In [1]:
import torch
import torch.nn as nn
from transformer_lens import HookedTransformer, HookedTransformerConfig

# Configuration
config = HookedTransformerConfig(
    n_layers=2,
    n_heads=8,
    d_model=128,
    d_head=16,  # d_model / n_heads
    d_mlp=None,  # No MLPs (attention-only)
    act_fn=None,  # No activation (no MLPs)
    attention_dir="causal",  # Causal attention
    attn_only=True,  # Attention-only model
    normalization_type=None,  # No LayerNorm for simplicity
    d_vocab=50,  # 26 letters + 10 digits + special tokens
    n_ctx=60,  # Max sequence length
    init_weights=True,
    device="cuda" if torch.cuda.is_available() else "cpu"
)

model = HookedTransformer(config)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Model device: {next(model.parameters()).device}")
print(f"Model device name: {torch.cuda.get_device_name(0)}")

Model parameters: 152,626
Model device: cuda:0
Model device name: NVIDIA GeForce RTX 5070 Ti


In [2]:
class Vocabulary:
    class TrieNode:
        def __init__(self):
            self.id = None
            self.next = {}

    def __init__(self):
        self.root = self.TrieNode()
        self.token_map = {} # Stores mapping of ID to token
        self.size = 0

    # Adds a new token into the vocabulary
    def add_token(self, token):
        node = self.root
        for c in token:
            if c not in node.next:
                node.next[c] = self.TrieNode()
            node = node.next[c]
        if node.id is None:
            node.id = self.size
            self.token_map[self.size] = token
            self.size += 1

    # Finds id of longest prefix token of text[start:end], and returns length of token
    def longest_prefix_token(self, text, start):
        longest_token = None
        longest_length = 0
        node = self.root
        for i in range(start, len(text)):
            if text[i] not in node.next:
                break
            node = node.next[text[i]]
            if node.id is not None:
                longest_token = node.id
                longest_length = i - start + 1
        assert longest_token is not None
        return longest_token, longest_length

    # Converts an id to the corresponding token
    def get_token(self, id):
        return self.token_map[id]

In [19]:
# Simple character-level tokenizer
class CountingTokenizer:
    def __init__(self):
        # Vocabulary: letters + digits + special tokens
        self.vocab = Vocabulary()
        chars = list("abcdefghijklmnopqrstuvwxyz0123456789")
        special = ["<PAD>", "<BOS>", "<EOS>", ":", " ", "Count", "the", "letter", "in"]
        raw_vocab = special + chars
        for token in raw_vocab:
            self.vocab.add_token(token)

    def encode(self, text, include_lengths = False):
        """Convert text to token IDs"""
        ids = []
        i = 0
        while i < len(text):
            id, token_length = self.vocab.longest_prefix_token(text, i)
            assert id != -1
            if include_lengths:
                ids.append((id, token_length))
            else:
                ids.append(id)
            i += token_length
        return ids
    
    def decode(self, ids):
        """Convert token IDs to text"""
        return "".join([self.vocab.get_token(id) for id in ids])

    def apply_bpe(self, words, max_token_length=3):
        """Adds merge rules based on a list of words for BPE"""
        text = "".join([f"<BOS>{word}<EOS>" for word in words])
        ignore_tokens = ["<PAD>", "<BOS>", "<EOS>", ":", " "]
        while True:
            encoded = self.encode(text, include_lengths=True)
            pairs = {}
            merge_pair = ()
            for i in range(len(encoded) - 1):
                token_pair = encoded[i], encoded[i + 1]
                if token_pair[0][1] + token_pair[1][1] > max_token_length:
                    continue
                if any(self.vocab.get_token(token[0]) in ignore_tokens for token in token_pair):
                    continue
                pairs[token_pair] = pairs.get(token_pair, 0) + 1
                if not merge_pair or pairs[merge_pair] < pairs[token_pair]:
                    merge_pair = token_pair
            if not merge_pair or pairs[merge_pair] < 2:
                break
            self.vocab.add_token("".join([self.vocab.get_token(token[0]) for token in merge_pair]))

tokenizer = CountingTokenizer()

In [20]:
import random

def generate_counting_question(target_letter='a', multiplicity_range=(1, 2), length_range=(5, 10)):
    """
    Generates a new word based on input parameters and gives the answer to the question
    """
    
    # Sample words with target letter
    words_with_target = []
    count = random.randint(*multiplicity_range)
    
    # Generate string with exact count of target letter
    chars = list("abcdefghijklmnopqrstuvwxyz")
    chars.remove(target_letter)
    
    length = random.randint(max(count, length_range[0]), length_range[1])
    string_chars = random.choices(chars, k=length - count)
    
    # Insert target letters
    positions = random.sample(range(length), count)
    for pos in positions:
        string_chars.insert(pos, target_letter)
    
    input_string = "".join(string_chars[:length])
    question = f"Count the letter {target_letter} in: {input_string}"
    answer = str(count)

    return question, answer

def generate_counting_example(qa, tokenizer=None):
    """
    Generate: "Count the letter a in: banana" -> "3"
    Format: [question tokens] [answer token]
    """
    question, answer = qa
    
    # Tokenize
    question_tokens = tokenizer.encode(question)
    question_tokens_decoded = [tokenizer.decode([token]) for token in question_tokens]
    answer_token = tokenizer.encode(answer)[0]  # Single digit
    
    # Combine: question + answer
    full_tokens = question_tokens + [answer_token]
    
    return {
        'tokens': full_tokens,
        'question_tokens_decoded': question_tokens_decoded,
        'question_length': len(question_tokens),  # For loss masking
        'answer': int(answer),
        'text': question + answer
    }

# Test
tokenizer = CountingTokenizer()
qa = generate_counting_question()
example = generate_counting_example(qa, tokenizer=tokenizer)
print(f"Text: {example['text']}")
print(f"Tokens: {example['tokens']}")
print(f"Question Tokens Decoded: {example['question_tokens_decoded']}")
print(f"Question length: {example['question_length']}")
print(f"Answer: {example['answer']}")

Text: Count the letter a in: qhgajcn1
Tokens: [5, 4, 6, 4, 7, 4, 9, 4, 8, 3, 4, 25, 16, 15, 9, 18, 11, 22, 36]
Question Tokens Decoded: ['Count', ' ', 'the', ' ', 'letter', ' ', 'a', ' ', 'in', ':', ' ', 'q', 'h', 'g', 'a', 'j', 'c', 'n']
Question length: 18
Answer: 1


In [21]:
# Test
qas = [generate_counting_question() for i in range(10)]
tokenizer = CountingTokenizer()
tokenizer.apply_bpe([qa[0] for qa in qas])
qa = generate_counting_question()
example = generate_counting_example(qa, tokenizer=tokenizer)
print(f"Sample Pre-Processing Text: {"".join([f"<BOS>{qa[0]}<EOS>" for qa in qas])}")
print(f"Text: {example['text']}")
print(f"Tokens: {example['tokens']}")
print(f"Question Tokens Decoded: {example['question_tokens_decoded']}")
print(f"Question length: {example['question_length']}")
print(f"Answer: {example['answer']}")

Sample Pre-Processing Text: <BOS>Count the letter a in: vvbawnvajn<EOS><BOS>Count the letter a in: sxnjiharl<EOS><BOS>Count the letter a in: rhanfb<EOS><BOS>Count the letter a in: paoae<EOS><BOS>Count the letter a in: aamilgus<EOS><BOS>Count the letter a in: sgqnlant<EOS><BOS>Count the letter a in: foxaa<EOS><BOS>Count the letter a in: xchaubpba<EOS><BOS>Count the letter a in: mpeaualz<EOS><BOS>Count the letter a in: leamdssug<EOS>
Text: Count the letter a in: aygkohe1
Tokens: [5, 4, 6, 4, 7, 4, 9, 4, 8, 3, 4, 9, 33, 15, 19, 23, 16, 13, 36]
Question Tokens Decoded: ['Count', ' ', 'the', ' ', 'letter', ' ', 'a', ' ', 'in', ':', ' ', 'a', 'y', 'g', 'k', 'o', 'h', 'e']
Question length: 18
Answer: 1


In [41]:
import pickle
from torch.utils.data import Dataset, DataLoader


class CountingDataset(Dataset):
    def __init__(self, n_examples=50000, difficulty='easy', tokenizer=None, allow_bpe=True, train_bpe=True):
        """
        difficulty: 'easy', 'bpe-hard', 'mult-hard', etc.
        """
        self.tokenizer = tokenizer
        self.examples = []

        use_bpe = False
        # Set parameters based on difficulty
        if difficulty == 'easy':
            mult_range = (1, 2)
            len_range = (5, 10)
        elif difficulty == 'bpe-hard':
            mult_range = (1, 2)
            len_range = (5, 10)
            use_bpe = True
        elif difficulty == 'mult-hard':
            mult_range = (3, 10)
            len_range = (5, 10)
        elif difficulty == 'length-hard':
            mult_range = (1, 2)
            len_range = (20, 50)
        elif difficulty == 'all-hard':
            mult_range = (3, 10)
            len_range = (20, 50)
            use_bpe = True
        elif difficulty == 'mixed':
            mult_range = (1,10)
            len_range = (5, 50)
            use_bpe = True
        else:
            assert False
        
        # Generate basic words
        target_letters = list("abcdefghijklmnopqrstuvwxyz")
        qas = []
        for _ in range(n_examples):
            target = random.choice(target_letters)
            qa = generate_counting_question(
                target_letter=target,
                multiplicity_range=mult_range,
                length_range=len_range,
            )
            qas.append(qa)

        # Add basic words to the vocabulary if we are applying BPE
        if use_bpe and allow_bpe and train_bpe:
            tokenizer.apply_bpe([qa[0] for qa in qas])

        # Generate the full questions and examples
        basic_tokenizer = CountingTokenizer()
        if difficulty == "mixed":
            bpe_set = random.sample(range(len(qas)), len(qas) // 2)
        else:
            bpe_set = range(len(qas))
        for i in range(len(qas)):
            example = generate_counting_example(
                qas[i],
                tokenizer=tokenizer if i in bpe_set else basic_tokenizer
            )
            self.examples.append(example)
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        return self.examples[idx]

def collate_fn(batch, pad_id=0, max_len=100):
    """Pad sequences to same length"""
    # Pad tokens
    tokens = [ex['tokens'] for ex in batch]
    max_batch_len = min(max(len(t) for t in tokens), max_len)
    
    padded_tokens = []
    masks = []  # Loss mask: 1 for answer token, 0 elsewhere
    
    for ex in batch:
        seq = ex['tokens'][:max_batch_len]
        q_len = min(ex['question_length'], max_batch_len - 1)
        
        # Pad sequence
        padded = seq + [pad_id] * (max_batch_len - len(seq))
        padded_tokens.append(padded)
        
        # Create mask: only compute loss on answer token
        mask = [0] * max_batch_len
        if q_len < len(seq):  # If answer token exists
            mask[q_len] = 1  # Answer is right after question
        masks.append(mask)
    
    return {
        'input_ids': torch.tensor(padded_tokens, dtype=torch.long),
        'loss_mask': torch.tensor(masks, dtype=torch.float),
        'answers': torch.tensor([ex['answer'] for ex in batch], dtype=torch.long)
    }


In [None]:
# Create dataloaders
train_dataset_names = ["easy", "bpe-hard", "mult-hard", "length-hard", "all-hard"]
GENERATE_NEW_TRAINING_DATA = False # IMPORTANT: Set this to True only if you want to generate new datasets

# Training dataloaders
train_datasets = {}
train_loaders = {}
train_tokenizers = {}
for name in train_dataset_names:
    if GENERATE_NEW_TRAINING_DATA:
        train_tokenizers[name] = CountingTokenizer()
        train_datasets[name] = CountingDataset(n_examples=20000, difficulty=name, tokenizer=train_tokenizers[name])
        with open(f"train-{name}-dataset.pkl", "wb") as f:
            pickle.dump(train_datasets[name], f)
        with open(f"train-{name}-tokenizer.pkl", "wb") as f:
            pickle.dump(train_tokenizers[name], f)
    else:
        with open(f"train-{name}-dataset.pkl", "rb") as f:
            train_datasets[name] = pickle.load(f)
        with open(f"train-{name}-tokenizer.pkl", "rb") as f:
            train_tokenizers[name] = pickle.load(f)
            print(train_tokenizers[name].vocab.size)
    train_loaders[name] = DataLoader(train_datasets[name], batch_size=64, shuffle=True, collate_fn=collate_fn)

# Test batch
batch = next(iter(train_loaders["bpe-hard"]))
print(f"Input shape: {batch['input_ids'].shape}")
print(f"Mask shape: {batch['loss_mask'].shape}")
print(f"Example tokens: {batch['input_ids'][0]}")
print(f"Example mask: {batch['loss_mask'][0]}")

45
3004
45
45
3588
3785
Input shape: torch.Size([64, 17])
Mask shape: torch.Size([64, 17])
Example tokens: tensor([   5,    4,    6,    4,    7,    4,   29,    4,    8,    3,    4,  214,
        1195,   37,    0,    0,    0])
Example mask: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.])


In [47]:
GENERATE_NEW_TESTING_DATA = True # IMPORTANT: Set this to True only if you want to generate new datasets

# Testing dataloaders
test_dataset_names = ["easy", "bpe-hard", "mult-hard", "length-hard", "all-hard"]
test_datasets = {}
test_tokenizers = {}
for name in test_dataset_names:
    if GENERATE_NEW_TESTING_DATA:
        test_datasets[name] = CountingDataset(n_examples=2000, difficulty=name, tokenizer=train_tokenizers[name], allow_bpe=True, train_bpe=False)
        with open(f"test-{name}-dataset.pkl", "wb") as f:
            pickle.dump(test_datasets[name], f)
    else:
        with open(f"test-{name}-dataset.pkl", "rb") as f:
            test_datasets[name] = pickle.load(f)

In [45]:
import torch.optim as optim
from tqdm import tqdm

RETRAIN = True # IMPORTANT: Only toggle to True if you want to actually train new models

def train_model(model, name, train_loader, n_epochs=10, lr=1e-3, device='cuda'):
    """
    Train with MASKED loss (only on answer token)
    
    Classification loss: CrossEntropy on vocabulary
    (Can also try regression loss on digit value)
    """
    model = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
    
    # Classification loss (predict token ID)
    criterion = nn.CrossEntropyLoss(reduction='none')  # Per-token loss
    
    for epoch in range(n_epochs):
        model.train()
        total_loss = 0
        correct = 0
        total = 0
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{n_epochs}")
        for batch in pbar:
            input_ids = batch['input_ids'].to(device)  # [batch, seq_len]
            loss_mask = batch['loss_mask'].to(device)  # [batch, seq_len]
            
            # Forward pass
            # Input: all tokens except last
            # Target: all tokens except first (shifted by 1)
            logits = model(input_ids[:, :-1])  # [batch, seq_len-1, vocab]
            targets = input_ids[:, 1:]  # [batch, seq_len-1]
            
            # Compute loss
            loss_per_token = criterion(
                logits.reshape(-1, logits.size(-1)),  # [batch*(seq_len-1), vocab]
                targets.reshape(-1)  # [batch*(seq_len-1)]
            )
            loss_per_token = loss_per_token.reshape(targets.shape)  # [batch, seq_len-1]
            
            # Apply mask: only compute loss on answer token
            mask = loss_mask[:, 1:]  # Align with targets
            masked_loss = (loss_per_token * mask).sum() / mask.sum()
            
            # Backward pass
            optimizer.zero_grad()
            masked_loss.backward()
            optimizer.step()
            
            # Metrics
            total_loss += masked_loss.item()
            
            # Accuracy: check if predicted answer digit is correct
            preds = logits.argmax(dim=-1)  # [batch, seq_len-1]
            answer_positions = mask.bool()
            if answer_positions.any():
                correct += (preds[answer_positions] == targets[answer_positions]).sum().item()
                total += answer_positions.sum().item()
            
            pbar.set_postfix({'loss': masked_loss.item(), 
                            'acc': correct/total if total > 0 else 0})
        
        scheduler.step()
        
        print(f"Epoch {epoch+1}: Loss={total_loss/len(train_loader):.4f}, "
              f"Acc={correct/total:.4f}")
        
        # Save checkpoint
        if (epoch + 1) % 5 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, f'checkpoint-{name}-epoch-{epoch+1}.pt')
    
    return model

# Train!
models = {}
for name in train_dataset_names:
    config = HookedTransformerConfig(
            n_layers=2,
            n_heads=8,
            d_model=128,
            d_head=16,  # d_model / n_heads
            d_mlp=None,  # No MLPs (attention-only)
            act_fn=None,  # No activation (no MLPs)
            attention_dir="causal",  # Causal attention
            attn_only=True,  # Attention-only model
            normalization_type=None,  # No LayerNorm for simplicity
            d_vocab=train_tokenizers[name].vocab.size + 5,  # Total vocab size for particular tokenizer
            n_ctx=100,  # Max sequence length
            init_weights=True,
            device="cuda" if torch.cuda.is_available() else "cpu"
        )
    model = HookedTransformer(config)
    if RETRAIN:
        print(f"Starting to train model for {name} with vocab size {train_tokenizers[name].vocab.size}")
        models[name] = train_model(model, name, train_loaders[name], n_epochs=20, lr=1e-3)
    else:
        checkpoint = torch.load(f'checkpoint-{name}-epoch-100.pt', map_location="cpu")
        
        model.load_state_dict(checkpoint["model_state_dict"])
        models[name] = model
        

Starting to train model for easy with vocab size 45
Moving model to device:  cuda


Epoch 1/20: 100%|██████████| 313/313 [00:02<00:00, 121.81it/s, loss=0.704, acc=0.504]


Epoch 1: Loss=0.7856, Acc=0.5038


Epoch 2/20: 100%|██████████| 313/313 [00:02<00:00, 129.98it/s, loss=0.69, acc=0.498] 


Epoch 2: Loss=0.7016, Acc=0.4978


Epoch 3/20: 100%|██████████| 313/313 [00:02<00:00, 138.41it/s, loss=0.691, acc=0.502]


Epoch 3: Loss=0.6968, Acc=0.5020


Epoch 4/20: 100%|██████████| 313/313 [00:02<00:00, 135.83it/s, loss=0.66, acc=0.523] 


Epoch 4: Loss=0.6940, Acc=0.5229


Epoch 5/20: 100%|██████████| 313/313 [00:02<00:00, 136.52it/s, loss=0.647, acc=0.564]


Epoch 5: Loss=0.6820, Acc=0.5636


Epoch 6/20: 100%|██████████| 313/313 [00:02<00:00, 134.49it/s, loss=0.665, acc=0.587]


Epoch 6: Loss=0.6714, Acc=0.5875


Epoch 7/20: 100%|██████████| 313/313 [00:02<00:00, 133.82it/s, loss=0.609, acc=0.631]


Epoch 7: Loss=0.6396, Acc=0.6309


Epoch 8/20: 100%|██████████| 313/313 [00:02<00:00, 133.10it/s, loss=0.474, acc=0.717]


Epoch 8: Loss=0.5494, Acc=0.7167


Epoch 9/20: 100%|██████████| 313/313 [00:02<00:00, 134.51it/s, loss=0.137, acc=0.881] 


Epoch 9: Loss=0.2795, Acc=0.8811


Epoch 10/20: 100%|██████████| 313/313 [00:02<00:00, 122.21it/s, loss=0.0129, acc=0.983] 


Epoch 10: Loss=0.0560, Acc=0.9834


Epoch 11/20: 100%|██████████| 313/313 [00:02<00:00, 125.73it/s, loss=0.00197, acc=0.992]


Epoch 11: Loss=0.0236, Acc=0.9923


Epoch 12/20: 100%|██████████| 313/313 [00:02<00:00, 123.51it/s, loss=0.0154, acc=0.998]  


Epoch 12: Loss=0.0093, Acc=0.9977


Epoch 13/20: 100%|██████████| 313/313 [00:02<00:00, 121.37it/s, loss=0.00379, acc=0.999] 


Epoch 13: Loss=0.0043, Acc=0.9993


Epoch 14/20: 100%|██████████| 313/313 [00:02<00:00, 123.46it/s, loss=0.000371, acc=1]


Epoch 14: Loss=0.0018, Acc=0.9998


Epoch 15/20: 100%|██████████| 313/313 [00:02<00:00, 126.64it/s, loss=0.000242, acc=1]


Epoch 15: Loss=0.0013, Acc=0.9998


Epoch 16/20: 100%|██████████| 313/313 [00:02<00:00, 129.27it/s, loss=4.61e-5, acc=1] 


Epoch 16: Loss=0.0006, Acc=1.0000


Epoch 17/20: 100%|██████████| 313/313 [00:02<00:00, 128.95it/s, loss=0.00044, acc=1] 


Epoch 17: Loss=0.0005, Acc=0.9999


Epoch 18/20: 100%|██████████| 313/313 [00:02<00:00, 123.90it/s, loss=0.000134, acc=1]


Epoch 18: Loss=0.0004, Acc=1.0000


Epoch 19/20: 100%|██████████| 313/313 [00:02<00:00, 123.10it/s, loss=0.00054, acc=1] 


Epoch 19: Loss=0.0004, Acc=1.0000


Epoch 20/20: 100%|██████████| 313/313 [00:02<00:00, 127.68it/s, loss=0.00049, acc=1] 


Epoch 20: Loss=0.0003, Acc=1.0000
Starting to train model for bpe-hard with vocab size 3004
Moving model to device:  cuda


Epoch 1/20: 100%|██████████| 313/313 [00:02<00:00, 128.14it/s, loss=0.68, acc=0.503] 


Epoch 1: Loss=0.9713, Acc=0.5033


Epoch 2/20: 100%|██████████| 313/313 [00:02<00:00, 130.72it/s, loss=0.735, acc=0.559]


Epoch 2: Loss=0.6831, Acc=0.5591


Epoch 3/20: 100%|██████████| 313/313 [00:02<00:00, 131.71it/s, loss=0.67, acc=0.595] 


Epoch 3: Loss=0.6579, Acc=0.5950


Epoch 4/20: 100%|██████████| 313/313 [00:02<00:00, 127.79it/s, loss=0.621, acc=0.632]


Epoch 4: Loss=0.6190, Acc=0.6315


Epoch 5/20: 100%|██████████| 313/313 [00:02<00:00, 121.50it/s, loss=0.556, acc=0.67] 


Epoch 5: Loss=0.5742, Acc=0.6702


Epoch 6/20: 100%|██████████| 313/313 [00:02<00:00, 124.44it/s, loss=0.511, acc=0.7]  


Epoch 6: Loss=0.5350, Acc=0.6998


Epoch 7/20: 100%|██████████| 313/313 [00:02<00:00, 128.73it/s, loss=0.539, acc=0.732]


Epoch 7: Loss=0.4929, Acc=0.7317


Epoch 8/20: 100%|██████████| 313/313 [00:02<00:00, 133.97it/s, loss=0.456, acc=0.766]


Epoch 8: Loss=0.4457, Acc=0.7661


Epoch 9/20: 100%|██████████| 313/313 [00:02<00:00, 134.91it/s, loss=0.51, acc=0.798] 


Epoch 9: Loss=0.3928, Acc=0.7978


Epoch 10/20: 100%|██████████| 313/313 [00:02<00:00, 134.29it/s, loss=0.495, acc=0.831]


Epoch 10: Loss=0.3403, Acc=0.8308


Epoch 11/20: 100%|██████████| 313/313 [00:02<00:00, 132.11it/s, loss=0.645, acc=0.857]


Epoch 11: Loss=0.2932, Acc=0.8566


Epoch 12/20: 100%|██████████| 313/313 [00:02<00:00, 133.56it/s, loss=0.238, acc=0.878]


Epoch 12: Loss=0.2496, Acc=0.8782


Epoch 13/20: 100%|██████████| 313/313 [00:02<00:00, 131.43it/s, loss=0.742, acc=0.899]


Epoch 13: Loss=0.2122, Acc=0.8994


Epoch 14/20: 100%|██████████| 313/313 [00:02<00:00, 132.80it/s, loss=0.13, acc=0.917]  


Epoch 14: Loss=0.1777, Acc=0.9169


Epoch 15/20: 100%|██████████| 313/313 [00:02<00:00, 131.44it/s, loss=0.0978, acc=0.934]


Epoch 15: Loss=0.1487, Acc=0.9343


Epoch 16/20: 100%|██████████| 313/313 [00:02<00:00, 134.17it/s, loss=0.135, acc=0.947] 


Epoch 16: Loss=0.1229, Acc=0.9468


Epoch 17/20: 100%|██████████| 313/313 [00:02<00:00, 131.56it/s, loss=0.142, acc=0.956] 


Epoch 17: Loss=0.1055, Acc=0.9556


Epoch 18/20: 100%|██████████| 313/313 [00:02<00:00, 134.10it/s, loss=0.0303, acc=0.962]


Epoch 18: Loss=0.0933, Acc=0.9620


Epoch 19/20: 100%|██████████| 313/313 [00:02<00:00, 133.04it/s, loss=0.0871, acc=0.966]


Epoch 19: Loss=0.0863, Acc=0.9662


Epoch 20/20: 100%|██████████| 313/313 [00:02<00:00, 134.22it/s, loss=0.0639, acc=0.968]


Epoch 20: Loss=0.0827, Acc=0.9679
Starting to train model for mult-hard with vocab size 45
Moving model to device:  cuda


Epoch 1/20: 100%|██████████| 313/313 [00:02<00:00, 131.74it/s, loss=1.69, acc=0.25] 


Epoch 1: Loss=1.9272, Acc=0.2501


Epoch 2/20: 100%|██████████| 313/313 [00:02<00:00, 134.93it/s, loss=1.28, acc=0.311]


Epoch 2: Loss=1.6745, Acc=0.3111


Epoch 3/20: 100%|██████████| 313/313 [00:02<00:00, 133.98it/s, loss=0.696, acc=0.585]


Epoch 3: Loss=0.9569, Acc=0.5845


Epoch 4/20: 100%|██████████| 313/313 [00:02<00:00, 133.51it/s, loss=0.634, acc=0.696]


Epoch 4: Loss=0.6973, Acc=0.6960


Epoch 5/20: 100%|██████████| 313/313 [00:02<00:00, 132.57it/s, loss=0.298, acc=0.748]


Epoch 5: Loss=0.5825, Acc=0.7479


Epoch 6/20: 100%|██████████| 313/313 [00:02<00:00, 132.06it/s, loss=0.484, acc=0.772]


Epoch 6: Loss=0.5297, Acc=0.7721


Epoch 7/20: 100%|██████████| 313/313 [00:02<00:00, 132.89it/s, loss=0.393, acc=0.8]  


Epoch 7: Loss=0.4697, Acc=0.8000


Epoch 8/20: 100%|██████████| 313/313 [00:02<00:00, 134.22it/s, loss=0.748, acc=0.822]


Epoch 8: Loss=0.4187, Acc=0.8224


Epoch 9/20: 100%|██████████| 313/313 [00:02<00:00, 134.19it/s, loss=0.252, acc=0.835]


Epoch 9: Loss=0.3921, Acc=0.8354


Epoch 10/20: 100%|██████████| 313/313 [00:02<00:00, 135.50it/s, loss=0.485, acc=0.86] 


Epoch 10: Loss=0.3406, Acc=0.8600


Epoch 11/20: 100%|██████████| 313/313 [00:02<00:00, 132.01it/s, loss=0.292, acc=0.876]


Epoch 11: Loss=0.3058, Acc=0.8756


Epoch 12/20: 100%|██████████| 313/313 [00:02<00:00, 132.82it/s, loss=0.217, acc=0.891]


Epoch 12: Loss=0.2716, Acc=0.8908


Epoch 13/20: 100%|██████████| 313/313 [00:02<00:00, 133.41it/s, loss=0.161, acc=0.912] 


Epoch 13: Loss=0.2247, Acc=0.9122


Epoch 14/20: 100%|██████████| 313/313 [00:02<00:00, 134.60it/s, loss=0.207, acc=0.927] 


Epoch 14: Loss=0.1908, Acc=0.9274


Epoch 15/20: 100%|██████████| 313/313 [00:02<00:00, 134.18it/s, loss=0.134, acc=0.941] 


Epoch 15: Loss=0.1590, Acc=0.9411


Epoch 16/20: 100%|██████████| 313/313 [00:02<00:00, 130.32it/s, loss=0.0754, acc=0.953]


Epoch 16: Loss=0.1327, Acc=0.9526


Epoch 17/20: 100%|██████████| 313/313 [00:02<00:00, 132.72it/s, loss=0.116, acc=0.96]  


Epoch 17: Loss=0.1140, Acc=0.9604


Epoch 18/20: 100%|██████████| 313/313 [00:02<00:00, 133.36it/s, loss=0.123, acc=0.965] 


Epoch 18: Loss=0.1024, Acc=0.9650


Epoch 19/20: 100%|██████████| 313/313 [00:02<00:00, 134.52it/s, loss=0.0933, acc=0.969]


Epoch 19: Loss=0.0940, Acc=0.9689


Epoch 20/20: 100%|██████████| 313/313 [00:02<00:00, 133.34it/s, loss=0.187, acc=0.97]  


Epoch 20: Loss=0.0901, Acc=0.9698
Starting to train model for length-hard with vocab size 45
Moving model to device:  cuda


Epoch 1/20: 100%|██████████| 313/313 [00:02<00:00, 124.38it/s, loss=0.692, acc=0.494]


Epoch 1: Loss=0.7961, Acc=0.4942


Epoch 2/20: 100%|██████████| 313/313 [00:02<00:00, 125.89it/s, loss=0.697, acc=0.506]


Epoch 2: Loss=0.6967, Acc=0.5058


Epoch 3/20: 100%|██████████| 313/313 [00:02<00:00, 126.87it/s, loss=0.679, acc=0.507]


Epoch 3: Loss=0.6977, Acc=0.5068


Epoch 4/20: 100%|██████████| 313/313 [00:02<00:00, 127.85it/s, loss=0.689, acc=0.51] 


Epoch 4: Loss=0.6963, Acc=0.5100


Epoch 5/20: 100%|██████████| 313/313 [00:02<00:00, 127.14it/s, loss=0.682, acc=0.505]


Epoch 5: Loss=0.6967, Acc=0.5053


Epoch 6/20: 100%|██████████| 313/313 [00:02<00:00, 126.64it/s, loss=0.704, acc=0.51] 


Epoch 6: Loss=0.6948, Acc=0.5105


Epoch 7/20: 100%|██████████| 313/313 [00:02<00:00, 126.25it/s, loss=0.683, acc=0.511]


Epoch 7: Loss=0.6947, Acc=0.5109


Epoch 8/20: 100%|██████████| 313/313 [00:02<00:00, 126.43it/s, loss=0.694, acc=0.511]


Epoch 8: Loss=0.6947, Acc=0.5111


Epoch 9/20: 100%|██████████| 313/313 [00:02<00:00, 125.98it/s, loss=0.695, acc=0.513]


Epoch 9: Loss=0.6933, Acc=0.5131


Epoch 10/20: 100%|██████████| 313/313 [00:02<00:00, 126.80it/s, loss=0.697, acc=0.52] 


Epoch 10: Loss=0.6920, Acc=0.5197


Epoch 11/20: 100%|██████████| 313/313 [00:02<00:00, 126.34it/s, loss=0.697, acc=0.528]


Epoch 11: Loss=0.6908, Acc=0.5278


Epoch 12/20: 100%|██████████| 313/313 [00:02<00:00, 125.29it/s, loss=0.712, acc=0.531]


Epoch 12: Loss=0.6892, Acc=0.5305


Epoch 13/20: 100%|██████████| 313/313 [00:02<00:00, 126.84it/s, loss=0.697, acc=0.537]


Epoch 13: Loss=0.6873, Acc=0.5374


Epoch 14/20: 100%|██████████| 313/313 [00:02<00:00, 127.18it/s, loss=0.672, acc=0.541]


Epoch 14: Loss=0.6854, Acc=0.5406


Epoch 15/20: 100%|██████████| 313/313 [00:02<00:00, 128.52it/s, loss=0.691, acc=0.548]


Epoch 15: Loss=0.6829, Acc=0.5479


Epoch 16/20: 100%|██████████| 313/313 [00:02<00:00, 125.82it/s, loss=0.71, acc=0.552] 


Epoch 16: Loss=0.6812, Acc=0.5524


Epoch 17/20: 100%|██████████| 313/313 [00:02<00:00, 127.72it/s, loss=0.642, acc=0.555]


Epoch 17: Loss=0.6798, Acc=0.5547


Epoch 18/20: 100%|██████████| 313/313 [00:02<00:00, 125.73it/s, loss=0.688, acc=0.557]


Epoch 18: Loss=0.6789, Acc=0.5572


Epoch 19/20: 100%|██████████| 313/313 [00:02<00:00, 125.79it/s, loss=0.637, acc=0.561]


Epoch 19: Loss=0.6779, Acc=0.5606


Epoch 20/20: 100%|██████████| 313/313 [00:02<00:00, 126.68it/s, loss=0.693, acc=0.563]


Epoch 20: Loss=0.6777, Acc=0.5631
Starting to train model for all-hard with vocab size 3588
Moving model to device:  cuda


Epoch 1/20: 100%|██████████| 313/313 [00:02<00:00, 128.00it/s, loss=2.05, acc=0.134]


Epoch 1: Loss=2.3690, Acc=0.1335


Epoch 2/20: 100%|██████████| 313/313 [00:02<00:00, 131.42it/s, loss=2.02, acc=0.196]


Epoch 2: Loss=2.0118, Acc=0.1965


Epoch 3/20: 100%|██████████| 313/313 [00:02<00:00, 128.26it/s, loss=1.91, acc=0.262]


Epoch 3: Loss=1.8939, Acc=0.2620


Epoch 4/20: 100%|██████████| 313/313 [00:02<00:00, 131.25it/s, loss=1.75, acc=0.319]


Epoch 4: Loss=1.7733, Acc=0.3190


Epoch 5/20: 100%|██████████| 313/313 [00:02<00:00, 130.21it/s, loss=1.62, acc=0.396]


Epoch 5: Loss=1.5942, Acc=0.3958


Epoch 6/20: 100%|██████████| 313/313 [00:02<00:00, 130.28it/s, loss=1.67, acc=0.48] 


Epoch 6: Loss=1.4006, Acc=0.4804


Epoch 7/20: 100%|██████████| 313/313 [00:02<00:00, 130.66it/s, loss=1.3, acc=0.551]  


Epoch 7: Loss=1.2213, Acc=0.5513


Epoch 8/20: 100%|██████████| 313/313 [00:02<00:00, 130.11it/s, loss=0.988, acc=0.617]


Epoch 8: Loss=1.0552, Acc=0.6167


Epoch 9/20: 100%|██████████| 313/313 [00:02<00:00, 129.70it/s, loss=0.784, acc=0.677]


Epoch 9: Loss=0.9060, Acc=0.6766


Epoch 10/20: 100%|██████████| 313/313 [00:02<00:00, 130.52it/s, loss=0.762, acc=0.736]


Epoch 10: Loss=0.7519, Acc=0.7356


Epoch 11/20: 100%|██████████| 313/313 [00:02<00:00, 128.49it/s, loss=0.504, acc=0.792]


Epoch 11: Loss=0.6059, Acc=0.7920


Epoch 12/20: 100%|██████████| 313/313 [00:02<00:00, 130.43it/s, loss=0.757, acc=0.845]


Epoch 12: Loss=0.4709, Acc=0.8445


Epoch 13/20: 100%|██████████| 313/313 [00:02<00:00, 129.59it/s, loss=0.407, acc=0.891]


Epoch 13: Loss=0.3518, Acc=0.8908


Epoch 14/20: 100%|██████████| 313/313 [00:02<00:00, 129.91it/s, loss=0.479, acc=0.92]  


Epoch 14: Loss=0.2619, Acc=0.9204


Epoch 15/20: 100%|██████████| 313/313 [00:02<00:00, 128.61it/s, loss=0.0586, acc=0.951]


Epoch 15: Loss=0.1847, Acc=0.9512


Epoch 16/20: 100%|██████████| 313/313 [00:02<00:00, 128.09it/s, loss=0.1, acc=0.967]   


Epoch 16: Loss=0.1344, Acc=0.9667


Epoch 17/20: 100%|██████████| 313/313 [00:02<00:00, 129.44it/s, loss=0.116, acc=0.976] 


Epoch 17: Loss=0.1035, Acc=0.9762


Epoch 18/20: 100%|██████████| 313/313 [00:02<00:00, 131.01it/s, loss=0.0778, acc=0.981]


Epoch 18: Loss=0.0843, Acc=0.9814


Epoch 19/20: 100%|██████████| 313/313 [00:02<00:00, 130.44it/s, loss=0.0134, acc=0.984]


Epoch 19: Loss=0.0740, Acc=0.9842


Epoch 20/20: 100%|██████████| 313/313 [00:02<00:00, 128.27it/s, loss=0.149, acc=0.986]  


Epoch 20: Loss=0.0694, Acc=0.9857
Starting to train model for mixed with vocab size 3785
Moving model to device:  cuda


Epoch 1/20: 100%|██████████| 313/313 [00:02<00:00, 122.57it/s, loss=2.19, acc=0.188]


Epoch 1: Loss=2.4493, Acc=0.1875


Epoch 2/20: 100%|██████████| 313/313 [00:02<00:00, 123.51it/s, loss=2.21, acc=0.205]


Epoch 2: Loss=2.1516, Acc=0.2051


Epoch 3/20: 100%|██████████| 313/313 [00:02<00:00, 125.88it/s, loss=1.87, acc=0.227]


Epoch 3: Loss=2.0724, Acc=0.2271


Epoch 4/20: 100%|██████████| 313/313 [00:02<00:00, 123.63it/s, loss=1.96, acc=0.282]


Epoch 4: Loss=1.9507, Acc=0.2816


Epoch 5/20: 100%|██████████| 313/313 [00:02<00:00, 122.04it/s, loss=1.7, acc=0.349] 


Epoch 5: Loss=1.7770, Acc=0.3488


Epoch 6/20: 100%|██████████| 313/313 [00:02<00:00, 125.72it/s, loss=1.38, acc=0.417]


Epoch 6: Loss=1.5929, Acc=0.4168


Epoch 7/20: 100%|██████████| 313/313 [00:02<00:00, 119.13it/s, loss=1.71, acc=0.473]


Epoch 7: Loss=1.4428, Acc=0.4726


Epoch 8/20: 100%|██████████| 313/313 [00:02<00:00, 116.60it/s, loss=1.2, acc=0.52]  


Epoch 8: Loss=1.3206, Acc=0.5198


Epoch 9/20: 100%|██████████| 313/313 [00:02<00:00, 113.98it/s, loss=1.62, acc=0.546] 


Epoch 9: Loss=1.2382, Acc=0.5464


Epoch 10/20: 100%|██████████| 313/313 [00:02<00:00, 107.90it/s, loss=1.24, acc=0.572] 


Epoch 10: Loss=1.1692, Acc=0.5717


Epoch 11/20: 100%|██████████| 313/313 [00:02<00:00, 116.68it/s, loss=1.26, acc=0.582] 


Epoch 11: Loss=1.1321, Acc=0.5816


Epoch 12/20: 100%|██████████| 313/313 [00:02<00:00, 124.69it/s, loss=1.18, acc=0.594] 


Epoch 12: Loss=1.1008, Acc=0.5936


Epoch 13/20: 100%|██████████| 313/313 [00:02<00:00, 122.93it/s, loss=1.1, acc=0.6]    


Epoch 13: Loss=1.0786, Acc=0.5996


Epoch 14/20: 100%|██████████| 313/313 [00:02<00:00, 124.51it/s, loss=1.2, acc=0.605]  


Epoch 14: Loss=1.0611, Acc=0.6047


Epoch 15/20: 100%|██████████| 313/313 [00:02<00:00, 122.90it/s, loss=0.529, acc=0.611]


Epoch 15: Loss=1.0473, Acc=0.6108


Epoch 16/20: 100%|██████████| 313/313 [00:02<00:00, 120.53it/s, loss=0.799, acc=0.612]


Epoch 16: Loss=1.0380, Acc=0.6122


Epoch 17/20: 100%|██████████| 313/313 [00:02<00:00, 123.38it/s, loss=0.79, acc=0.616] 


Epoch 17: Loss=1.0289, Acc=0.6156


Epoch 18/20: 100%|██████████| 313/313 [00:02<00:00, 124.12it/s, loss=1.04, acc=0.618] 


Epoch 18: Loss=1.0229, Acc=0.6182


Epoch 19/20: 100%|██████████| 313/313 [00:02<00:00, 124.74it/s, loss=0.953, acc=0.619]


Epoch 19: Loss=1.0182, Acc=0.6194


Epoch 20/20: 100%|██████████| 313/313 [00:02<00:00, 121.85it/s, loss=0.949, acc=0.62] 


Epoch 20: Loss=1.0153, Acc=0.6203


In [48]:
def evaluate_model(model, dataloader, device='cuda'):
    model.eval()
    total = 0
    correct = 0
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            loss_mask = batch["loss_mask"].to(device)
            
            # Forward pass
            logits = model(input_ids[:, :-1])
            targets = input_ids[:, 1:]
            mask = loss_mask[:, 1:]
            
            # Get predictions
            preds = logits.argmax(dim=-1)
            
            # Identify valid rows
            per_row = mask.sum(dim=1)
            valid_rows = (per_row == 1)
            
            if not valid_rows.any():
                continue   # skip batches with no valid rows
            
            # Filter only valid rows
            mask_valid = mask[valid_rows].bool()
            preds_valid = preds[valid_rows]
            targets_valid = targets[valid_rows]
            
            # Extract predictions and true labels at masked positions
            pred_answers = preds_valid[mask_valid]
            true_answers = targets_valid[mask_valid]
            
            correct += (pred_answers == true_answers).sum().item()
            total += pred_answers.numel()
            
    accuracy = correct / total
    return accuracy

# Test
for name in train_dataset_names:
    accuracy = evaluate_model(models[name], DataLoader(
        test_datasets[name],
        batch_size=64,
        shuffle=False,
        collate_fn=collate_fn
    ))
    print(f"Testing model {name} accuracy on {name}: {accuracy}")

Testing model easy accuracy on easy: 1.0
Testing model bpe-hard accuracy on bpe-hard: 0.5165
Testing model mult-hard accuracy on mult-hard: 0.949
Testing model length-hard accuracy on length-hard: 0.5105
Testing model all-hard accuracy on all-hard: 0.1565


KeyError: 'mixed'