In [4]:
import torch
import torch.nn as nn

In [68]:
# Define network
class Teamformer(nn.Module):
    def __init__(self, rob_state_size, targ_state_size,
                 d_model, enc_heads, enc_num_layers,
                  dec_heads, dec_num_layers, 
                  num_acts, num_heteros, max_seq_len):

        super().__init__()
        self.num_heteros = num_heteros

        self.loc_embed = nn.Linear(2, d_model)
        self.hetero_embed = nn.Linear(num_heteros, d_model)
        self.rob_spec_embed = nn.Linear(rob_state_size, d_model)
        self.targ_spec_embed = nn.Linear(targ_state_size, d_model)
        # PAD, SOS, EOS
        self.token_embed = nn.Embedding(3, d_model, padding_idx=0)

        enc_layer = nn.TransformerEncoderLayer(d_model, enc_heads, batch_first=True, norm_first=True)
        self.enc = nn.TransformerEncoder(enc_layer, enc_num_layers, enable_nested_tensor=False)
        dec_layer = nn.TransformerDecoderLayer(d_model, dec_heads, batch_first=True, norm_first=True)
        self.dec = nn.TransformerDecoder(dec_layer, dec_num_layers)

        self.act_layer = nn.Linear(d_model, num_acts)
        self.max_seq_len = max_seq_len

    def forward(self, x: list[torch.tensor]):
        rob_state, targ_state = x
        batch = rob_state.shape[0]

        # Encode team and task states
        team_locs, team_caps, team_info = torch.split(rob_state, [2, self.num_heteros, 2], dim=2)
        rob_embed = self.loc_embed(team_locs) + self.hetero_embed(team_caps) + self.rob_spec_embed(team_info)
        task_locs, task_reqs, task_info = torch.split(targ_state, [2, self.num_heteros, 2], dim=2)
        targ_embed = self.loc_embed(task_locs) + self.hetero_embed(task_reqs) + self.targ_spec_embed(task_info)

        fin_in = torch.cat((rob_embed, targ_embed), dim=1)
        enc_out = self.enc(fin_in)

        actions = torch.zeros((batch, self.num_heteros + 1))
        action_embeds = self.token_embed(torch.zeros((batch, 1), dtype=int))
        # Autoregress
        for i in range(self.max_seq_len):
            dec_out = self.dec(action_embeds, enc_out, tgt_mask=nn.Transformer.generate_square_subsequent_mask(action_embeds.shape[1]), tgt_is_causal=True)
            # Note, the action mask is removed for clarity
            logits = self.act_layer(dec_out[:, -1:])

            d = torch.distributions.Categorical(logits=logits)
            action = d.sample() # or d.mode() during inference
            actions[torch.arange(batch), action] += 1

            act = torch.zeros((batch, self.num_heteros + 1))
            act[torch.arange(batch), action] += 1

            act_embed = self.hetero_embed(act[:, :self.num_heteros]).unsqueeze(1)
            action_embeds = torch.cat((action_embeds, act_embed), dim=1)
        return actions




In [74]:
batch = 2
num_robs = 3
num_targs = 4

num_heteros = 5
rob_info_size = 2
task_info_size = 2

d_model = 128
enc_heads = 4
enc_layers = 2
dec_heads = 4
dec_layers = 2

# Normally, the heteros are ints, but doesn't matter for an example
rob_state = torch.rand((batch, num_robs, 2 + num_heteros + rob_info_size))
targ_state = torch.rand((batch, num_targs, 2 + num_heteros + task_info_size))

model = Teamformer(rob_info_size, task_info_size, 
                   d_model, enc_heads, enc_layers, dec_heads, dec_layers, 
                   num_heteros + 1, num_heteros, num_robs + 1)

# shape: batch x num_heteros + 1
out = model([rob_state, targ_state])
print(out.shape)

torch.Size([2, 6])
