In [1]:
# install: pip install transformer-lens

import torch
import torch.nn as nn
from transformer_lens import HookedTransformer, HookedTransformerConfig

# Configuration
config = HookedTransformerConfig(
    n_layers=2,
    n_heads=8,
    d_model=128,
    d_head=16,  # d_model / n_heads
    d_mlp=None,  # No MLPs (attention-only)
    act_fn=None,  # No activation (no MLPs)
    attention_dir="causal",  # Causal attention
    attn_only=True,  # Attention-only model
    normalization_type=None,  # No LayerNorm for simplicity
    d_vocab=50,  # 26 letters + 10 digits + special tokens
    n_ctx=50,  # Max sequence length
    init_weights=True,
    device="cuda" if torch.cuda.is_available() else "cpu"
)

model = HookedTransformer(config)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Model device: {next(model.parameters()).device}")

Model parameters: 151,346
Model device: cuda:0


In [2]:
# Simple character-level tokenizer
class CountingTokenizer:
    def __init__(self):
        # Vocabulary: letters + digits + special tokens
        self.chars = list("abcdefghijklmnopqrstuvwxyz0123456789")
        self.special = ["<PAD>", "<BOS>", "<EOS>", ":", " ", "Count", "the", "letter", "in"]
        
        self.vocab = self.special + self.chars
        self.vocab_size = len(self.vocab)
        
        self.char_to_id = {c: i for i, c in enumerate(self.vocab)}
        self.id_to_char = {i: c for i, c in enumerate(self.vocab)}
    
    def encode(self, text):
        """Convert text to token IDs"""
        tokens = []
        i = 0
        while i < len(text):
            # Try multi-char tokens first
            if text[i:i+5] == "Count":
                tokens.append(self.char_to_id["Count"])
                i += 5
            elif text[i:i+3] == "the":
                tokens.append(self.char_to_id["the"])
                i += 3
            elif text[i:i+6] == "letter":
                tokens.append(self.char_to_id["letter"])
                i += 6
            elif text[i:i+2] == "in":
                tokens.append(self.char_to_id["in"])
                i += 2
            else:
                tokens.append(self.char_to_id[text[i]])
                i += 1
        return tokens
    
    def decode(self, ids):
        """Convert token IDs to text"""
        return "".join([self.id_to_char[i] for i in ids])

tokenizer = CountingTokenizer()

In [3]:
import random

def generate_counting_example(target_letter='a', multiplicity_range=(1, 2), 
                              length_range=(5, 10), tokenizer=None):
    """
    Generate: "Count the letter a in: banana" -> "3"
    Format: [question tokens] [answer token]
    """
    # Sample words with target letter
    words_with_target = []
    count = random.randint(*multiplicity_range)
    
    # Generate string with exact count of target letter
    chars = list("abcdefghijklmnopqrstuvwxyz")
    chars.remove(target_letter)
    
    length = random.randint(*length_range)
    string_chars = random.choices(chars, k=length - count)
    
    # Insert target letters
    positions = random.sample(range(length), count)
    for pos in positions:
        string_chars.insert(pos, target_letter)
    
    input_string = "".join(string_chars[:length])
    
    # Format question
    question = f"Count the letter {target_letter} in: {input_string}"
    answer = str(count)
    
    # Tokenize
    question_tokens = tokenizer.encode(question)
    answer_token = tokenizer.encode(answer)[0]  # Single digit
    
    # Combine: question + answer
    full_tokens = question_tokens + [answer_token]
    
    return {
        'tokens': full_tokens,
        'question_length': len(question_tokens),  # For loss masking
        'answer': count,
        'text': question + answer
    }

# Test
tokenizer = CountingTokenizer()
example = generate_counting_example(tokenizer=tokenizer)
print(f"Text: {example['text']}")
print(f"Tokens: {example['tokens']}")
print(f"Question length: {example['question_length']}")
print(f"Answer: {example['answer']}")

Text: Count the letter a in: nidpah1
Tokens: [5, 4, 6, 4, 7, 4, 9, 4, 8, 3, 4, 22, 17, 12, 24, 9, 16, 36]
Question length: 17
Answer: 1


In [4]:
from torch.utils.data import Dataset, DataLoader

class CountingDataset(Dataset):
    def __init__(self, n_examples=10000, difficulty='easy', tokenizer=None):
        """
        difficulty: 'easy', 'bpe-hard', 'mult-hard', etc.
        """
        self.tokenizer = tokenizer
        self.examples = []
        
        # Set parameters based on difficulty
        if difficulty == 'easy':
            mult_range = (1, 2)
            len_range = (5, 10)
        elif difficulty == 'mult-hard':
            mult_range = (3, 10)
            len_range = (5, 10)
        elif difficulty == 'length-hard':
            mult_range = (1, 2)
            len_range = (20, 50)
        elif difficulty == 'all-hard':
            mult_range = (3, 10)
            len_range = (20, 50)
        
        # Generate examples
        target_letters = list("abcdefghijklmnopqrstuvwxyz")
        for _ in range(n_examples):
            target = random.choice(target_letters)
            example = generate_counting_example(
                target_letter=target,
                multiplicity_range=mult_range,
                length_range=len_range,
                tokenizer=tokenizer
            )
            self.examples.append(example)
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        return self.examples[idx]

def collate_fn(batch, pad_id=0, max_len=50):
    """Pad sequences to same length"""
    # Pad tokens
    tokens = [ex['tokens'] for ex in batch]
    max_batch_len = min(max(len(t) for t in tokens), max_len)
    
    padded_tokens = []
    masks = []  # Loss mask: 1 for answer token, 0 elsewhere
    
    for ex in batch:
        seq = ex['tokens'][:max_batch_len]
        q_len = min(ex['question_length'], max_batch_len - 1)
        
        # Pad sequence
        padded = seq + [pad_id] * (max_batch_len - len(seq))
        padded_tokens.append(padded)
        
        # Create mask: only compute loss on answer token
        mask = [0] * max_batch_len
        if q_len < len(seq):  # If answer token exists
            mask[q_len] = 1  # Answer is right after question
        masks.append(mask)
    
    return {
        'input_ids': torch.tensor(padded_tokens, dtype=torch.long),
        'loss_mask': torch.tensor(masks, dtype=torch.float),
        'answers': torch.tensor([ex['answer'] for ex in batch], dtype=torch.long)
    }

# Create dataloaders
tokenizer = CountingTokenizer()
train_dataset = CountingDataset(n_examples=10000, difficulty='easy', tokenizer=tokenizer)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, 
                         collate_fn=collate_fn)

# Test batch
batch = next(iter(train_loader))
print(f"Input shape: {batch['input_ids'].shape}")
print(f"Mask shape: {batch['loss_mask'].shape}")
print(f"Example tokens: {batch['input_ids'][0]}")
print(f"Example mask: {batch['loss_mask'][0]}")

Input shape: torch.Size([64, 22])
Mask shape: torch.Size([64, 22])
Example tokens: tensor([ 5,  4,  6,  4,  7,  4, 12,  4,  8,  3,  4, 18, 31, 25, 30, 16, 31, 12,
        36,  0,  0,  0])
Example mask: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        1., 0., 0., 0.])


In [5]:
import torch.optim as optim
from tqdm import tqdm

def train_model(model, train_loader, n_epochs=10, lr=1e-3, device='cuda'):
    """
    Train with MASKED loss (only on answer token)
    
    Classification loss: CrossEntropy on vocabulary
    (Can also try regression loss on digit value)
    """
    model = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
    
    # Classification loss (predict token ID)
    criterion = nn.CrossEntropyLoss(reduction='none')  # Per-token loss
    
    for epoch in range(n_epochs):
        model.train()
        total_loss = 0
        correct = 0
        total = 0
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{n_epochs}")
        for batch in pbar:
            input_ids = batch['input_ids'].to(device)  # [batch, seq_len]
            loss_mask = batch['loss_mask'].to(device)  # [batch, seq_len]
            
            # Forward pass
            # Input: all tokens except last
            # Target: all tokens except first (shifted by 1)
            logits = model(input_ids[:, :-1])  # [batch, seq_len-1, vocab]
            targets = input_ids[:, 1:]  # [batch, seq_len-1]
            
            # Compute loss
            loss_per_token = criterion(
                logits.reshape(-1, logits.size(-1)),  # [batch*(seq_len-1), vocab]
                targets.reshape(-1)  # [batch*(seq_len-1)]
            )
            loss_per_token = loss_per_token.reshape(targets.shape)  # [batch, seq_len-1]
            
            # Apply mask: only compute loss on answer token
            mask = loss_mask[:, 1:]  # Align with targets
            masked_loss = (loss_per_token * mask).sum() / mask.sum()
            
            # Backward pass
            optimizer.zero_grad()
            masked_loss.backward()
            optimizer.step()
            
            # Metrics
            total_loss += masked_loss.item()
            
            # Accuracy: check if predicted answer digit is correct
            preds = logits.argmax(dim=-1)  # [batch, seq_len-1]
            answer_positions = mask.bool()
            if answer_positions.any():
                correct += (preds[answer_positions] == targets[answer_positions]).sum().item()
                total += answer_positions.sum().item()
            
            pbar.set_postfix({'loss': masked_loss.item(), 
                            'acc': correct/total if total > 0 else 0})
        
        scheduler.step()
        
        print(f"Epoch {epoch+1}: Loss={total_loss/len(train_loader):.4f}, "
              f"Acc={correct/total:.4f}")
        
        # Save checkpoint
        if (epoch + 1) % 5 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, f'checkpoint_epoch_{epoch+1}.pt')
    
    return model

# Train!
model = HookedTransformer(config)
trained_model = train_model(model, train_loader, n_epochs=20, lr=1e-3)

Moving model to device:  cuda


Epoch 1/20: 100%|██████████| 157/157 [00:01<00:00, 102.28it/s, loss=0.691, acc=0.499]


Epoch 1: Loss=0.8805, Acc=0.4994


Epoch 2/20: 100%|██████████| 157/157 [00:01<00:00, 126.94it/s, loss=0.775, acc=0.499]


Epoch 2: Loss=0.6979, Acc=0.4988


Epoch 3/20: 100%|██████████| 157/157 [00:01<00:00, 127.86it/s, loss=0.739, acc=0.508]


Epoch 3: Loss=0.6984, Acc=0.5085


Epoch 4/20: 100%|██████████| 157/157 [00:01<00:00, 126.14it/s, loss=0.699, acc=0.506]


Epoch 4: Loss=0.7002, Acc=0.5064


Epoch 5/20: 100%|██████████| 157/157 [00:01<00:00, 130.63it/s, loss=0.678, acc=0.505]


Epoch 5: Loss=0.6970, Acc=0.5054


Epoch 6/20: 100%|██████████| 157/157 [00:01<00:00, 126.98it/s, loss=0.674, acc=0.516]


Epoch 6: Loss=0.6925, Acc=0.5162


Epoch 7/20: 100%|██████████| 157/157 [00:01<00:00, 131.47it/s, loss=0.606, acc=0.536]


Epoch 7: Loss=0.6901, Acc=0.5359


Epoch 8/20: 100%|██████████| 157/157 [00:01<00:00, 128.25it/s, loss=0.791, acc=0.571]


Epoch 8: Loss=0.6801, Acc=0.5705


Epoch 9/20: 100%|██████████| 157/157 [00:01<00:00, 131.59it/s, loss=0.627, acc=0.588]


Epoch 9: Loss=0.6718, Acc=0.5879


Epoch 10/20: 100%|██████████| 157/157 [00:01<00:00, 128.31it/s, loss=0.694, acc=0.609]


Epoch 10: Loss=0.6591, Acc=0.6091


Epoch 11/20: 100%|██████████| 157/157 [00:01<00:00, 121.91it/s, loss=0.722, acc=0.619]


Epoch 11: Loss=0.6502, Acc=0.6192


Epoch 12/20: 100%|██████████| 157/157 [00:01<00:00, 127.83it/s, loss=0.678, acc=0.64] 


Epoch 12: Loss=0.6353, Acc=0.6403


Epoch 13/20: 100%|██████████| 157/157 [00:01<00:00, 131.44it/s, loss=0.616, acc=0.665]


Epoch 13: Loss=0.6075, Acc=0.6652


Epoch 14/20: 100%|██████████| 157/157 [00:01<00:00, 128.01it/s, loss=0.543, acc=0.703]


Epoch 14: Loss=0.5659, Acc=0.7031


Epoch 15/20: 100%|██████████| 157/157 [00:01<00:00, 131.10it/s, loss=0.632, acc=0.761]


Epoch 15: Loss=0.4915, Acc=0.7613


Epoch 16/20: 100%|██████████| 157/157 [00:01<00:00, 121.47it/s, loss=0.374, acc=0.83] 


Epoch 16: Loss=0.3896, Acc=0.8304


Epoch 17/20: 100%|██████████| 157/157 [00:01<00:00, 128.89it/s, loss=0.199, acc=0.883]


Epoch 17: Loss=0.2957, Acc=0.8828


Epoch 18/20: 100%|██████████| 157/157 [00:01<00:00, 130.23it/s, loss=0.44, acc=0.912] 


Epoch 18: Loss=0.2418, Acc=0.9118


Epoch 19/20: 100%|██████████| 157/157 [00:01<00:00, 129.62it/s, loss=0.13, acc=0.93]  


Epoch 19: Loss=0.2094, Acc=0.9302


Epoch 20/20: 100%|██████████| 157/157 [00:01<00:00, 129.49it/s, loss=0.131, acc=0.937]

Epoch 20: Loss=0.1973, Acc=0.9371





In [6]:
def test_model(model, text, tokenizer, device='cuda'):
    """Quick inference test"""
    model.eval()
    tokens = tokenizer.encode(text)
    input_ids = torch.tensor([tokens]).to(device)
    
    with torch.no_grad():
        logits = model(input_ids)
        pred_token = logits[0, -1].argmax().item()
        pred_char = tokenizer.id_to_char[pred_token]
    
    print(f"Input: {text}")
    print(f"Predicted: {pred_char}")
    return pred_char

# Test
test_model(trained_model, "Count the letter a in: banana", tokenizer)

Input: Count the letter a in: banana
Predicted: 2


'2'