In [41]:
import torch
import torch.nn as nn
import torch.optim as optim
import math

# 1. Define a Transformer model
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class TransformerModel(nn.Module):
    def __init__(self, input_dim, d_model, nhead, num_encoder_layers, dim_feedforward, num_classes):
        super(TransformerModel, self).__init__()
        self.d_model = d_model
        self.pos_encoder = PositionalEncoding(d_model)
        self.embedding = nn.Linear(input_dim, d_model)
        encoder_layers = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_encoder_layers)
        self.decoder = nn.Linear(d_model, num_classes)

    def forward(self, src):
        # src shape: (batch_size, seq_len, input_dim)
        src = self.embedding(src) * math.sqrt(self.d_model)
        # After embedding: (batch_size, seq_len, d_model)
        # Transformer expects (seq_len, batch_size, d_model) if batch_first=False (default)
        # or (batch_size, seq_len, d_model) if batch_first=True.
        # Our layer is batch_first=True.
        
        # The PositionalEncoding expects (seq_len, batch, dim)
        # so we need to permute before and after
        src = src.permute(1, 0, 2)
        src = self.pos_encoder(src)
        src = src.permute(1, 0, 2)

        output = self.transformer_encoder(src)
        # output shape: (batch_size, seq_len, d_model)
        
        # We can average the output over the sequence length for classification
        output = output.mean(dim=1)
        output = self.decoder(output)
        return output

In [42]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [43]:
# 2. Create four hard-coded example inputs in two batches with different sequence lengths
# The transformer model expects an input dimension, let's make it 1 for simplicity
input_dim = 1

# Batch 1 will have sequence length 10
input1 = torch.randn(2, 10, input_dim).to(device) # batch of 2
target1 = torch.randint(0, 5, (2,)).to(device)

# Batch 2 will have sequence length 15
input2 = torch.randn(2, 15, input_dim).to(device) # batch of 2
target2 = torch.randint(0, 5, (2,)).to(device)

input_batch1 = input1
target_batch1 = target1
input_batch2 = input2
target_batch2 = target2

print(f"Shape of input_batch1: {input_batch1.shape}")
print(f"Shape of input_batch2: {input_batch2.shape}")
print(f"Shape of target_batch1: {target_batch1.shape}")
print(f"Shape of target_batch2: {target_batch2.shape}")

Shape of input_batch1: torch.Size([2, 10, 1])
Shape of input_batch2: torch.Size([2, 15, 1])
Shape of target_batch1: torch.Size([2])
Shape of target_batch2: torch.Size([2])


In [44]:
# 3. Create a model and optimizer
input_dim = 1
d_model = 32
nhead = 4
num_encoder_layers = 2
dim_feedforward = 128
num_classes = 5

model = TransformerModel(input_dim, d_model, nhead, num_encoder_layers, dim_feedforward, num_classes).to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

# 4. Compile the model
compiled_model = torch.compile(model)

In [45]:
# Evaluate before training
with torch.no_grad():
    output_before1 = compiled_model(input_batch1)
    loss_before1 = criterion(output_before1, target_batch1)
    output_before2 = compiled_model(input_batch2)
    loss_before2 = criterion(output_before2, target_batch2)
    print(f"Loss before training (batch 1): {loss_before1.item()}")
    print(f"Loss before training (batch 2): {loss_before2.item()}")

Loss before training (batch 1): 1.3857465982437134
Loss before training (batch 2): 2.0797219276428223


In [46]:
# 5. Two train steps
print("First train step (batch 1)")
optimizer.zero_grad()
output1 = compiled_model(input_batch1)
loss1 = criterion(output1, target_batch1)
loss1.backward()
optimizer.step()
print(f"Loss 1 (batch 1): {loss1.item()}")

First train step (batch 1)
Loss 1 (batch 1): 1.4687457084655762
Loss 1 (batch 1): 1.4687457084655762


In [47]:
print("\nSecond train step (batch 2)")
optimizer.zero_grad()
output2 = compiled_model(input_batch2)
loss2 = criterion(output2, target_batch2)
loss2.backward()
optimizer.step()
print(f"Loss 2 (batch 2): {loss2.item()}")


Second train step (batch 2)
Loss 2 (batch 2): 2.028371810913086


In [48]:
# Evaluate after training
with torch.no_grad():
    output_after1 = compiled_model(input_batch1)
    loss_after1 = criterion(output_after1, target_batch1)
    output_after2 = compiled_model(input_batch2)
    loss_after2 = criterion(output_after2, target_batch2)
    print(f"Loss after training (batch 1): {loss_after1.item()}")
    print(f"Loss after training (batch 2): {loss_after2.item()}")
    
    print(f"\nLoss for batch 1 went from {loss_before1.item()} to {loss_after1.item()}, delta: {loss_after1.item() - loss_before1.item()} (lower is better)")
    print(f"Loss for batch 2 went from {loss_before2.item()} to {loss_after2.item()}, delta: {loss_after2.item() - loss_before2.item()} (lower is better)")

Loss after training (batch 1): 1.256601095199585
Loss after training (batch 2): 1.7875232696533203

Loss for batch 1 went from 1.3857465982437134 to 1.256601095199585, delta: -0.12914550304412842 (lower is better)
Loss for batch 2 went from 2.0797219276428223 to 1.7875232696533203, delta: -0.29219865798950195 (lower is better)
