In [1]:
import torch
import torch.nn as nn
from transformer_lens import HookedTransformer, HookedTransformerConfig

# Configuration
config = HookedTransformerConfig(
    n_layers=2,
    n_heads=8,
    d_model=128,
    d_head=16,  # d_model / n_heads
    d_mlp=None,  # No MLPs (attention-only)
    act_fn=None,  # No activation (no MLPs)
    attention_dir="causal",  # Causal attention
    attn_only=True,  # Attention-only model
    normalization_type=None,  # No LayerNorm for simplicity
    d_vocab=50,  # 26 letters + 10 digits + special tokens
    n_ctx=60,  # Max sequence length
    init_weights=True,
    device="cuda" if torch.cuda.is_available() else "cpu"
)

model = HookedTransformer(config)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Model device: {next(model.parameters()).device}")
print(f"Model device name: {torch.cuda.get_device_name(0)}")

Model parameters: 152,626
Model device: cuda:0
Model device name: NVIDIA GeForce RTX 4080 SUPER


In [2]:
class Vocabulary:
    class TrieNode:
        def __init__(self):
            self.id = None
            self.next = {}

    def __init__(self):
        self.root = self.TrieNode()
        self.token_map = {} # Stores mapping of ID to token
        self.size = 0

    # Adds a new token into the vocabulary
    def add_token(self, token):
        node = self.root
        for c in token:
            if c not in node.next:
                node.next[c] = self.TrieNode()
            node = node.next[c]
        if node.id is None:
            node.id = self.size
            self.token_map[self.size] = token
            self.size += 1

    # Finds id of longest prefix token of text[start:end], and returns length of token
    def longest_prefix_token(self, text, start):
        longest_token = None
        longest_length = 0
        node = self.root
        for i in range(start, len(text)):
            if text[i] not in node.next:
                break
            node = node.next[text[i]]
            if node.id is not None:
                longest_token = node.id
                longest_length = i - start + 1
        assert longest_token is not None
        return longest_token, longest_length

    # Converts an id to the corresponding token
    def get_token(self, id):
        return self.token_map[id]

In [3]:
# Simple character-level tokenizer
class CountingTokenizer:
    def __init__(self):
        # Vocabulary: letters + digits + special tokens
        self.vocab = Vocabulary()
        chars = list("abcdefghijklmnopqrstuvwxyz0123456789")
        special = ["<PAD>", "<BOS>", "<EOS>", ":", " ", "Count", "the", "letter", "in"]
        raw_vocab = special + chars
        for token in raw_vocab:
            self.vocab.add_token(token)

    def encode(self, text, include_lengths = False):
        """Convert text to token IDs"""
        ids = []
        i = 0
        while i < len(text):
            id, token_length = self.vocab.longest_prefix_token(text, i)
            assert id != -1
            if include_lengths:
                ids.append((id, token_length))
            else:
                ids.append(id)
            i += token_length
        return ids
    
    def decode(self, ids):
        """Convert token IDs to text"""
        return "".join([self.vocab.get_token(id) for id in ids])

    def apply_bpe(self, words, max_token_length=3):
        """Adds merge rules based on a list of words for BPE"""
        text = "".join([f"<BOS>{word}<EOS>" for word in words])
        ignore_tokens = ["<PAD>", "<BOS>", "<EOS>", ":", " "]
        while True:
            encoded = self.encode(text, include_lengths=True)
            pairs = {}
            merge_pair = ()
            for i in range(len(encoded) - 1):
                token_pair = encoded[i], encoded[i + 1]
                if token_pair[0][1] + token_pair[1][1] > max_token_length:
                    continue
                if any(self.vocab.get_token(token[0]) in ignore_tokens for token in token_pair):
                    continue
                pairs[token_pair] = pairs.get(token_pair, 0) + 1
                if not merge_pair or pairs[merge_pair] < pairs[token_pair]:
                    merge_pair = token_pair
            if not merge_pair or pairs[merge_pair] < 2:
                break
            self.vocab.add_token("".join([self.vocab.get_token(token[0]) for token in merge_pair]))

tokenizer = CountingTokenizer()

In [4]:
import random

def generate_counting_question(target_letter='a', multiplicity_range=(1, 2), length_range=(5, 10)):
    """
    Generates a new word based on input parameters and gives the answer to the question
    """
    
    # Sample words with target letter
    words_with_target = []
    count = random.randint(*multiplicity_range)
    
    # Generate string with exact count of target letter
    chars = list("abcdefghijklmnopqrstuvwxyz")
    chars.remove(target_letter)
    
    length = random.randint(max(count, length_range[0]), length_range[1])
    string_chars = random.choices(chars, k=length - count)
    
    # Insert target letters
    positions = random.sample(range(length), count)
    for pos in positions:
        string_chars.insert(pos, target_letter)
    
    input_string = "".join(string_chars[:length])
    question = f"Count the letter {target_letter} in: {input_string}"
    answer = str(count)

    return question, answer

def generate_counting_example(qa, tokenizer=None):
    """
    Generate: "Count the letter a in: banana" -> "3"
    Format: [question tokens] [answer token]
    """
    question, answer = qa
    
    # Tokenize
    question_tokens = tokenizer.encode(question)
    question_tokens_decoded = [tokenizer.decode([token]) for token in question_tokens]
    answer_token = tokenizer.encode(answer)[0]  # Single digit
    
    # Combine: question + answer
    full_tokens = question_tokens + [answer_token]
    
    return {
        'tokens': full_tokens,
        'question_tokens_decoded': question_tokens_decoded,
        'question_length': len(question_tokens),  # For loss masking
        'answer': int(answer),
        'text': question + answer
    }

# Test
tokenizer = CountingTokenizer()
qa = generate_counting_question()
example = generate_counting_example(qa, tokenizer=tokenizer)
print(f"Text: {example['text']}")
print(f"Tokens: {example['tokens']}")
print(f"Question Tokens Decoded: {example['question_tokens_decoded']}")
print(f"Question length: {example['question_length']}")
print(f"Answer: {example['answer']}")

Text: Count the letter a in: avekjaj2
Tokens: [5, 4, 6, 4, 7, 4, 9, 4, 8, 3, 4, 9, 30, 13, 19, 18, 9, 18, 37]
Question Tokens Decoded: ['Count', ' ', 'the', ' ', 'letter', ' ', 'a', ' ', 'in', ':', ' ', 'a', 'v', 'e', 'k', 'j', 'a', 'j']
Question length: 18
Answer: 2


In [5]:
# Test
qas = [generate_counting_question() for i in range(10)]
tokenizer = CountingTokenizer()
tokenizer.apply_bpe([qa[0] for qa in qas])
qa = generate_counting_question()
example = generate_counting_example(qa, tokenizer=tokenizer)
print(f"Sample Pre-Processing Text: {"".join([f"<BOS>{qa[0]}<EOS>" for qa in qas])}")
print(f"Text: {example['text']}")
print(f"Tokens: {example['tokens']}")
print(f"Question Tokens Decoded: {example['question_tokens_decoded']}")
print(f"Question length: {example['question_length']}")
print(f"Answer: {example['answer']}")

Sample Pre-Processing Text: <BOS>Count the letter a in: dvdabc<EOS><BOS>Count the letter a in: afzceac<EOS><BOS>Count the letter a in: yonchwdash<EOS><BOS>Count the letter a in: papmabm<EOS><BOS>Count the letter a in: yoqgam<EOS><BOS>Count the letter a in: ajqna<EOS><BOS>Count the letter a in: kagvhda<EOS><BOS>Count the letter a in: hztavv<EOS><BOS>Count the letter a in: afxae<EOS><BOS>Count the letter a in: auppacfb<EOS>
Text: Count the letter a in: fakfwih1
Tokens: [5, 4, 6, 4, 7, 4, 9, 4, 8, 3, 4, 14, 9, 19, 14, 31, 17, 16, 36]
Question Tokens Decoded: ['Count', ' ', 'the', ' ', 'letter', ' ', 'a', ' ', 'in', ':', ' ', 'f', 'a', 'k', 'f', 'w', 'i', 'h']
Question length: 18
Answer: 1


In [6]:
import pickle
from torch.utils.data import Dataset, DataLoader


class CountingDataset(Dataset):
    def __init__(self, n_examples=50000, difficulty='easy', tokenizer=None, allow_bpe=True, train_bpe=True):
        """
        difficulty: 'easy', 'bpe-hard', 'mult-hard', etc.
        """
        self.tokenizer = tokenizer
        self.examples = []

        use_bpe = False
        # Set parameters based on difficulty
        if difficulty == 'easy':
            mult_range = (1, 2)
            len_range = (5, 10)
        elif difficulty == 'bpe-hard':
            mult_range = (1, 2)
            len_range = (5, 10)
            # use_bpe = True
        elif difficulty == 'mult-hard':
            mult_range = (3, 10)
            len_range = (5, 10)
        elif difficulty == 'length-hard':
            mult_range = (1, 2)
            len_range = (20, 50)
        elif difficulty == 'all-hard':
            mult_range = (3, 10)
            len_range = (20, 50)
            # use_bpe = True
        elif difficulty == 'mixed':
            mult_range = (1,10)
            len_range = (5, 50)
            # use_bpe = True
        else:
            assert False
        
        # Generate basic words
        target_letters = list("abcdefghijklmnopqrstuvwxyz")
        qas = []
        for _ in range(n_examples):
            target = random.choice(target_letters)
            qa = generate_counting_question(
                target_letter=target,
                multiplicity_range=mult_range,
                length_range=len_range,
            )
            qas.append(qa)

        # Add basic words to the vocabulary if we are applying BPE
        if use_bpe and allow_bpe and train_bpe:
            tokenizer.apply_bpe([qa[0] for qa in qas])

        # Generate the full questions and examples
        basic_tokenizer = CountingTokenizer()
        if difficulty == "mixed":
            bpe_set = random.sample(range(len(qas)), len(qas) // 2)
        else:
            bpe_set = range(len(qas))
        for i in range(len(qas)):
            example = generate_counting_example(
                qas[i],
                tokenizer=tokenizer if i in bpe_set else basic_tokenizer
            )
            self.examples.append(example)
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        return self.examples[idx]

def collate_fn(batch, pad_id=0, max_len=100):
    """Pad sequences to same length"""
    # Pad tokens
    tokens = [ex['tokens'] for ex in batch]
    max_batch_len = min(max(len(t) for t in tokens), max_len)
    
    padded_tokens = []
    masks = []  # Loss mask: 1 for answer token, 0 elsewhere
    
    for ex in batch:
        seq = ex['tokens'][:max_batch_len]
        q_len = min(ex['question_length'], max_batch_len - 1)
        
        # Pad sequence
        padded = seq + [pad_id] * (max_batch_len - len(seq))
        padded_tokens.append(padded)
        
        # Create mask: only compute loss on answer token
        mask = [0] * max_batch_len
        if q_len < len(seq):  # If answer token exists
            mask[q_len] = 1  # Answer is right after question
        masks.append(mask)
    
    return {
        'input_ids': torch.tensor(padded_tokens, dtype=torch.long),
        'loss_mask': torch.tensor(masks, dtype=torch.float),
        'answers': torch.tensor([ex['answer'] for ex in batch], dtype=torch.long)
    }


In [7]:
# Create dataloaders
train_dataset_names = ["easy", "bpe-hard", "mult-hard", "length-hard", "all-hard"]
GENERATE_NEW_TRAINING_DATA = True # IMPORTANT: Set this to True only if you want to generate new datasets

# Training dataloaders
train_datasets = {}
train_loaders = {}
train_tokenizers = {}
for name in train_dataset_names:
    if GENERATE_NEW_TRAINING_DATA:
        train_tokenizers[name] = CountingTokenizer()
        train_datasets[name] = CountingDataset(n_examples=100000, difficulty=name, tokenizer=train_tokenizers[name])
        with open(f"train-{name}-dataset.pkl", "wb") as f:
            pickle.dump(train_datasets[name], f)
        with open(f"train-{name}-tokenizer.pkl", "wb") as f:
            pickle.dump(train_tokenizers[name], f)
    else:
        with open(f"train-{name}-dataset.pkl", "rb") as f:
            train_datasets[name] = pickle.load(f)
        with open(f"train-{name}-tokenizer.pkl", "rb") as f:
            train_tokenizers[name] = pickle.load(f)
            print(train_tokenizers[name].vocab.size)
    train_loaders[name] = DataLoader(train_datasets[name], batch_size=64, shuffle=True, collate_fn=collate_fn)

# Test batch
batch = next(iter(train_loaders["bpe-hard"]))
print(f"Input shape: {batch['input_ids'].shape}")
print(f"Mask shape: {batch['loss_mask'].shape}")
print(f"Example tokens: {batch['input_ids'][0]}")
print(f"Example mask: {batch['loss_mask'][0]}")

Input shape: torch.Size([64, 22])
Mask shape: torch.Size([64, 22])
Example tokens: tensor([ 5,  4,  6,  4,  7,  4, 15,  4,  8,  3,  4, 34, 33, 30, 15, 15, 23, 14,
        37,  0,  0,  0])
Example mask: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        1., 0., 0., 0.])


In [8]:
GENERATE_NEW_TESTING_DATA = True # IMPORTANT: Set this to True only if you want to generate new datasets

# Testing dataloaders
test_dataset_names = ["easy", "bpe-hard", "mult-hard", "length-hard", "all-hard"]
test_datasets = {}
test_tokenizers = {}
for name in test_dataset_names:
    if GENERATE_NEW_TESTING_DATA:
        test_datasets[name] = CountingDataset(n_examples=10000, difficulty=name, tokenizer=train_tokenizers[name], allow_bpe=True, train_bpe=False)
        with open(f"test-{name}-dataset.pkl", "wb") as f:
            pickle.dump(test_datasets[name], f)
    else:
        with open(f"test-{name}-dataset.pkl", "rb") as f:
            test_datasets[name] = pickle.load(f)

In [None]:
import torch.optim as optim
from tqdm import tqdm

RETRAIN = True # IMPORTANT: Only toggle to True if you want to actually train new models

def train_model(model, name, train_loader, n_epochs=10, lr=1e-3, device='cuda'):
    """
    Train with MASKED loss (only on answer token)
    
    Classification loss: CrossEntropy on vocabulary
    (Can also try regression loss on digit value)
    """
    model = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
    
    # Classification loss (predict token ID)
    criterion = nn.CrossEntropyLoss(reduction='none')  # Per-token loss
    
    for epoch in range(n_epochs):
        model.train()
        total_loss = 0
        correct = 0
        total = 0
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{n_epochs}")
        for batch in pbar:
            input_ids = batch['input_ids'].to(device)  # [batch, seq_len]
            loss_mask = batch['loss_mask'].to(device)  # [batch, seq_len]
            
            # Forward pass
            # Input: all tokens except last
            # Target: all tokens except first (shifted by 1)
            logits = model(input_ids[:, :-1])  # [batch, seq_len-1, vocab]
            targets = input_ids[:, 1:]  # [batch, seq_len-1]
            
            # Compute loss
            loss_per_token = criterion(
                logits.reshape(-1, logits.size(-1)),  # [batch*(seq_len-1), vocab]
                targets.reshape(-1)  # [batch*(seq_len-1)]
            )
            loss_per_token = loss_per_token.reshape(targets.shape)  # [batch, seq_len-1]
            
            # Apply mask: only compute loss on answer token
            mask = loss_mask[:, 1:]  # Align with targets
            masked_loss = (loss_per_token * mask).sum() / mask.sum()
            
            # Backward pass
            optimizer.zero_grad()
            masked_loss.backward()
            optimizer.step()
            
            # Metrics
            total_loss += masked_loss.item()
            
            # Accuracy: check if predicted answer digit is correct
            preds = logits.argmax(dim=-1)  # [batch, seq_len-1]
            answer_positions = mask.bool()
            if answer_positions.any():
                correct += (preds[answer_positions] == targets[answer_positions]).sum().item()
                total += answer_positions.sum().item()
            
            pbar.set_postfix({'loss': masked_loss.item(), 
                            'acc': correct/total if total > 0 else 0})
        
        scheduler.step()
        
        print(f"Epoch {epoch+1}: Loss={total_loss/len(train_loader):.4f}, "
              f"Acc={correct/total:.4f}")
        
        # Save checkpoint
        if (epoch + 1) % 5 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, f'checkpoint-{name}-epoch-{epoch+1}.pt')
    
    return model

# Train!
models = {}
for name in train_dataset_names:
    config = HookedTransformerConfig(
            n_layers=2,
            n_heads=8,
            d_model=256,
            d_head=16,  # d_model / n_heads
            d_mlp=None,  # No MLPs (attention-only)
            act_fn=None,  # No activation (no MLPs)
            attention_dir="causal",  # Causal attention
            attn_only=True,  # Attention-only model
            normalization_type=None,  # No LayerNorm for simplicity
            d_vocab=train_tokenizers[name].vocab.size + 5,  # Total vocab size for particular tokenizer
            n_ctx=100,  # Max sequence length
            init_weights=True,
            device="cuda" if torch.cuda.is_available() else "cpu"
        )
    model = HookedTransformer(config)
    if RETRAIN:
        print(f"Starting to train model for {name} with vocab size {train_tokenizers[name].vocab.size}")
        models[name] = train_model(model, name, train_loaders[name], n_epochs=20, lr=3e-4)
    else:
        checkpoint = torch.load(f'checkpoint-{name}-epoch-40.pt', map_location="cpu")
        
        model.load_state_dict(checkpoint["model_state_dict"])
        models[name] = model
        

Starting to train model for easy with vocab size 45
Moving model to device:  cuda


Epoch 1/20: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 296.39it/s, loss=0.669, acc=0.499]


Epoch 1: Loss=0.7359, Acc=0.4995


Epoch 2/20: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 314.92it/s, loss=0.71, acc=0.501]


Epoch 2: Loss=0.6961, Acc=0.5013


Epoch 3/20: 100%|█████████████████████████████| 1563/1563 [00:04<00:00, 321.24it/s, loss=0.00534, acc=0.678]


Epoch 3: Loss=0.5064, Acc=0.6780


Epoch 4/20: 100%|████████████████████████████| 1563/1563 [00:04<00:00, 318.27it/s, loss=0.000526, acc=0.996]


Epoch 4: Loss=0.0193, Acc=0.9959


Epoch 5/20: 100%|█████████████████████████████| 1563/1563 [00:04<00:00, 317.42it/s, loss=0.00139, acc=0.997]


Epoch 5: Loss=0.0138, Acc=0.9966


Epoch 6/20: 100%|█████████████████████████████| 1563/1563 [00:04<00:00, 319.07it/s, loss=0.00138, acc=0.998]


Epoch 6: Loss=0.0092, Acc=0.9981


Epoch 7/20: 100%|████████████████████████████| 1563/1563 [00:04<00:00, 319.15it/s, loss=0.000288, acc=0.999]


Epoch 7: Loss=0.0074, Acc=0.9985


Epoch 8/20: 100%|█████████████████████████████| 1563/1563 [00:05<00:00, 307.81it/s, loss=7.86e-5, acc=0.999]


Epoch 8: Loss=0.0050, Acc=0.9989


Epoch 9/20: 100%|████████████████████████████| 1563/1563 [00:05<00:00, 309.81it/s, loss=0.000267, acc=0.999]


Epoch 9: Loss=0.0029, Acc=0.9994


Epoch 10/20: 100%|████████████████████████████████| 1563/1563 [00:05<00:00, 311.63it/s, loss=0.00318, acc=1]


Epoch 10: Loss=0.0020, Acc=0.9996


Epoch 11/20: 100%|████████████████████████████████| 1563/1563 [00:05<00:00, 309.93it/s, loss=1.35e-5, acc=1]


Epoch 11: Loss=0.0016, Acc=0.9996


Epoch 12/20: 100%|████████████████████████████████| 1563/1563 [00:05<00:00, 308.79it/s, loss=4.12e-6, acc=1]


Epoch 12: Loss=0.0007, Acc=0.9998


Epoch 13/20: 100%|████████████████████████████████| 1563/1563 [00:05<00:00, 303.48it/s, loss=7.66e-6, acc=1]


Epoch 13: Loss=0.0008, Acc=0.9998


Epoch 14/20: 100%|████████████████████████████████| 1563/1563 [00:05<00:00, 305.15it/s, loss=2.13e-6, acc=1]


Epoch 14: Loss=0.0004, Acc=0.9999


Epoch 15/20: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 307.56it/s, loss=0.000211, acc=1]


Epoch 15: Loss=0.0002, Acc=1.0000


Epoch 16/20: 100%|████████████████████████████████| 1563/1563 [00:05<00:00, 306.52it/s, loss=4.48e-6, acc=1]


Epoch 16: Loss=0.0002, Acc=0.9999


Epoch 17/20: 100%|████████████████████████████████| 1563/1563 [00:05<00:00, 311.39it/s, loss=8.07e-5, acc=1]


Epoch 17: Loss=0.0001, Acc=1.0000


Epoch 18/20: 100%|████████████████████████████████| 1563/1563 [00:05<00:00, 308.89it/s, loss=1.27e-6, acc=1]


Epoch 18: Loss=0.0000, Acc=1.0000


Epoch 19/20: 100%|████████████████████████████████| 1563/1563 [00:05<00:00, 312.37it/s, loss=6.46e-6, acc=1]


Epoch 19: Loss=0.0000, Acc=1.0000


Epoch 20/20: 100%|████████████████████████████████| 1563/1563 [00:05<00:00, 312.28it/s, loss=1.26e-6, acc=1]


Epoch 20: Loss=0.0000, Acc=1.0000
Starting to train model for bpe-hard with vocab size 45
Moving model to device:  cuda


Epoch 1/20: 100%|█████████████████████████████████| 1563/1563 [00:05<00:00, 312.50it/s, loss=0.703, acc=0.5]


Epoch 1: Loss=0.7356, Acc=0.5001


Epoch 2/20: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 312.21it/s, loss=0.692, acc=0.503]


Epoch 2: Loss=0.6958, Acc=0.5031


Epoch 3/20: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 309.68it/s, loss=0.0565, acc=0.64]


Epoch 3: Loss=0.5963, Acc=0.6402


Epoch 4/20: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 310.34it/s, loss=0.0545, acc=0.984]


Epoch 4: Loss=0.0497, Acc=0.9842


Epoch 5/20: 100%|█████████████████████████████| 1563/1563 [00:05<00:00, 312.34it/s, loss=0.00399, acc=0.995]


Epoch 5: Loss=0.0190, Acc=0.9946


Epoch 6/20: 100%|████████████████████████████| 1563/1563 [00:04<00:00, 318.66it/s, loss=0.000658, acc=0.996]


Epoch 6: Loss=0.0146, Acc=0.9961


Epoch 7/20: 100%|██████████████████████████████| 1563/1563 [00:04<00:00, 320.56it/s, loss=0.0747, acc=0.998]


Epoch 7: Loss=0.0096, Acc=0.9976


Epoch 8/20: 100%|█████████████████████████████| 1563/1563 [00:04<00:00, 320.21it/s, loss=0.00143, acc=0.997]


Epoch 8: Loss=0.0106, Acc=0.9970


Epoch 9/20: 100%|█████████████████████████████| 1563/1563 [00:04<00:00, 318.82it/s, loss=0.00454, acc=0.998]


Epoch 9: Loss=0.0071, Acc=0.9982


Epoch 10/20: 100%|███████████████████████████| 1563/1563 [00:04<00:00, 321.00it/s, loss=0.000826, acc=0.999]


Epoch 10: Loss=0.0058, Acc=0.9986


Epoch 11/20: 100%|██████████████████████████████| 1563/1563 [00:04<00:00, 324.50it/s, loss=0.126, acc=0.999]


Epoch 11: Loss=0.0040, Acc=0.9992


Epoch 12/20: 100%|████████████████████████████| 1563/1563 [00:04<00:00, 320.65it/s, loss=0.00226, acc=0.999]


Epoch 12: Loss=0.0030, Acc=0.9995


Epoch 13/20: 100%|████████████████████████████| 1563/1563 [00:04<00:00, 321.84it/s, loss=2.02e-5, acc=0.999]


Epoch 13: Loss=0.0031, Acc=0.9993


Epoch 14/20: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 329.22it/s, loss=5.01e-5, acc=1]


Epoch 14: Loss=0.0019, Acc=0.9997


Epoch 15/20: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 318.19it/s, loss=3.21e-5, acc=1]


Epoch 15: Loss=0.0013, Acc=0.9999


Epoch 16/20: 100%|███████████████████████████████████| 1563/1563 [00:04<00:00, 320.58it/s, loss=2e-5, acc=1]


Epoch 16: Loss=0.0010, Acc=0.9999


Epoch 17/20: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 323.89it/s, loss=2.19e-5, acc=1]


Epoch 17: Loss=0.0006, Acc=1.0000


Epoch 18/20: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 323.31it/s, loss=4.39e-5, acc=1]


Epoch 18: Loss=0.0006, Acc=0.9999


Epoch 19/20: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 322.33it/s, loss=1.73e-5, acc=1]


Epoch 19: Loss=0.0004, Acc=1.0000


Epoch 20/20: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 324.82it/s, loss=1.41e-5, acc=1]


Epoch 20: Loss=0.0004, Acc=1.0000
Starting to train model for mult-hard with vocab size 45
Moving model to device:  cuda


Epoch 1/20: 100%|███████████████████████████████| 1563/1563 [00:04<00:00, 321.52it/s, loss=0.954, acc=0.298]


Epoch 1: Loss=1.7476, Acc=0.2985


Epoch 2/20: 100%|███████████████████████████████| 1563/1563 [00:04<00:00, 326.90it/s, loss=0.802, acc=0.639]


Epoch 2: Loss=0.8203, Acc=0.6391


Epoch 3/20: 100%|███████████████████████████████| 1563/1563 [00:04<00:00, 317.98it/s, loss=0.507, acc=0.767]


Epoch 3: Loss=0.5470, Acc=0.7672


Epoch 4/20: 100%|███████████████████████████████| 1563/1563 [00:04<00:00, 322.39it/s, loss=0.109, acc=0.861]


Epoch 4: Loss=0.3499, Acc=0.8609


Epoch 5/20: 100%|███████████████████████████████| 1563/1563 [00:04<00:00, 316.77it/s, loss=0.151, acc=0.905]


Epoch 5: Loss=0.2475, Acc=0.9050


Epoch 6/20: 100%|██████████████████████████████| 1563/1563 [00:04<00:00, 321.48it/s, loss=0.0837, acc=0.936]


Epoch 6: Loss=0.1766, Acc=0.9359


Epoch 7/20: 100%|███████████████████████████████| 1563/1563 [00:04<00:00, 317.38it/s, loss=0.103, acc=0.954]


Epoch 7: Loss=0.1289, Acc=0.9537


Epoch 8/20: 100%|██████████████████████████████| 1563/1563 [00:04<00:00, 322.24it/s, loss=0.0288, acc=0.971]


Epoch 8: Loss=0.0858, Acc=0.9710


Epoch 9/20: 100%|███████████████████████████████| 1563/1563 [00:04<00:00, 324.30it/s, loss=0.0214, acc=0.98]


Epoch 9: Loss=0.0648, Acc=0.9799


Epoch 10/20: 100%|█████████████████████████████| 1563/1563 [00:04<00:00, 336.23it/s, loss=0.0102, acc=0.987]


Epoch 10: Loss=0.0415, Acc=0.9873


Epoch 11/20: 100%|██████████████████████████████| 1563/1563 [00:04<00:00, 328.49it/s, loss=0.0155, acc=0.99]


Epoch 11: Loss=0.0319, Acc=0.9903


Epoch 12/20: 100%|███████████████████████████| 1563/1563 [00:04<00:00, 326.87it/s, loss=0.000745, acc=0.993]


Epoch 12: Loss=0.0243, Acc=0.9926


Epoch 13/20: 100%|█████████████████████████████| 1563/1563 [00:04<00:00, 322.39it/s, loss=0.0352, acc=0.995]


Epoch 13: Loss=0.0176, Acc=0.9949


Epoch 14/20: 100%|█████████████████████████████| 1563/1563 [00:04<00:00, 330.85it/s, loss=0.0015, acc=0.996]


Epoch 14: Loss=0.0128, Acc=0.9964


Epoch 15/20: 100%|█████████████████████████████| 1563/1563 [00:04<00:00, 326.14it/s, loss=0.0152, acc=0.998]


Epoch 15: Loss=0.0088, Acc=0.9977


Epoch 16/20: 100%|█████████████████████████████| 1563/1563 [00:04<00:00, 321.62it/s, loss=0.0323, acc=0.998]


Epoch 16: Loss=0.0066, Acc=0.9983


Epoch 17/20: 100%|████████████████████████████| 1563/1563 [00:04<00:00, 324.04it/s, loss=0.00534, acc=0.999]


Epoch 17: Loss=0.0049, Acc=0.9989


Epoch 18/20: 100%|████████████████████████████| 1563/1563 [00:04<00:00, 315.60it/s, loss=0.00322, acc=0.999]


Epoch 18: Loss=0.0039, Acc=0.9993


Epoch 19/20: 100%|███████████████████████████| 1563/1563 [00:05<00:00, 306.18it/s, loss=0.000173, acc=0.999]


Epoch 19: Loss=0.0031, Acc=0.9994


Epoch 20/20: 100%|████████████████████████████████| 1563/1563 [00:05<00:00, 305.27it/s, loss=0.00209, acc=1]


Epoch 20: Loss=0.0029, Acc=0.9996
Starting to train model for length-hard with vocab size 45
Moving model to device:  cuda


Epoch 1/20: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 276.59it/s, loss=0.693, acc=0.499]


Epoch 1: Loss=0.7399, Acc=0.4987


Epoch 2/20: 100%|██████████████████████████████████| 1563/1563 [00:05<00:00, 283.11it/s, loss=0.68, acc=0.5]


Epoch 2: Loss=0.6964, Acc=0.5001


Epoch 3/20: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 276.75it/s, loss=0.696, acc=0.501]


Epoch 3: Loss=0.6950, Acc=0.5006


Epoch 4/20:  89%|███████████████████████████▍   | 1384/1563 [00:04<00:00, 283.67it/s, loss=0.693, acc=0.503]

In [None]:
def evaluate_model(model, dataloader, device='cuda'):
    model.eval()
    total = 0
    correct = 0
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            loss_mask = batch["loss_mask"].to(device)
            
            # Forward pass
            logits = model(input_ids[:, :-1])
            targets = input_ids[:, 1:]
            mask = loss_mask[:, 1:]
            
            # Get predictions
            preds = logits.argmax(dim=-1)
            
            # Identify valid rows
            per_row = mask.sum(dim=1)
            valid_rows = (per_row == 1)
            
            if not valid_rows.any():
                continue   # skip batches with no valid rows
            
            # Filter only valid rows
            mask_valid = mask[valid_rows].bool()
            preds_valid = preds[valid_rows]
            targets_valid = targets[valid_rows]
            
            # Extract predictions and true labels at masked positions
            pred_answers = preds_valid[mask_valid]
            true_answers = targets_valid[mask_valid]
            
            correct += (pred_answers == true_answers).sum().item()
            total += pred_answers.numel()
            
    accuracy = correct / total
    return accuracy

# Test
for name in train_dataset_names:
    accuracy = evaluate_model(models[name], DataLoader(
        test_datasets[name],
        batch_size=64,
        shuffle=False,
        collate_fn=collate_fn
    ))
    print(f"Testing model {name} accuracy on {name}: {accuracy}")