In [1]:
import torch
import torch.nn as nn
from transformer_lens import HookedTransformer, HookedTransformerConfig

# Configuration
config = HookedTransformerConfig(
    n_layers=2,
    n_heads=8,
    d_model=128,
    d_head=16,  # d_model / n_heads
    d_mlp=None,  # No MLPs (attention-only)
    act_fn=None,  # No activation (no MLPs)
    attention_dir="causal",  # Causal attention
    attn_only=True,  # Attention-only model
    normalization_type=None,  # No LayerNorm for simplicity
    d_vocab=50,  # 26 letters + 10 digits + special tokens
    n_ctx=60,  # Max sequence length
    init_weights=True,
    device="cuda" if torch.cuda.is_available() else "cpu"
)

model = HookedTransformer(config)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Model device: {next(model.parameters()).device}")
print(f"Model device name: {torch.cuda.get_device_name(0)}")

Model parameters: 152,626
Model device: cuda:0
Model device name: NVIDIA GeForce RTX 4080 SUPER


In [2]:
class Vocabulary:
    class TrieNode:
        def __init__(self):
            self.id = None
            self.next = {}

    def __init__(self):
        self.root = self.TrieNode()
        self.token_map = {} # Stores mapping of ID to token
        self.size = 0

    # Adds a new token into the vocabulary
    def add_token(self, token):
        node = self.root
        for c in token:
            if c not in node.next:
                node.next[c] = self.TrieNode()
            node = node.next[c]
        if node.id is None:
            node.id = self.size
            self.token_map[self.size] = token
            self.size += 1

    # Finds id of longest prefix token of text[start:end], and returns length of token
    def longest_prefix_token(self, text, start):
        longest_token = None
        longest_length = 0
        node = self.root
        for i in range(start, len(text)):
            if text[i] not in node.next:
                break
            node = node.next[text[i]]
            if node.id is not None:
                longest_token = node.id
                longest_length = i - start + 1
        assert longest_token is not None
        return longest_token, longest_length

    # Converts an id to the corresponding token
    def get_token(self, id):
        return self.token_map[id]

In [3]:
# Simple character-level tokenizer
class CountingTokenizer:
    def __init__(self):
        # Vocabulary: letters + digits + special tokens
        self.vocab = Vocabulary()
        chars = list("abcdefghijklmnopqrstuvwxyz0123456789")
        special = ["<PAD>", "<BOS>", "<EOS>", ":", " ", "Count", "the", "letter", "in"]
        raw_vocab = special + chars
        for token in raw_vocab:
            self.vocab.add_token(token)

    def encode(self, text, include_lengths = False):
        """Convert text to token IDs"""
        ids = []
        i = 0
        while i < len(text):
            id, token_length = self.vocab.longest_prefix_token(text, i)
            assert id != -1
            if include_lengths:
                ids.append((id, token_length))
            else:
                ids.append(id)
            i += token_length
        return ids
    
    def decode(self, ids):
        """Convert token IDs to text"""
        return "".join([self.vocab.get_token(id) for id in ids])

    def apply_bpe(self, words, max_token_length=3):
        """Adds merge rules based on a list of words for BPE"""
        text = "".join([f"<BOS>{word}<EOS>" for word in words])
        ignore_tokens = ["<PAD>", "<BOS>", "<EOS>", ":", " "]
        while True:
            encoded = self.encode(text, include_lengths=True)
            pairs = {}
            merge_pair = ()
            for i in range(len(encoded) - 1):
                token_pair = encoded[i], encoded[i + 1]
                if token_pair[0][1] + token_pair[1][1] > max_token_length:
                    continue
                if any(self.vocab.get_token(token[0]) in ignore_tokens for token in token_pair):
                    continue
                pairs[token_pair] = pairs.get(token_pair, 0) + 1
                if not merge_pair or pairs[merge_pair] < pairs[token_pair]:
                    merge_pair = token_pair
            if not merge_pair or pairs[merge_pair] < 2:
                break
            self.vocab.add_token("".join([self.vocab.get_token(token[0]) for token in merge_pair]))

tokenizer = CountingTokenizer()

In [4]:
import random

def generate_counting_question(target_letter='a', multiplicity_range=(1, 2), length_range=(5, 10)):
    """
    Generates a new word based on input parameters and gives the answer to the question
    """
    
    # Sample words with target letter
    words_with_target = []
    count = random.randint(*multiplicity_range)
    
    # Generate string with exact count of target letter
    chars = list("abcdefghijklmnopqrstuvwxyz")
    chars.remove(target_letter)
    
    length = random.randint(max(count, length_range[0]), length_range[1])
    string_chars = random.choices(chars, k=length - count)
    
    # Insert target letters
    positions = random.sample(range(length), count)
    for pos in positions:
        string_chars.insert(pos, target_letter)
    
    input_string = "".join(string_chars[:length])
    question = f"Count the letter {target_letter} in: {input_string}"
    answer = str(count)

    return question, answer

def generate_counting_example(qa, tokenizer=None):
    """
    Generate: "Count the letter a in: banana" -> "3"
    Format: [question tokens] [answer token]
    """
    question, answer = qa
    
    # Tokenize
    question_tokens = tokenizer.encode(question)
    question_tokens_decoded = [tokenizer.decode([token]) for token in question_tokens]
    answer_token = tokenizer.encode(answer)[0]  # Single digit
    
    # Combine: question + answer
    full_tokens = question_tokens + [answer_token]
    
    return {
        'tokens': full_tokens,
        'question_tokens_decoded': question_tokens_decoded,
        'question_length': len(question_tokens),  # For loss masking
        'answer': int(answer),
        'text': question + answer
    }

# Test
tokenizer = CountingTokenizer()
qa = generate_counting_question()
example = generate_counting_example(qa, tokenizer=tokenizer)
print(f"Text: {example['text']}")
print(f"Tokens: {example['tokens']}")
print(f"Question Tokens Decoded: {example['question_tokens_decoded']}")
print(f"Question length: {example['question_length']}")
print(f"Answer: {example['answer']}")

Text: Count the letter a in: atlga2
Tokens: [5, 4, 6, 4, 7, 4, 9, 4, 8, 3, 4, 9, 28, 20, 15, 9, 37]
Question Tokens Decoded: ['Count', ' ', 'the', ' ', 'letter', ' ', 'a', ' ', 'in', ':', ' ', 'a', 't', 'l', 'g', 'a']
Question length: 16
Answer: 2


In [5]:
# Test
qas = [generate_counting_question() for i in range(10)]
tokenizer = CountingTokenizer()
tokenizer.apply_bpe([qa[0] for qa in qas])
qa = generate_counting_question()
example = generate_counting_example(qa, tokenizer=tokenizer)
print(f"Sample Pre-Processing Text: {"".join([f"<BOS>{qa[0]}<EOS>" for qa in qas])}")
print(f"Text: {example['text']}")
print(f"Tokens: {example['tokens']}")
print(f"Question Tokens Decoded: {example['question_tokens_decoded']}")
print(f"Question length: {example['question_length']}")
print(f"Answer: {example['answer']}")

Sample Pre-Processing Text: <BOS>Count the letter a in: raylqa<EOS><BOS>Count the letter a in: ajiwuiblik<EOS><BOS>Count the letter a in: jfaqswallk<EOS><BOS>Count the letter a in: imaqa<EOS><BOS>Count the letter a in: zsvaa<EOS><BOS>Count the letter a in: rznxdurabr<EOS><BOS>Count the letter a in: acqqay<EOS><BOS>Count the letter a in: lrhdabiac<EOS><BOS>Count the letter a in: ayancrgik<EOS><BOS>Count the letter a in: redhoeafq<EOS>
Text: Count the letter a in: ucfxkal1
Tokens: [5, 4, 6, 4, 7, 4, 9, 4, 8, 3, 4, 29, 11, 14, 32, 19, 9, 20, 36]
Question Tokens Decoded: ['Count', ' ', 'the', ' ', 'letter', ' ', 'a', ' ', 'in', ':', ' ', 'u', 'c', 'f', 'x', 'k', 'a', 'l']
Question length: 18
Answer: 1


In [6]:
import pickle
from torch.utils.data import Dataset, DataLoader

GENERATE_NEW_TRAINING_DATA = False # IMPORTANT: Set this to True only if you want to generate new datasets

class CountingDataset(Dataset):
    def __init__(self, n_examples=50000, difficulty='easy', tokenizer=None, allow_bpe=True):
        """
        difficulty: 'easy', 'bpe-hard', 'mult-hard', etc.
        """
        self.tokenizer = tokenizer
        self.examples = []

        use_bpe = False
        # Set parameters based on difficulty
        if difficulty == 'easy':
            mult_range = (1, 2)
            len_range = (5, 10)
        elif difficulty == 'bpe-hard':
            mult_range = (1, 2)
            len_range = (5, 10)
            use_bpe = True
        elif difficulty == 'mult-hard':
            mult_range = (3, 10)
            len_range = (5, 10)
        elif difficulty == 'length-hard':
            mult_range = (1, 2)
            len_range = (20, 50)
        elif difficulty == 'all-hard':
            mult_range = (3, 10)
            len_range = (20, 50)
            use_bpe = True
        elif difficulty == 'mixed':
            mult_range = (1,10)
            len_range = (5, 50)
            use_bpe = True
        else:
            assert False
        
        # Generate basic words
        target_letters = list("abcdefghijklmnopqrstuvwxyz")
        qas = []
        for _ in range(n_examples):
            target = random.choice(target_letters)
            qa = generate_counting_question(
                target_letter=target,
                multiplicity_range=mult_range,
                length_range=len_range,
            )
            qas.append(qa)

        # Add basic words to the vocabulary if we are applying BPE
        if use_bpe and allow_bpe:
            tokenizer.apply_bpe([qa[0] for qa in qas])

        # Generate the full questions and examples
        basic_tokenizer = CountingTokenizer()
        if difficulty == "mixed":
            bpe_set = random.sample(range(len(qas)), len(qas) // 2)
        else:
            bpe_set = range(len(qas))
        for i in range(len(qas)):
            example = generate_counting_example(
                qas[i],
                tokenizer=tokenizer if i in bpe_set else basic_tokenizer
            )
            self.examples.append(example)
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        return self.examples[idx]

def collate_fn(batch, pad_id=0, max_len=60):
    """Pad sequences to same length"""
    # Pad tokens
    tokens = [ex['tokens'] for ex in batch]
    max_batch_len = min(max(len(t) for t in tokens), max_len)
    
    padded_tokens = []
    masks = []  # Loss mask: 1 for answer token, 0 elsewhere
    
    for ex in batch:
        seq = ex['tokens'][:max_batch_len]
        q_len = min(ex['question_length'], max_batch_len - 1)
        
        # Pad sequence
        padded = seq + [pad_id] * (max_batch_len - len(seq))
        padded_tokens.append(padded)
        
        # Create mask: only compute loss on answer token
        mask = [0] * max_batch_len
        if q_len < len(seq):  # If answer token exists
            mask[q_len] = 1  # Answer is right after question
        masks.append(mask)
    
    return {
        'input_ids': torch.tensor(padded_tokens, dtype=torch.long),
        'loss_mask': torch.tensor(masks, dtype=torch.float),
        'answers': torch.tensor([ex['answer'] for ex in batch], dtype=torch.long)
    }

# Create dataloaders
train_dataset_names = ["easy", "bpe-hard", "mult-hard", "length-hard", "all-hard", "mixed"]

# Training dataloaders
train_datasets = {}
train_loaders = {}
train_tokenizers = {}
for name in train_dataset_names:
    if GENERATE_NEW_TRAINING_DATA:
        train_tokenizers[name] = CountingTokenizer()
        train_datasets[name] = CountingDataset(n_examples=20000, difficulty=name, tokenizer=train_tokenizers[name])
        with open(f"train-{name}-dataset.pkl", "wb") as f:
            pickle.dump(train_datasets[name], f)
        with open(f"train-{name}-tokenizer.pkl", "wb") as f:
            pickle.dump(train_tokenizers[name], f)
    else:
        with open(f"train-{name}-dataset.pkl", "rb") as f:
            train_datasets[name] = pickle.load(f)
        with open(f"train-{name}-tokenizer.pkl", "rb") as f:
            train_tokenizers[name] = pickle.load(f)
            print(train_tokenizers[name].vocab.size)
    train_loaders[name] = DataLoader(train_datasets[name], batch_size=64, shuffle=True, collate_fn=collate_fn)

# Test batch
batch = next(iter(train_loaders["bpe-hard"]))
print(f"Input shape: {batch['input_ids'].shape}")
print(f"Mask shape: {batch['loss_mask'].shape}")
print(f"Example tokens: {batch['input_ids'][0]}")
print(f"Example mask: {batch['loss_mask'][0]}")

45
3004
45
45
3588
3785
Input shape: torch.Size([64, 17])
Mask shape: torch.Size([64, 17])
Example tokens: tensor([  5,   4,   6,   4,   7,   4,  18,   4,   8,   3,   4, 153, 194, 169,
         37,   0,   0])
Example mask: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.])


In [11]:
GENERATE_NEW_TESTING_DATA = False # IMPORTANT: Set this to True only if you want to generate new datasets

# Testing dataloaders
test_dataset_names = ["easy", "bpe-hard", "mult-hard", "length-hard", "all-hard", "mixed"]
test_datasets = {}
test_tokenizers = {}
for name in test_dataset_names:
    if GENERATE_NEW_TESTING_DATA:
        test_datasets[name] = CountingDataset(n_examples=2000, difficulty=name, tokenizer=train_tokenizers[name], allow_bpe=False)
        with open(f"test-{name}-dataset.pkl", "wb") as f:
            pickle.dump(test_datasets[name], f)
    else:
        with open(f"test-{name}-dataset.pkl", "rb") as f:
            test_datasets[name] = pickle.load(f)

In [15]:
import torch.optim as optim
from tqdm import tqdm

RETRAIN = False # IMPORTANT: Only toggle to True if you want to actually train new models

def train_model(model, name, train_loader, n_epochs=10, lr=1e-3, device='cuda'):
    """
    Train with MASKED loss (only on answer token)
    
    Classification loss: CrossEntropy on vocabulary
    (Can also try regression loss on digit value)
    """
    model = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
    
    # Classification loss (predict token ID)
    criterion = nn.CrossEntropyLoss(reduction='none')  # Per-token loss
    
    for epoch in range(n_epochs):
        model.train()
        total_loss = 0
        correct = 0
        total = 0
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{n_epochs}")
        for batch in pbar:
            input_ids = batch['input_ids'].to(device)  # [batch, seq_len]
            loss_mask = batch['loss_mask'].to(device)  # [batch, seq_len]
            
            # Forward pass
            # Input: all tokens except last
            # Target: all tokens except first (shifted by 1)
            logits = model(input_ids[:, :-1])  # [batch, seq_len-1, vocab]
            targets = input_ids[:, 1:]  # [batch, seq_len-1]
            
            # Compute loss
            loss_per_token = criterion(
                logits.reshape(-1, logits.size(-1)),  # [batch*(seq_len-1), vocab]
                targets.reshape(-1)  # [batch*(seq_len-1)]
            )
            loss_per_token = loss_per_token.reshape(targets.shape)  # [batch, seq_len-1]
            
            # Apply mask: only compute loss on answer token
            mask = loss_mask[:, 1:]  # Align with targets
            masked_loss = (loss_per_token * mask).sum() / mask.sum()
            
            # Backward pass
            optimizer.zero_grad()
            masked_loss.backward()
            optimizer.step()
            
            # Metrics
            total_loss += masked_loss.item()
            
            # Accuracy: check if predicted answer digit is correct
            preds = logits.argmax(dim=-1)  # [batch, seq_len-1]
            answer_positions = mask.bool()
            if answer_positions.any():
                correct += (preds[answer_positions] == targets[answer_positions]).sum().item()
                total += answer_positions.sum().item()
            
            pbar.set_postfix({'loss': masked_loss.item(), 
                            'acc': correct/total if total > 0 else 0})
        
        scheduler.step()
        
        print(f"Epoch {epoch+1}: Loss={total_loss/len(train_loader):.4f}, "
              f"Acc={correct/total:.4f}")
        
        # Save checkpoint
        if (epoch + 1) % 5 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, f'checkpoint-{name}-epoch-{epoch+1}.pt')
    
    return model

# Train!
models = {}
for name in train_dataset_names:
    config = HookedTransformerConfig(
            n_layers=2,
            n_heads=8,
            d_model=128,
            d_head=16,  # d_model / n_heads
            d_mlp=None,  # No MLPs (attention-only)
            act_fn=None,  # No activation (no MLPs)
            attention_dir="causal",  # Causal attention
            attn_only=True,  # Attention-only model
            normalization_type=None,  # No LayerNorm for simplicity
            d_vocab=train_tokenizers[name].vocab.size + 5,  # Total vocab size for particular tokenizer
            n_ctx=60,  # Max sequence length
            init_weights=True,
            device="cuda" if torch.cuda.is_available() else "cpu"
        )
    model = HookedTransformer(config)
    if RETRAIN:
        print(f"Starting to train model for {name} with vocab size {train_tokenizers[name].vocab.size}")
        models[name] = train_model(model, name, train_loaders[name], n_epochs=20, lr=1e-3)
    else:
        checkpoint = torch.load(f'checkpoint-{name}-epoch-100.pt', map_location="cpu")
        
        model.load_state_dict(checkpoint["model_state_dict"])
        models[name] = model
        

Starting to train model for easy with vocab size 45
Moving model to device:  cuda


Epoch 1/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 263.14it/s, loss=0.698, acc=0.493]


Epoch 1: Loss=0.7930, Acc=0.4929


Epoch 2/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 298.67it/s, loss=0.683, acc=0.507]


Epoch 2: Loss=0.6968, Acc=0.5070


Epoch 3/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 295.75it/s, loss=0.706, acc=0.502]


Epoch 3: Loss=0.6975, Acc=0.5019


Epoch 4/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 294.91it/s, loss=0.693, acc=0.504]


Epoch 4: Loss=0.6969, Acc=0.5036


Epoch 5/20: 100%|███████████████████████████████████| 313/313 [00:01<00:00, 292.31it/s, loss=0.7, acc=0.504]


Epoch 5: Loss=0.6955, Acc=0.5037


Epoch 6/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 296.90it/s, loss=0.655, acc=0.539]


Epoch 6: Loss=0.6896, Acc=0.5386


Epoch 7/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 297.79it/s, loss=0.636, acc=0.591]


Epoch 7: Loss=0.6684, Acc=0.5906


Epoch 8/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 299.02it/s, loss=0.545, acc=0.657]


Epoch 8: Loss=0.6114, Acc=0.6573


Epoch 9/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 297.75it/s, loss=0.632, acc=0.848]


Epoch 9: Loss=0.3431, Acc=0.8481


Epoch 10/20: 100%|███████████████████████████████| 313/313 [00:01<00:00, 303.43it/s, loss=0.0194, acc=0.981]


Epoch 10: Loss=0.0591, Acc=0.9805


Epoch 11/20: 100%|███████████████████████████████| 313/313 [00:01<00:00, 297.59it/s, loss=0.0135, acc=0.997]


Epoch 11: Loss=0.0134, Acc=0.9970


Epoch 12/20: 100%|█████████████████████████████| 313/313 [00:01<00:00, 297.35it/s, loss=0.000513, acc=0.999]


Epoch 12: Loss=0.0088, Acc=0.9985


Epoch 13/20: 100%|█████████████████████████████| 313/313 [00:01<00:00, 304.04it/s, loss=0.000976, acc=0.998]


Epoch 13: Loss=0.0076, Acc=0.9983


Epoch 14/20: 100%|█████████████████████████████| 313/313 [00:01<00:00, 301.08it/s, loss=0.000181, acc=0.999]


Epoch 14: Loss=0.0050, Acc=0.9990


Epoch 15/20: 100%|███████████████████████████████| 313/313 [00:01<00:00, 292.03it/s, loss=0.0156, acc=0.999]


Epoch 15: Loss=0.0026, Acc=0.9994


Epoch 16/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 293.93it/s, loss=0.000146, acc=1]


Epoch 16: Loss=0.0019, Acc=0.9996


Epoch 17/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 303.09it/s, loss=0.000246, acc=1]


Epoch 17: Loss=0.0016, Acc=0.9997


Epoch 18/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 303.30it/s, loss=0.000134, acc=1]


Epoch 18: Loss=0.0010, Acc=0.9998


Epoch 19/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 300.28it/s, loss=0.000586, acc=1]


Epoch 19: Loss=0.0008, Acc=0.9998


Epoch 20/20: 100%|██████████████████████████████████| 313/313 [00:01<00:00, 300.98it/s, loss=5.94e-5, acc=1]


Epoch 20: Loss=0.0006, Acc=0.9999
Starting to train model for bpe-hard with vocab size 3004
Moving model to device:  cuda


Epoch 1/20: 100%|██████████████████████████████████| 313/313 [00:01<00:00, 299.79it/s, loss=0.743, acc=0.49]


Epoch 1: Loss=0.9627, Acc=0.4904


Epoch 2/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 301.80it/s, loss=0.794, acc=0.553]


Epoch 2: Loss=0.6829, Acc=0.5532


Epoch 3/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 301.04it/s, loss=0.699, acc=0.588]


Epoch 3: Loss=0.6586, Acc=0.5877


Epoch 4/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 302.07it/s, loss=0.706, acc=0.626]


Epoch 4: Loss=0.6206, Acc=0.6260


Epoch 5/20: 100%|██████████████████████████████████| 313/313 [00:01<00:00, 301.00it/s, loss=0.679, acc=0.66]


Epoch 5: Loss=0.5797, Acc=0.6602


Epoch 6/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 288.26it/s, loss=0.577, acc=0.689]


Epoch 6: Loss=0.5453, Acc=0.6894


Epoch 7/20: 100%|██████████████████████████████████| 313/313 [00:01<00:00, 297.71it/s, loss=0.567, acc=0.72]


Epoch 7: Loss=0.5057, Acc=0.7201


Epoch 8/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 304.25it/s, loss=0.549, acc=0.753]


Epoch 8: Loss=0.4603, Acc=0.7526


Epoch 9/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 297.00it/s, loss=0.323, acc=0.789]


Epoch 9: Loss=0.4093, Acc=0.7892


Epoch 10/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 300.82it/s, loss=0.347, acc=0.817]


Epoch 10: Loss=0.3611, Acc=0.8171


Epoch 11/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 296.37it/s, loss=0.247, acc=0.848]


Epoch 11: Loss=0.3126, Acc=0.8482


Epoch 12/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 298.80it/s, loss=0.385, acc=0.873]


Epoch 12: Loss=0.2682, Acc=0.8729


Epoch 13/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 292.64it/s, loss=0.176, acc=0.896]


Epoch 13: Loss=0.2265, Acc=0.8962


Epoch 14/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 304.90it/s, loss=0.284, acc=0.912]


Epoch 14: Loss=0.1961, Acc=0.9121


Epoch 15/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 301.64it/s, loss=0.258, acc=0.928]


Epoch 15: Loss=0.1640, Acc=0.9285


Epoch 16/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 302.63it/s, loss=0.168, acc=0.941]


Epoch 16: Loss=0.1383, Acc=0.9411


Epoch 17/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 295.25it/s, loss=0.148, acc=0.952]


Epoch 17: Loss=0.1184, Acc=0.9522


Epoch 18/20: 100%|███████████████████████████████| 313/313 [00:01<00:00, 297.19it/s, loss=0.0662, acc=0.958]


Epoch 18: Loss=0.1048, Acc=0.9578


Epoch 19/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 296.18it/s, loss=0.137, acc=0.962]


Epoch 19: Loss=0.0970, Acc=0.9622


Epoch 20/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 299.40it/s, loss=0.074, acc=0.965]


Epoch 20: Loss=0.0929, Acc=0.9647
Starting to train model for mult-hard with vocab size 45
Moving model to device:  cuda


Epoch 1/20: 100%|██████████████████████████████████| 313/313 [00:01<00:00, 302.19it/s, loss=1.76, acc=0.244]


Epoch 1: Loss=1.9294, Acc=0.2440


Epoch 2/20: 100%|███████████████████████████████████| 313/313 [00:01<00:00, 309.63it/s, loss=1.5, acc=0.304]


Epoch 2: Loss=1.7141, Acc=0.3044


Epoch 3/20: 100%|██████████████████████████████████| 313/313 [00:01<00:00, 306.84it/s, loss=0.893, acc=0.46]


Epoch 3: Loss=1.2792, Acc=0.4603


Epoch 4/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 297.26it/s, loss=0.689, acc=0.613]


Epoch 4: Loss=0.9090, Acc=0.6128


Epoch 5/20: 100%|██████████████████████████████████| 313/313 [00:01<00:00, 294.92it/s, loss=0.486, acc=0.68]


Epoch 5: Loss=0.7476, Acc=0.6796


Epoch 6/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 292.11it/s, loss=0.625, acc=0.738]


Epoch 6: Loss=0.6306, Acc=0.7382


Epoch 7/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 298.61it/s, loss=0.473, acc=0.765]


Epoch 7: Loss=0.5554, Acc=0.7646


Epoch 8/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 293.90it/s, loss=0.355, acc=0.799]


Epoch 8: Loss=0.4800, Acc=0.7986


Epoch 9/20: 100%|███████████████████████████████████| 313/313 [00:01<00:00, 293.54it/s, loss=0.4, acc=0.827]


Epoch 9: Loss=0.4207, Acc=0.8270


Epoch 10/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 297.41it/s, loss=0.327, acc=0.858]


Epoch 10: Loss=0.3573, Acc=0.8576


Epoch 11/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 297.25it/s, loss=0.454, acc=0.882]


Epoch 11: Loss=0.2958, Acc=0.8824


Epoch 12/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 304.55it/s, loss=0.503, acc=0.903]


Epoch 12: Loss=0.2474, Acc=0.9033


Epoch 13/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 301.57it/s, loss=0.138, acc=0.918]


Epoch 13: Loss=0.2120, Acc=0.9181


Epoch 14/20: 100%|███████████████████████████████| 313/313 [00:01<00:00, 295.43it/s, loss=0.0842, acc=0.936]


Epoch 14: Loss=0.1749, Acc=0.9358


Epoch 15/20: 100%|███████████████████████████████| 313/313 [00:01<00:00, 296.24it/s, loss=0.0739, acc=0.945]


Epoch 15: Loss=0.1508, Acc=0.9450


Epoch 16/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 298.43it/s, loss=0.125, acc=0.956]


Epoch 16: Loss=0.1275, Acc=0.9565


Epoch 17/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 295.29it/s, loss=0.158, acc=0.963]


Epoch 17: Loss=0.1104, Acc=0.9625


Epoch 18/20: 100%|███████████████████████████████| 313/313 [00:01<00:00, 303.06it/s, loss=0.0651, acc=0.966]


Epoch 18: Loss=0.1003, Acc=0.9664


Epoch 19/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 300.61it/s, loss=0.0581, acc=0.97]


Epoch 19: Loss=0.0935, Acc=0.9697


Epoch 20/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 294.59it/s, loss=0.13, acc=0.971]


Epoch 20: Loss=0.0898, Acc=0.9713
Starting to train model for length-hard with vocab size 45
Moving model to device:  cuda


Epoch 1/20: 100%|██████████████████████████████████| 313/313 [00:01<00:00, 280.53it/s, loss=1.11, acc=0.467]


Epoch 1: Loss=1.0543, Acc=0.4669


Epoch 2/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 282.26it/s, loss=0.861, acc=0.473]


Epoch 2: Loss=0.9134, Acc=0.4728


Epoch 3/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 279.05it/s, loss=0.733, acc=0.478]


Epoch 3: Loss=0.9101, Acc=0.4780


Epoch 4/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 281.40it/s, loss=0.866, acc=0.471]


Epoch 4: Loss=0.9072, Acc=0.4708


Epoch 5/20: 100%|███████████████████████████████████| 313/313 [00:01<00:00, 276.05it/s, loss=1.07, acc=0.48]


Epoch 5: Loss=0.9060, Acc=0.4797


Epoch 6/20: 100%|██████████████████████████████████| 313/313 [00:01<00:00, 282.51it/s, loss=1.05, acc=0.476]


Epoch 6: Loss=0.9038, Acc=0.4761


Epoch 7/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 278.17it/s, loss=0.931, acc=0.473]


Epoch 7: Loss=0.9026, Acc=0.4734


Epoch 8/20: 100%|██████████████████████████████████| 313/313 [00:01<00:00, 274.21it/s, loss=1.15, acc=0.487]


Epoch 8: Loss=0.8984, Acc=0.4867


Epoch 9/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 280.25it/s, loss=0.867, acc=0.475]


Epoch 9: Loss=0.8967, Acc=0.4748


Epoch 10/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 280.93it/s, loss=0.774, acc=0.483]


Epoch 10: Loss=0.8944, Acc=0.4831


Epoch 11/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 288.32it/s, loss=0.995, acc=0.487]


Epoch 11: Loss=0.8899, Acc=0.4867


Epoch 12/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 275.37it/s, loss=0.807, acc=0.485]


Epoch 12: Loss=0.8869, Acc=0.4849


Epoch 13/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 280.26it/s, loss=0.868, acc=0.483]


Epoch 13: Loss=0.8824, Acc=0.4834


Epoch 14/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 276.66it/s, loss=1.07, acc=0.486]


Epoch 14: Loss=0.8787, Acc=0.4856


Epoch 15/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 276.39it/s, loss=0.909, acc=0.493]


Epoch 15: Loss=0.8719, Acc=0.4929


Epoch 16/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 275.70it/s, loss=0.921, acc=0.496]


Epoch 16: Loss=0.8667, Acc=0.4962


Epoch 17/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 285.24it/s, loss=0.999, acc=0.494]


Epoch 17: Loss=0.8633, Acc=0.4944


Epoch 18/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 282.11it/s, loss=0.902, acc=0.496]


Epoch 18: Loss=0.8602, Acc=0.4963


Epoch 19/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 277.93it/s, loss=0.751, acc=0.501]


Epoch 19: Loss=0.8577, Acc=0.5007


Epoch 20/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 281.50it/s, loss=0.897, acc=0.502]


Epoch 20: Loss=0.8572, Acc=0.5017
Starting to train model for all-hard with vocab size 3588
Moving model to device:  cuda


Epoch 1/20: 100%|██████████████████████████████████| 313/313 [00:01<00:00, 286.22it/s, loss=2.14, acc=0.132]


Epoch 1: Loss=2.3563, Acc=0.1321


Epoch 2/20: 100%|██████████████████████████████████| 313/313 [00:01<00:00, 280.25it/s, loss=2.03, acc=0.194]


Epoch 2: Loss=2.0230, Acc=0.1943


Epoch 3/20: 100%|██████████████████████████████████| 313/313 [00:01<00:00, 288.33it/s, loss=1.95, acc=0.262]


Epoch 3: Loss=1.9002, Acc=0.2616


Epoch 4/20: 100%|██████████████████████████████████| 313/313 [00:01<00:00, 293.36it/s, loss=1.61, acc=0.319]


Epoch 4: Loss=1.7826, Acc=0.3192


Epoch 5/20: 100%|██████████████████████████████████| 313/313 [00:01<00:00, 293.25it/s, loss=2.02, acc=0.383]


Epoch 5: Loss=1.6157, Acc=0.3830


Epoch 6/20: 100%|██████████████████████████████████| 313/313 [00:01<00:00, 288.35it/s, loss=1.59, acc=0.465]


Epoch 6: Loss=1.4270, Acc=0.4652


Epoch 7/20: 100%|██████████████████████████████████| 313/313 [00:01<00:00, 286.88it/s, loss=1.35, acc=0.532]


Epoch 7: Loss=1.2646, Acc=0.5321


Epoch 8/20: 100%|██████████████████████████████████| 313/313 [00:01<00:00, 287.07it/s, loss=1.13, acc=0.586]


Epoch 8: Loss=1.1229, Acc=0.5863


Epoch 9/20: 100%|██████████████████████████████████| 313/313 [00:01<00:00, 287.24it/s, loss=1.33, acc=0.642]


Epoch 9: Loss=0.9932, Acc=0.6422


Epoch 10/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 285.83it/s, loss=1.27, acc=0.689]


Epoch 10: Loss=0.8773, Acc=0.6888


Epoch 11/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 284.22it/s, loss=1.23, acc=0.738]


Epoch 11: Loss=0.7535, Acc=0.7380


Epoch 12/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 288.17it/s, loss=0.871, acc=0.786]


Epoch 12: Loss=0.6344, Acc=0.7861


Epoch 13/20: 100%|██████████████████████████████████| 313/313 [00:01<00:00, 283.86it/s, loss=0.77, acc=0.83]


Epoch 13: Loss=0.5204, Acc=0.8295


Epoch 14/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 286.62it/s, loss=0.539, acc=0.87]


Epoch 14: Loss=0.4208, Acc=0.8702


Epoch 15/20: 100%|██████████████████████████████████| 313/313 [00:01<00:00, 289.55it/s, loss=0.318, acc=0.9]


Epoch 15: Loss=0.3358, Acc=0.8998


Epoch 16/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 284.98it/s, loss=0.316, acc=0.925]


Epoch 16: Loss=0.2710, Acc=0.9247


Epoch 17/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 290.26it/s, loss=0.191, acc=0.941]


Epoch 17: Loss=0.2259, Acc=0.9405


Epoch 18/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 290.92it/s, loss=0.104, acc=0.951]


Epoch 18: Loss=0.1955, Acc=0.9505


Epoch 19/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 287.92it/s, loss=0.093, acc=0.955]


Epoch 19: Loss=0.1781, Acc=0.9554


Epoch 20/20: 100%|███████████████████████████████| 313/313 [00:01<00:00, 292.54it/s, loss=0.0807, acc=0.958]


Epoch 20: Loss=0.1698, Acc=0.9580
Starting to train model for mixed with vocab size 3785
Moving model to device:  cuda


Epoch 1/20: 100%|██████████████████████████████████| 313/313 [00:01<00:00, 221.65it/s, loss=2.28, acc=0.181]


Epoch 1: Loss=2.5569, Acc=0.1807


Epoch 2/20: 100%|██████████████████████████████████| 313/313 [00:01<00:00, 225.05it/s, loss=2.14, acc=0.194]


Epoch 2: Loss=2.2141, Acc=0.1937


Epoch 3/20: 100%|██████████████████████████████████| 313/313 [00:01<00:00, 224.51it/s, loss=2.08, acc=0.212]


Epoch 3: Loss=2.1414, Acc=0.2125


Epoch 4/20: 100%|██████████████████████████████████| 313/313 [00:01<00:00, 224.77it/s, loss=1.89, acc=0.268]


Epoch 4: Loss=2.0274, Acc=0.2676


Epoch 5/20: 100%|████████████████████████████████████| 313/313 [00:01<00:00, 223.80it/s, loss=1.7, acc=0.33]


Epoch 5: Loss=1.8545, Acc=0.3296


Epoch 6/20: 100%|██████████████████████████████████| 313/313 [00:01<00:00, 225.97it/s, loss=1.76, acc=0.393]


Epoch 6: Loss=1.6842, Acc=0.3930


Epoch 7/20: 100%|██████████████████████████████████| 313/313 [00:01<00:00, 220.22it/s, loss=1.24, acc=0.454]


Epoch 7: Loss=1.5317, Acc=0.4543


Epoch 8/20: 100%|██████████████████████████████████| 313/313 [00:01<00:00, 223.69it/s, loss=1.77, acc=0.499]


Epoch 8: Loss=1.4090, Acc=0.4988


Epoch 9/20: 100%|███████████████████████████████████| 313/313 [00:01<00:00, 223.92it/s, loss=1.03, acc=0.53]


Epoch 9: Loss=1.3124, Acc=0.5305


Epoch 10/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 224.40it/s, loss=1.27, acc=0.554]


Epoch 10: Loss=1.2429, Acc=0.5541


Epoch 11/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 222.46it/s, loss=1.08, acc=0.574]


Epoch 11: Loss=1.1813, Acc=0.5741


Epoch 12/20: 100%|██████████████████████████████████| 313/313 [00:01<00:00, 223.87it/s, loss=1.1, acc=0.588]


Epoch 12: Loss=1.1441, Acc=0.5882


Epoch 13/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 221.74it/s, loss=1.37, acc=0.596]


Epoch 13: Loss=1.1192, Acc=0.5958


Epoch 14/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 222.63it/s, loss=0.935, acc=0.601]


Epoch 14: Loss=1.0918, Acc=0.6006


Epoch 15/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 223.31it/s, loss=0.834, acc=0.609]


Epoch 15: Loss=1.0762, Acc=0.6093


Epoch 16/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 223.72it/s, loss=0.89, acc=0.612]


Epoch 16: Loss=1.0631, Acc=0.6124


Epoch 17/20: 100%|████████████████████████████████████| 313/313 [00:01<00:00, 226.63it/s, loss=1, acc=0.615]


Epoch 17: Loss=1.0535, Acc=0.6151


Epoch 18/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 222.92it/s, loss=0.759, acc=0.617]


Epoch 18: Loss=1.0440, Acc=0.6172


Epoch 19/20: 100%|████████████████████████████████| 313/313 [00:01<00:00, 223.71it/s, loss=0.942, acc=0.619]


Epoch 19: Loss=1.0380, Acc=0.6193


Epoch 20/20: 100%|█████████████████████████████████| 313/313 [00:01<00:00, 223.22it/s, loss=1.04, acc=0.621]


Epoch 20: Loss=1.0353, Acc=0.6210


In [18]:
def evaluate_model(model, dataloader, device='cuda'):
    model.eval()
    total = 0
    correct = 0
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            loss_mask = batch["loss_mask"].to(device)
            
            # Forward pass
            logits = model(input_ids[:, :-1])
            targets = input_ids[:, 1:]
            mask = loss_mask[:, 1:]
            
            # Get predictions
            preds = logits.argmax(dim=-1)
            
            # Identify valid rows
            per_row = mask.sum(dim=1)
            valid_rows = (per_row == 1)
            
            if not valid_rows.any():
                continue   # skip batches with no valid rows
            
            # Filter only valid rows
            mask_valid = mask[valid_rows].bool()
            preds_valid = preds[valid_rows]
            targets_valid = targets[valid_rows]
            
            # Extract predictions and true labels at masked positions
            pred_answers = preds_valid[mask_valid]
            true_answers = targets_valid[mask_valid]
            
            correct += (pred_answers == true_answers).sum().item()
            total += pred_answers.numel()
            
    accuracy = correct / total
    return accuracy

# Test
for name in train_dataset_names:
    accuracy = evaluate_model(models[name], DataLoader(
        test_datasets[name],
        batch_size=64,
        shuffle=False,
        collate_fn=collate_fn
    ))
    print(f"Testing model {name} accuracy on {name}: {accuracy}")

Testing model easy accuracy on easy: 0.9985
Testing model bpe-hard accuracy on bpe-hard: 0.519
Testing model mult-hard accuracy on mult-hard: 0.9315
Testing model length-hard accuracy on length-hard: 0.4795
Testing model all-hard accuracy on all-hard: 0.182
Testing model mixed accuracy on mixed: 0.158
