**Transformer-based**

In [14]:
import torch
import torch.nn as nn

In [18]:
device = "cpu"
d_model = 256
bs = 8
sos_idx = 1
vocab_size = 15  # num of classes inc. SOS
input_len = 4  # num of clips
output_len = 25  # max seq len

# Define the model
encoder_layer = nn.TransformerEncoderLayer(
    d_model=d_model, nhead=4, batch_first=True
).to(device)
encoder = nn.TransformerEncoder(encoder_layer=encoder_layer, num_layers=6).to(device)

decoder_layer = nn.TransformerDecoderLayer(
    d_model=d_model, nhead=4, batch_first=True
).to(device)
decoder = nn.TransformerDecoder(decoder_layer=decoder_layer, num_layers=6).to(device)

decoder_emb = nn.Embedding(vocab_size, d_model)
predictor = nn.Linear(d_model, vocab_size)

# for a single batch x
x = torch.randn(bs, input_len, d_model).to(device)
y = torch.randint(0, vocab_size, (bs, output_len)).to(device)

**Forward pass**

In [19]:
# Forward pass of the model
encoder_output = encoder(x)

tgt_emb = decoder_emb(y)
tgt_mask = torch.nn.Transformer().generate_square_subsequent_mask(output_len).to(device)
decoder_output = decoder(tgt=tgt_emb, tgt_mask=tgt_mask, memory=encoder_output)
output = predictor(decoder_output)

print("Encoder output shape:", encoder_output.shape)
print("Decoder output shape:", decoder_output.shape)
print("Final output shape:", output.shape)



Encoder output shape: torch.Size([8, 4, 256])
Decoder output shape: torch.Size([8, 25, 256])
Final output shape: torch.Size([8, 25, 15])


**Inference**

In [12]:
encoder_output = encoder(x)  # (bs, input_len, d_model)

# initialized the input of the decoder with sos_idx (start of sentence token idx)
output = torch.ones(bs, output_len).long().to(device) * sos_idx

for t in range(1, output_len):
    tgt_emb = decoder_emb(output[:, :t])
    tgt_mask = torch.nn.Transformer().generate_square_subsequent_mask(t).to(device)

    decoder_output = decoder(tgt=tgt_emb, memory=encoder_output, tgt_mask=tgt_mask)

    pred_proba_t = predictor(decoder_output[:, -1, :])
    output_t = pred_proba_t.data.topk(1)[1].squeeze()
    output[:, t] = output_t

print("Output shape:", output.shape)

Output shape: torch.Size([4, 12])
