# Looped Transformers for Length Generalization
This notebook compares looped transformers to simple ones and shows how they can generalize out of distribution to with addition.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import matplotlib.pyplot as plt
import numpy as np

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on: {device}")

Running on: cuda


In [2]:
# ==========================================
# 1. DATA GENERATION
# ==========================================
def generate_batch(batch_size, bits):
    X, Y = [], []
    for _ in range(batch_size):
        a = random.randint(0, 2**bits - 1)
        b = random.randint(0, 2**bits - 1)
        c = a + b
        
        a_bin = [int(x) for x in format(a, f'0{bits}b')[::-1]]
        b_bin = [int(x) for x in format(b, f'0{bits}b')[::-1]]
        c_bin = [int(x) for x in format(c, f'0{bits}b')[::-1]]
        
        # Handle overflow/padding
        if len(c_bin) < bits: 
            c_bin += [0] * (bits - len(c_bin))
        c_bin = c_bin[:bits]
        
        x_seq = [[a_val, b_val] for a_val, b_val in zip(a_bin, b_bin)]
        X.append(x_seq)
        Y.append(c_bin)
        
    return torch.tensor(X).float().to(device), torch.tensor(Y).long().to(device)

# Example usage:
X_example, Y_example = generate_batch(2, bits=8)
print("Example Input (X):", X_example)
print("Example Output (Y):", Y_example)

Example Input (X): tensor([[[0., 1.],
         [1., 1.],
         [1., 1.],
         [0., 1.],
         [0., 1.],
         [1., 0.],
         [0., 0.],
         [1., 1.]],

        [[1., 0.],
         [1., 1.],
         [1., 0.],
         [1., 0.],
         [0., 1.],
         [1., 1.],
         [1., 0.],
         [0., 1.]]], device='cuda:0')
Example Output (Y): tensor([[1, 0, 1, 0, 0, 0, 1, 0],
        [1, 0, 0, 0, 0, 1, 0, 0]], device='cuda:0')


In [3]:
class BaseTransformer(nn.Module):
    def __init__(self, input_dim, d_model, nhead, num_layers=4, max_len=100):
        super().__init__()
        self.embedding = nn.Linear(input_dim, d_model)
        
        # Absolute Positional Encoding
        self.pos_embedding = nn.Parameter(torch.randn(1, max_len, d_model) * 0.1)
                
        # Transformer Encoder Block
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True, norm_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.output_head = nn.Linear(d_model, 2) # MLP that outputs 2 classes (0 or 1)

    def forward(self, x, num_loops=None):
        seq_len = x.shape[1]
        # Add absolute position (up to current length)
        # If seq_len > max_len, this crashes or requires slicing (we slice here)
        pos = self.pos_embedding[:, :seq_len, :]
        x = self.embedding(x) + pos
        x = self.transformer(x)
        return self.output_head(x)
    
    
class UpdatedTransformer(nn.Module):
    def __init__(self, input_dim, d_model, nhead, num_layers=4, max_len=100):
        super().__init__()
        self.embedding = nn.Linear(input_dim, d_model)
        
        # Absolute Positional Encoding
        self.pos_embedding = nn.Parameter(torch.randn(1, max_len, d_model) * 0.1)
                
        # Transformer Encoder Block
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True, norm_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.output_head = nn.Linear(d_model, 2) # MLP that outputs 2 classes (0 or 1)

    def forward(self, x, num_loops=None):
        seq_len = x.shape[1]
        # Add absolute position (up to current length)
        # If seq_len > max_len, this crashes or requires slicing (we slice here)
        pos = self.pos_embedding[:, :seq_len, :]
        x = self.embedding(x) + pos
        x = self.transformer(x)
        return self.output_head(x)
    
    

In [4]:
def generate_decimal_batch(batch_size, length):
    X, Y = [], []
    
    # We always define the sequence length as N + 1
    # This guarantees space for the final carry.
    seq_len = length + 1
    
    for _ in range(batch_size):
        # 1. Generate random numbers of 'length' digits
        # e.g., if length=2, max is 99.
        max_val = (10 ** length) - 1
        a = random.randint(0, max_val)
        b = random.randint(0, max_val)
        c = a + b # Real sum (e.g., 99+99=198)
        
        # 2. Extract digits for the FULL sequence (length + 1)
        # We grab digits 0 to 'length'.
        # If the number is small (e.g. 50+20=70), the last digit will naturally be 0.
        a_seq = [(a // (10**i)) % 10 for i in range(seq_len)]
        b_seq = [(b // (10**i)) % 10 for i in range(seq_len)]
        c_seq = [(c // (10**i)) % 10 for i in range(seq_len)]
        
        # 3. Stack inputs
        # The last position of input will always be [0, 0]
        # The AI sees this and thinks "Ah, a blank space for me to write the carry."
        x_seq = [[d_a, d_b] for d_a, d_b in zip(a_seq, b_seq)]
        
        X.append(x_seq)
        Y.append(c_seq)
        
    return torch.tensor(X).long().to(device), torch.tensor(Y).long().to(device)

In [None]:
def train_and_compare():
    D_MODEL = 128 # Slightly larger for decimal complexity
    NHEAD = 4
    
    # 1. Setup Models
    std_model = BaseTransformer(D_MODEL, NHEAD, num_layers=4).to(device)
    loop_model = UpdatedTransformer(D_MODEL, NHEAD, num_layers=4).to(device) # Window=4
    
    optimizers = {
        "Standard": optim.Adam(std_model.parameters(), lr=0.001),
        "Looped": optim.Adam(loop_model.parameters(), lr=0.001)
    }
    
    criterion = nn.CrossEntropyLoss()
    
    print("--- Training Phase (8-Digit Numbers) ---")
    # Train both for 2000 steps
    for step in range(2001):
        X, Y = generate_decimal_batch(64, length=8)
        
        # Train Standard
        opt = optimizers["Standard"]
        opt.zero_grad()
        out = std_model(X)
        loss_std = criterion(out.reshape(-1, 10), Y.reshape(-1))
        loss_std.backward()
        opt.step()
        
        # Train Looped
        opt = optimizers["Looped"]
        opt.zero_grad()
        out = loop_model(X, num_loops=None) # 8 digits + buffer
        loss_loop = criterion(out.reshape(-1, 10), Y.reshape(-1))
        loss_loop.backward()
        opt.step()
        
        if step % 500 == 0:
            print(f"Step {step}: Std Loss {loss_std.item():.3f} | Loop Loss {loss_loop.item():.3f}")

    print("\n--- Generalization Test (Length 8 to 24) ---")
    lengths = [8, 12, 16, 20, 24]
    std_accs = []
    loop_accs = []
    
    for L in lengths:
        with torch.no_grad():
            X_test, Y_test = generate_decimal_batch(100, length=L)
            
            # Test Standard
            out_std = std_model(X_test)
            acc_s = (out_std.argmax(-1) == Y_test).float().mean().item()
            std_accs.append(acc_s)
            
            # Test Looped (Scale loops with length)
            out_loop = loop_model(X_test, num_loops=L + 5)
            acc_l = (out_loop.argmax(-1) == Y_test).float().mean().item()
            loop_accs.append(acc_l)
            
    # Print Table
    print(f"{'Length':<10} | {'Standard Acc':<15} | {'Looped Acc':<15}")
    print("-" * 45)
    for i, L in enumerate(lengths):
        print(f"{L:<10} | {std_accs[i]:.2f}{'':<11} | {loop_accs[i]:.2f}")

    # Plot
    plt.figure(figsize=(8, 5))
    plt.plot(lengths, std_accs, 'r-o', label='Standard (Memorization)')
    plt.plot(lengths, loop_accs, 'g-o', label='Looped (Abstraction)')
    plt.axvline(x=8, color='gray', linestyle='--', label='Training Cutoff')
    plt.title("Decimal Addition: Generalization to Long Numbers")
    plt.xlabel("Number of Digits")
    plt.ylabel("Accuracy")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

train_and_compare()

TypeError: BaseTransformer.__init__() missing 1 required positional argument: 'nhead'