In [1]:
import torch
import torch.nn as nn
from transformer_lens import HookedTransformer, HookedTransformerConfig

# Configuration
config = HookedTransformerConfig(
    n_layers=2,
    n_heads=8,
    d_model=128,
    d_head=16,  # d_model / n_heads
    d_mlp=None,  # No MLPs (attention-only)
    act_fn=None,  # No activation (no MLPs)
    attention_dir="causal",  # Causal attention
    attn_only=True,  # Attention-only model
    normalization_type=None,  # No LayerNorm for simplicity
    d_vocab=50,  # 26 letters + 10 digits + special tokens
    n_ctx=60,  # Max sequence length
    init_weights=True,
    device="cuda" if torch.cuda.is_available() else "cpu"
)

model = HookedTransformer(config)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Model device: {next(model.parameters()).device}")
print(f"Model device name: {torch.cuda.get_device_name(0)}")

Model parameters: 152,626
Model device: cuda:0
Model device name: NVIDIA GeForce RTX 4080 SUPER


In [2]:
class Vocabulary:
    class TrieNode:
        def __init__(self):
            self.id = None
            self.next = {}

    def __init__(self):
        self.root = self.TrieNode()
        self.token_map = {} # Stores mapping of ID to token
        self.size = 0

    # Adds a new token into the vocabulary
    def add_token(self, token):
        node = self.root
        for c in token:
            if c not in node.next:
                node.next[c] = self.TrieNode()
            node = node.next[c]
        if node.id is None:
            node.id = self.size
            self.token_map[self.size] = token
            self.size += 1

    # Finds id of longest prefix token of text[start:end], and returns length of token
    def longest_prefix_token(self, text, start):
        longest_token = None
        longest_length = 0
        node = self.root
        for i in range(start, len(text)):
            if text[i] not in node.next:
                break
            node = node.next[text[i]]
            if node.id is not None:
                longest_token = node.id
                longest_length = i - start + 1
        assert longest_token is not None
        return longest_token, longest_length

    # Converts an id to the corresponding token
    def get_token(self, id):
        return self.token_map[id]

In [3]:
# Simple character-level tokenizer
class CountingTokenizer:
    def __init__(self):
        # Vocabulary: letters + digits + special tokens
        self.vocab = Vocabulary()
        chars = list("abcdefghijklmnopqrstuvwxyz0123456789")
        special = ["<PAD>", "<BOS>", "<EOS>", ":", " ", "Count", "the", "letter", "in"]
        raw_vocab = special + chars
        for token in raw_vocab:
            self.vocab.add_token(token)

    def encode(self, text, include_lengths = False):
        """Convert text to token IDs"""
        ids = []
        i = 0
        while i < len(text):
            id, token_length = self.vocab.longest_prefix_token(text, i)
            assert id != -1
            if include_lengths:
                ids.append((id, token_length))
            else:
                ids.append(id)
            i += token_length
        return ids
    
    def decode(self, ids):
        """Convert token IDs to text"""
        return "".join([self.vocab.get_token(id) for id in ids])

    def apply_bpe(self, words, max_token_length=3):
        """Adds merge rules based on a list of words for BPE"""
        text = "".join([f"<BOS>{word}<EOS>" for word in words])
        ignore_tokens = ["<PAD>", "<BOS>", "<EOS>", ":", " "]
        while True:
            encoded = self.encode(text, include_lengths=True)
            pairs = {}
            merge_pair = ()
            for i in range(len(encoded) - 1):
                token_pair = encoded[i], encoded[i + 1]
                if token_pair[0][1] + token_pair[1][1] > max_token_length:
                    continue
                if any(self.vocab.get_token(token[0]) in ignore_tokens for token in token_pair):
                    continue
                pairs[token_pair] = pairs.get(token_pair, 0) + 1
                if not merge_pair or pairs[merge_pair] < pairs[token_pair]:
                    merge_pair = token_pair
            if not merge_pair or pairs[merge_pair] < 2:
                break
            self.vocab.add_token("".join([self.vocab.get_token(token[0]) for token in merge_pair]))

tokenizer = CountingTokenizer()

In [4]:
import random

def generate_counting_question(target_letter='a', multiplicity_range=(1, 2), length_range=(5, 10)):
    """
    Generates a new word based on input parameters and gives the answer to the question
    """
    
    # Sample words with target letter
    words_with_target = []
    count = random.randint(*multiplicity_range)
    
    # Generate string with exact count of target letter
    chars = list("abcdefghijklmnopqrstuvwxyz")
    chars.remove(target_letter)
    
    length = random.randint(max(count, length_range[0]), length_range[1])
    string_chars = random.choices(chars, k=length - count)
    
    # Insert target letters
    positions = random.sample(range(length), count)
    for pos in positions:
        string_chars.insert(pos, target_letter)
    
    input_string = "".join(string_chars[:length])
    question = f"Count the letter {target_letter} in: {input_string}"
    answer = str(count)

    return question, answer

def generate_counting_example(qa, tokenizer=None):
    """
    Generate: "Count the letter a in: banana" -> "3"
    Format: [question tokens] [answer token]
    """
    question, answer = qa
    
    # Tokenize
    question_tokens = tokenizer.encode(question)
    question_tokens_decoded = [tokenizer.decode([token]) for token in question_tokens]
    answer_token = tokenizer.encode(answer)[0]  # Single digit
    
    # Combine: question + answer
    full_tokens = question_tokens + [answer_token]
    
    return {
        'tokens': full_tokens,
        'question_tokens_decoded': question_tokens_decoded,
        'question_length': len(question_tokens),  # For loss masking
        'answer': int(answer),
        'text': question + answer
    }

# Test
tokenizer = CountingTokenizer()
qa = generate_counting_question()
example = generate_counting_example(qa, tokenizer=tokenizer)
print(f"Text: {example['text']}")
print(f"Tokens: {example['tokens']}")
print(f"Question Tokens Decoded: {example['question_tokens_decoded']}")
print(f"Question length: {example['question_length']}")
print(f"Answer: {example['answer']}")

Text: Count the letter a in: gayshvpy1
Tokens: [5, 4, 6, 4, 7, 4, 9, 4, 8, 3, 4, 15, 9, 33, 27, 16, 30, 24, 33, 36]
Question Tokens Decoded: ['Count', ' ', 'the', ' ', 'letter', ' ', 'a', ' ', 'in', ':', ' ', 'g', 'a', 'y', 's', 'h', 'v', 'p', 'y']
Question length: 19
Answer: 1


In [5]:
# Test
qas = [generate_counting_question() for i in range(10)]
tokenizer = CountingTokenizer()
tokenizer.apply_bpe([qa[0] for qa in qas])
qa = generate_counting_question()
example = generate_counting_example(qa, tokenizer=tokenizer)
print(f"Sample Pre-Processing Text: {"".join([f"<BOS>{qa[0]}<EOS>" for qa in qas])}")
print(f"Text: {example['text']}")
print(f"Tokens: {example['tokens']}")
print(f"Question Tokens Decoded: {example['question_tokens_decoded']}")
print(f"Question length: {example['question_length']}")
print(f"Answer: {example['answer']}")

Sample Pre-Processing Text: <BOS>Count the letter a in: guaauc<EOS><BOS>Count the letter a in: blazkzj<EOS><BOS>Count the letter a in: mazvja<EOS><BOS>Count the letter a in: earsqhv<EOS><BOS>Count the letter a in: agagezw<EOS><BOS>Count the letter a in: nxeqagot<EOS><BOS>Count the letter a in: tjuafaz<EOS><BOS>Count the letter a in: staswz<EOS><BOS>Count the letter a in: haeia<EOS><BOS>Count the letter a in: bnebidbka<EOS>
Text: Count the letter a in: atqgtwgkd1
Tokens: [5, 4, 6, 4, 7, 4, 9, 4, 8, 3, 4, 9, 28, 25, 15, 28, 31, 15, 19, 12, 36]
Question Tokens Decoded: ['Count', ' ', 'the', ' ', 'letter', ' ', 'a', ' ', 'in', ':', ' ', 'a', 't', 'q', 'g', 't', 'w', 'g', 'k', 'd']
Question length: 20
Answer: 1


In [6]:
import pickle
from torch.utils.data import Dataset, DataLoader


class CountingDataset(Dataset):
    def __init__(self, n_examples=50000, difficulty='easy', tokenizer=None, allow_bpe=True, train_bpe=True):
        """
        difficulty: 'easy', 'bpe-hard', 'mult-hard', etc.
        """
        self.tokenizer = tokenizer
        self.examples = []

        use_bpe = False
        # Set parameters based on difficulty
        if difficulty == 'easy':
            mult_range = (1, 2)
            len_range = (5, 10)
        elif difficulty == 'bpe-hard':
            mult_range = (1, 2)
            len_range = (5, 10)
            # use_bpe = True
        elif difficulty == 'mult-hard':
            mult_range = (3, 10)
            len_range = (5, 10)
        elif difficulty == 'length-hard':
            mult_range = (1, 2)
            len_range = (20, 50)
        elif difficulty == 'all-hard':
            mult_range = (3, 10)
            len_range = (20, 50)
            # use_bpe = True
        elif difficulty == 'mixed':
            mult_range = (1,10)
            len_range = (5, 50)
            # use_bpe = True
        else:
            assert False
        
        # Generate basic words
        target_letters = list("abcdefghijklmnopqrstuvwxyz")
        qas = []
        for _ in range(n_examples):
            target = random.choice(target_letters)
            qa = generate_counting_question(
                target_letter=target,
                multiplicity_range=mult_range,
                length_range=len_range,
            )
            qas.append(qa)

        # Add basic words to the vocabulary if we are applying BPE
        if use_bpe and allow_bpe and train_bpe:
            tokenizer.apply_bpe([qa[0] for qa in qas])

        # Generate the full questions and examples
        basic_tokenizer = CountingTokenizer()
        if difficulty == "mixed":
            bpe_set = random.sample(range(len(qas)), len(qas) // 2)
        else:
            bpe_set = range(len(qas))
        for i in range(len(qas)):
            example = generate_counting_example(
                qas[i],
                tokenizer=tokenizer if i in bpe_set else basic_tokenizer
            )
            self.examples.append(example)
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        return self.examples[idx]

def collate_fn(batch, pad_id=0, max_len=100):
    """Pad sequences to same length"""
    # Pad tokens
    tokens = [ex['tokens'] for ex in batch]
    max_batch_len = min(max(len(t) for t in tokens), max_len)
    
    padded_tokens = []
    masks = []  # Loss mask: 1 for answer token, 0 elsewhere
    
    for ex in batch:
        seq = ex['tokens'][:max_batch_len]
        q_len = min(ex['question_length'], max_batch_len - 1)
        
        # Pad sequence
        padded = seq + [pad_id] * (max_batch_len - len(seq))
        padded_tokens.append(padded)
        
        # Create mask: only compute loss on answer token
        mask = [0] * max_batch_len
        if q_len < len(seq):  # If answer token exists
            mask[q_len] = 1  # Answer is right after question
        masks.append(mask)
    
    return {
        'input_ids': torch.tensor(padded_tokens, dtype=torch.long),
        'loss_mask': torch.tensor(masks, dtype=torch.float),
        'answers': torch.tensor([ex['answer'] for ex in batch], dtype=torch.long)
    }


In [7]:
# Create dataloaders
train_dataset_names = ["easy", "bpe-hard", "mult-hard", "length-hard", "all-hard"]
GENERATE_NEW_TRAINING_DATA = False # IMPORTANT: Set this to True only if you want to generate new datasets

# Training dataloaders
train_datasets = {}
train_loaders = {}
train_tokenizers = {}
for name in train_dataset_names:
    if GENERATE_NEW_TRAINING_DATA:
        train_tokenizers[name] = CountingTokenizer()
        train_datasets[name] = CountingDataset(n_examples=100000, difficulty=name, tokenizer=train_tokenizers[name])
        with open(f"train-{name}-dataset.pkl", "wb") as f:
            pickle.dump(train_datasets[name], f)
        with open(f"train-{name}-tokenizer.pkl", "wb") as f:
            pickle.dump(train_tokenizers[name], f)
    else:
        with open(f"train-{name}-dataset.pkl", "rb") as f:
            train_datasets[name] = pickle.load(f)
        with open(f"train-{name}-tokenizer.pkl", "rb") as f:
            train_tokenizers[name] = pickle.load(f)
            print(train_tokenizers[name].vocab.size)
    train_loaders[name] = DataLoader(train_datasets[name], batch_size=64, shuffle=True, collate_fn=collate_fn)

# Test batch
batch = next(iter(train_loaders["bpe-hard"]))
print(f"Input shape: {batch['input_ids'].shape}")
print(f"Mask shape: {batch['loss_mask'].shape}")
print(f"Example tokens: {batch['input_ids'][0]}")
print(f"Example mask: {batch['loss_mask'][0]}")

45
45
45
45
45
Input shape: torch.Size([64, 22])
Mask shape: torch.Size([64, 22])
Example tokens: tensor([ 5,  4,  6,  4,  7,  4, 28,  4,  8,  3,  4, 20, 23, 26, 32, 28, 23, 36,
         0,  0,  0,  0])
Example mask: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
        0., 0., 0., 0.])


In [8]:
GENERATE_NEW_TESTING_DATA = False # IMPORTANT: Set this to True only if you want to generate new datasets

# Testing dataloaders
test_dataset_names = ["easy", "bpe-hard", "mult-hard", "length-hard", "all-hard"]
test_datasets = {}
test_tokenizers = {}
for name in test_dataset_names:
    if GENERATE_NEW_TESTING_DATA:
        test_datasets[name] = CountingDataset(n_examples=10000, difficulty=name, tokenizer=train_tokenizers[name], allow_bpe=True, train_bpe=False)
        with open(f"test-{name}-dataset.pkl", "wb") as f:
            pickle.dump(test_datasets[name], f)
    else:
        with open(f"test-{name}-dataset.pkl", "rb") as f:
            test_datasets[name] = pickle.load(f)

In [9]:
import torch.optim as optim
from tqdm import tqdm

RETRAIN = True # IMPORTANT: Only toggle to True if you want to actually train new models

def train_model(model, name, train_loader, n_epochs=10, lr=1e-3, device='cuda'):
    """
    Train with MASKED loss (only on answer token)
    
    Classification loss: CrossEntropy on vocabulary
    (Can also try regression loss on digit value)
    """
    model = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
    
    # Classification loss (predict token ID)
    criterion = nn.CrossEntropyLoss(reduction='none')  # Per-token loss
    
    for epoch in range(n_epochs):
        model.train()
        total_loss = 0
        correct = 0
        total = 0
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{n_epochs}")
        for batch in pbar:
            input_ids = batch['input_ids'].to(device)  # [batch, seq_len]
            loss_mask = batch['loss_mask'].to(device)  # [batch, seq_len]
            
            # Forward pass
            # Input: all tokens except last
            # Target: all tokens except first (shifted by 1)
            logits = model(input_ids[:, :-1])  # [batch, seq_len-1, vocab]
            targets = input_ids[:, 1:]  # [batch, seq_len-1]
            
            # Compute loss
            loss_per_token = criterion(
                logits.reshape(-1, logits.size(-1)),  # [batch*(seq_len-1), vocab]
                targets.reshape(-1)  # [batch*(seq_len-1)]
            )
            loss_per_token = loss_per_token.reshape(targets.shape)  # [batch, seq_len-1]
            
            # Apply mask: only compute loss on answer token
            mask = loss_mask[:, 1:]  # Align with targets
            masked_loss = (loss_per_token * mask).sum() / mask.sum()
            
            # Backward pass
            optimizer.zero_grad()
            masked_loss.backward()
            optimizer.step()
            
            # Metrics
            total_loss += masked_loss.item()
            
            # Accuracy: check if predicted answer digit is correct
            preds = logits.argmax(dim=-1)  # [batch, seq_len-1]
            answer_positions = mask.bool()
            if answer_positions.any():
                correct += (preds[answer_positions] == targets[answer_positions]).sum().item()
                total += answer_positions.sum().item()
            
            pbar.set_postfix({'loss': masked_loss.item(), 
                            'acc': correct/total if total > 0 else 0})
        
        scheduler.step()
        
        print(f"Epoch {epoch+1}: Loss={total_loss/len(train_loader):.4f}, "
              f"Acc={correct/total:.4f}")
        
        # Save checkpoint
        if (epoch + 1) % 5 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, f'checkpoint-{name}-epoch-{epoch+1}.pt')
    
    return model

# Train!
models = {}
for name in train_dataset_names:
    config = HookedTransformerConfig(
            n_layers=2,
            n_heads=8,
            d_model=256,
            d_head=16,  # d_model / n_heads
            d_mlp=None,  # No MLPs (attention-only)
            act_fn=None,  # No activation (no MLPs)
            attention_dir="causal",  # Causal attention
            attn_only=True,  # Attention-only model
            normalization_type=None,  # No LayerNorm for simplicity
            d_vocab=train_tokenizers[name].vocab.size + 5,  # Total vocab size for particular tokenizer
            n_ctx=100,  # Max sequence length
            init_weights=True,
            device="cuda" if torch.cuda.is_available() else "cpu"
        )
    model = HookedTransformer(config)
    if RETRAIN:
        print(f"Starting to train model for {name} with vocab size {train_tokenizers[name].vocab.size}")
        models[name] = train_model(model, name, train_loaders[name], n_epochs=60, lr=3e-4)
    else:
        checkpoint = torch.load(f'checkpoint-{name}-epoch-40.pt', map_location="cpu")
        
        model.load_state_dict(checkpoint["model_state_dict"])
        models[name] = model
        

Starting to train model for easy with vocab size 45
Moving model to device:  cuda


Epoch 1/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 302.99it/s, loss=0.689, acc=0.499]


Epoch 1: Loss=0.7359, Acc=0.4989


Epoch 2/60: 100%|███████████████████████████████| 1563/1563 [00:04<00:00, 318.26it/s, loss=0.687, acc=0.497]


Epoch 2: Loss=0.6964, Acc=0.4972


Epoch 3/60: 100%|███████████████████████████████| 1563/1563 [00:04<00:00, 324.69it/s, loss=0.431, acc=0.565]


Epoch 3: Loss=0.6605, Acc=0.5648


Epoch 4/60: 100%|████████████████████████████| 1563/1563 [00:04<00:00, 326.28it/s, loss=0.000618, acc=0.983]


Epoch 4: Loss=0.0488, Acc=0.9829


Epoch 5/60: 100%|████████████████████████████| 1563/1563 [00:04<00:00, 323.48it/s, loss=0.000187, acc=0.999]


Epoch 5: Loss=0.0050, Acc=0.9986


Epoch 6/60: 100%|████████████████████████████| 1563/1563 [00:04<00:00, 329.48it/s, loss=0.000161, acc=0.999]


Epoch 6: Loss=0.0051, Acc=0.9986


Epoch 7/60: 100%|█████████████████████████████| 1563/1563 [00:04<00:00, 323.87it/s, loss=1.32e-5, acc=0.999]


Epoch 7: Loss=0.0041, Acc=0.9988


Epoch 8/60: 100%|█████████████████████████████| 1563/1563 [00:04<00:00, 325.78it/s, loss=1.41e-5, acc=0.999]


Epoch 8: Loss=0.0044, Acc=0.9985


Epoch 9/60: 100%|█████████████████████████████████| 1563/1563 [00:04<00:00, 326.86it/s, loss=1.52e-5, acc=1]


Epoch 9: Loss=0.0014, Acc=0.9997


Epoch 10/60: 100%|████████████████████████████| 1563/1563 [00:04<00:00, 323.09it/s, loss=5.31e-5, acc=0.999]


Epoch 10: Loss=0.0035, Acc=0.9990


Epoch 11/60: 100%|███████████████████████████| 1563/1563 [00:04<00:00, 325.87it/s, loss=0.000191, acc=0.999]


Epoch 11: Loss=0.0022, Acc=0.9994


Epoch 12/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 332.42it/s, loss=1.32e-5, acc=1]


Epoch 12: Loss=0.0013, Acc=0.9996


Epoch 13/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 326.63it/s, loss=2.97e-6, acc=1]


Epoch 13: Loss=0.0010, Acc=0.9996


Epoch 14/60: 100%|████████████████████████████| 1563/1563 [00:04<00:00, 325.82it/s, loss=2.92e-5, acc=0.999]


Epoch 14: Loss=0.0023, Acc=0.9993


Epoch 15/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 328.37it/s, loss=2.38e-6, acc=1]


Epoch 15: Loss=0.0000, Acc=1.0000


Epoch 16/60: 100%|█████████████████████████████████| 1563/1563 [00:04<00:00, 323.88it/s, loss=1.8e-6, acc=1]


Epoch 16: Loss=0.0000, Acc=1.0000


Epoch 17/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 323.07it/s, loss=7.67e-7, acc=1]


Epoch 17: Loss=0.0000, Acc=1.0000


Epoch 18/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 321.19it/s, loss=2.31e-7, acc=1]


Epoch 18: Loss=0.0000, Acc=1.0000


Epoch 19/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 325.86it/s, loss=2.27e-7, acc=1]


Epoch 19: Loss=0.0000, Acc=1.0000


Epoch 20/60: 100%|████████████████████████████| 1563/1563 [00:04<00:00, 326.41it/s, loss=6.09e-5, acc=0.998]


Epoch 20: Loss=0.0054, Acc=0.9985


Epoch 21/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 321.45it/s, loss=2.12e-5, acc=1]


Epoch 21: Loss=0.0002, Acc=1.0000


Epoch 22/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 326.22it/s, loss=4.22e-6, acc=1]


Epoch 22: Loss=0.0000, Acc=1.0000


Epoch 23/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 324.42it/s, loss=4.53e-6, acc=1]


Epoch 23: Loss=0.0000, Acc=1.0000


Epoch 24/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 326.64it/s, loss=2.97e-6, acc=1]


Epoch 24: Loss=0.0000, Acc=1.0000


Epoch 25/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 321.95it/s, loss=3.32e-7, acc=1]


Epoch 25: Loss=0.0000, Acc=1.0000


Epoch 26/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 327.98it/s, loss=2.35e-7, acc=1]


Epoch 26: Loss=0.0000, Acc=1.0000


Epoch 27/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 328.58it/s, loss=8.94e-8, acc=1]


Epoch 27: Loss=0.0000, Acc=1.0000


Epoch 28/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 328.00it/s, loss=1.83e-7, acc=1]


Epoch 28: Loss=0.0000, Acc=1.0000


Epoch 29/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 330.13it/s, loss=7.45e-9, acc=1]


Epoch 29: Loss=0.0000, Acc=1.0000


Epoch 30/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 323.31it/s, loss=3.73e-9, acc=1]


Epoch 30: Loss=0.0000, Acc=1.0000


Epoch 31/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 328.79it/s, loss=7.45e-9, acc=1]


Epoch 31: Loss=0.0000, Acc=1.0000


Epoch 32/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 327.54it/s, loss=0, acc=1]


Epoch 32: Loss=0.0000, Acc=1.0000


Epoch 33/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 325.30it/s, loss=0, acc=1]


Epoch 33: Loss=0.0000, Acc=1.0000


Epoch 34/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 327.69it/s, loss=0, acc=1]


Epoch 34: Loss=0.0000, Acc=1.0000


Epoch 35/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 322.20it/s, loss=0, acc=1]


Epoch 35: Loss=0.0000, Acc=1.0000


Epoch 36/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 323.42it/s, loss=0, acc=1]


Epoch 36: Loss=0.0000, Acc=1.0000


Epoch 37/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 327.11it/s, loss=0, acc=1]


Epoch 37: Loss=0.0000, Acc=1.0000


Epoch 38/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 350.19it/s, loss=0, acc=1]


Epoch 38: Loss=0.0000, Acc=1.0000


Epoch 39/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 349.68it/s, loss=0, acc=1]


Epoch 39: Loss=0.0000, Acc=1.0000


Epoch 40/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 348.16it/s, loss=0, acc=1]


Epoch 40: Loss=0.0000, Acc=1.0000


Epoch 41/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 348.84it/s, loss=0, acc=1]


Epoch 41: Loss=0.0000, Acc=1.0000


Epoch 42/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 347.44it/s, loss=0, acc=1]


Epoch 42: Loss=0.0000, Acc=1.0000


Epoch 43/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 347.26it/s, loss=0, acc=1]


Epoch 43: Loss=0.0000, Acc=1.0000


Epoch 44/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 344.98it/s, loss=0, acc=1]


Epoch 44: Loss=0.0000, Acc=1.0000


Epoch 45/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 352.74it/s, loss=0, acc=1]


Epoch 45: Loss=0.0000, Acc=1.0000


Epoch 46/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 351.94it/s, loss=0, acc=1]


Epoch 46: Loss=0.0000, Acc=1.0000


Epoch 47/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 346.27it/s, loss=0, acc=1]


Epoch 47: Loss=0.0000, Acc=1.0000


Epoch 48/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 351.99it/s, loss=0, acc=1]


Epoch 48: Loss=0.0000, Acc=1.0000


Epoch 49/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 349.83it/s, loss=0, acc=1]


Epoch 49: Loss=0.0000, Acc=1.0000


Epoch 50/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 353.09it/s, loss=0, acc=1]


Epoch 50: Loss=0.0000, Acc=1.0000


Epoch 51/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 351.69it/s, loss=0, acc=1]


Epoch 51: Loss=0.0000, Acc=1.0000


Epoch 52/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 351.16it/s, loss=0, acc=1]


Epoch 52: Loss=0.0000, Acc=1.0000


Epoch 53/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 348.77it/s, loss=0, acc=1]


Epoch 53: Loss=0.0000, Acc=1.0000


Epoch 54/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 347.55it/s, loss=0, acc=1]


Epoch 54: Loss=0.0000, Acc=1.0000


Epoch 55/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 350.99it/s, loss=0, acc=1]


Epoch 55: Loss=0.0000, Acc=1.0000


Epoch 56/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 351.99it/s, loss=0, acc=1]


Epoch 56: Loss=0.0000, Acc=1.0000


Epoch 57/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 343.91it/s, loss=0, acc=1]


Epoch 57: Loss=0.0000, Acc=1.0000


Epoch 58/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 342.08it/s, loss=0, acc=1]


Epoch 58: Loss=0.0000, Acc=1.0000


Epoch 59/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 352.02it/s, loss=0, acc=1]


Epoch 59: Loss=0.0000, Acc=1.0000


Epoch 60/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 356.99it/s, loss=0, acc=1]


Epoch 60: Loss=0.0000, Acc=1.0000
Starting to train model for bpe-hard with vocab size 45
Moving model to device:  cuda


Epoch 1/60: 100%|███████████████████████████████| 1563/1563 [00:04<00:00, 316.97it/s, loss=0.705, acc=0.498]


Epoch 1: Loss=0.7372, Acc=0.4984


Epoch 2/60: 100%|███████████████████████████████| 1563/1563 [00:04<00:00, 320.43it/s, loss=0.695, acc=0.502]


Epoch 2: Loss=0.6960, Acc=0.5016


Epoch 3/60: 100%|███████████████████████████████| 1563/1563 [00:04<00:00, 321.16it/s, loss=0.517, acc=0.565]


Epoch 3: Loss=0.6724, Acc=0.5654


Epoch 4/60: 100%|█████████████████████████████| 1563/1563 [00:04<00:00, 326.68it/s, loss=0.00428, acc=0.955]


Epoch 4: Loss=0.1124, Acc=0.9551


Epoch 5/60: 100%|██████████████████████████████| 1563/1563 [00:04<00:00, 323.53it/s, loss=0.0376, acc=0.992]


Epoch 5: Loss=0.0233, Acc=0.9924


Epoch 6/60: 100%|█████████████████████████████| 1563/1563 [00:04<00:00, 322.69it/s, loss=0.00048, acc=0.998]


Epoch 6: Loss=0.0079, Acc=0.9977


Epoch 7/60: 100%|█████████████████████████████| 1563/1563 [00:04<00:00, 326.61it/s, loss=3.48e-5, acc=0.997]


Epoch 7: Loss=0.0088, Acc=0.9971


Epoch 8/60: 100%|████████████████████████████| 1563/1563 [00:04<00:00, 320.44it/s, loss=0.000637, acc=0.997]


Epoch 8: Loss=0.0083, Acc=0.9975


Epoch 9/60: 100%|█████████████████████████████| 1563/1563 [00:04<00:00, 323.89it/s, loss=0.00224, acc=0.998]


Epoch 9: Loss=0.0063, Acc=0.9980


Epoch 10/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 324.87it/s, loss=8.54e-6, acc=1]


Epoch 10: Loss=0.0009, Acc=0.9998


Epoch 11/60: 100%|████████████████████████████| 1563/1563 [00:04<00:00, 323.25it/s, loss=2.98e-5, acc=0.998]


Epoch 11: Loss=0.0083, Acc=0.9980


Epoch 12/60: 100%|████████████████████████████| 1563/1563 [00:04<00:00, 319.51it/s, loss=1.92e-5, acc=0.999]


Epoch 12: Loss=0.0039, Acc=0.9989


Epoch 13/60: 100%|███████████████████████████| 1563/1563 [00:04<00:00, 326.58it/s, loss=0.000335, acc=0.997]


Epoch 13: Loss=0.0094, Acc=0.9971


Epoch 14/60: 100%|█████████████████████████████████| 1563/1563 [00:04<00:00, 323.09it/s, loss=3.1e-5, acc=1]


Epoch 14: Loss=0.0001, Acc=1.0000


Epoch 15/60: 100%|███████████████████████████| 1563/1563 [00:04<00:00, 324.61it/s, loss=0.000172, acc=0.998]


Epoch 15: Loss=0.0066, Acc=0.9981


Epoch 16/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 322.48it/s, loss=5.96e-6, acc=1]


Epoch 16: Loss=0.0002, Acc=1.0000


Epoch 17/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 324.48it/s, loss=8.13e-6, acc=1]


Epoch 17: Loss=0.0000, Acc=1.0000


Epoch 18/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 325.97it/s, loss=8.91e-6, acc=1]


Epoch 18: Loss=0.0000, Acc=1.0000


Epoch 19/60: 100%|█████████████████████████████████| 1563/1563 [00:04<00:00, 326.55it/s, loss=5.7e-7, acc=1]


Epoch 19: Loss=0.0000, Acc=1.0000


Epoch 20/60: 100%|███████████████████████████| 1563/1563 [00:04<00:00, 324.10it/s, loss=0.000145, acc=0.998]


Epoch 20: Loss=0.0059, Acc=0.9983


Epoch 21/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 331.80it/s, loss=9.53e-6, acc=1]


Epoch 21: Loss=0.0001, Acc=1.0000


Epoch 22/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 315.66it/s, loss=3.54e-6, acc=1]


Epoch 22: Loss=0.0000, Acc=1.0000


Epoch 23/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 318.71it/s, loss=2.17e-6, acc=1]


Epoch 23: Loss=0.0000, Acc=1.0000


Epoch 24/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 326.54it/s, loss=7.56e-7, acc=1]


Epoch 24: Loss=0.0000, Acc=1.0000


Epoch 25/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 321.92it/s, loss=2.91e-7, acc=1]


Epoch 25: Loss=0.0000, Acc=1.0000


Epoch 26/60: 100%|█████████████████████████████████| 1563/1563 [00:04<00:00, 322.72it/s, loss=6.3e-7, acc=1]


Epoch 26: Loss=0.0000, Acc=1.0000


Epoch 27/60: 100%|███████████████████████████| 1563/1563 [00:04<00:00, 334.70it/s, loss=0.000467, acc=0.998]


Epoch 27: Loss=0.0107, Acc=0.9978


Epoch 28/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 330.67it/s, loss=2.54e-5, acc=1]


Epoch 28: Loss=0.0001, Acc=1.0000


Epoch 29/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 331.15it/s, loss=1.31e-5, acc=1]


Epoch 29: Loss=0.0000, Acc=1.0000


Epoch 30/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 337.31it/s, loss=4.58e-6, acc=1]


Epoch 30: Loss=0.0000, Acc=1.0000


Epoch 31/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 334.59it/s, loss=1.76e-5, acc=1]


Epoch 31: Loss=0.0000, Acc=1.0000


Epoch 32/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 332.19it/s, loss=2.32e-6, acc=1]


Epoch 32: Loss=0.0000, Acc=1.0000


Epoch 33/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 329.97it/s, loss=1.17e-6, acc=1]


Epoch 33: Loss=0.0000, Acc=1.0000


Epoch 34/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 329.95it/s, loss=2.64e-7, acc=1]


Epoch 34: Loss=0.0000, Acc=1.0000


Epoch 35/60: 100%|████████████████████████████| 1563/1563 [00:04<00:00, 326.78it/s, loss=8.42e-6, acc=0.999]


Epoch 35: Loss=0.0030, Acc=0.9992


Epoch 36/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 331.90it/s, loss=3.04e-5, acc=1]


Epoch 36: Loss=0.0000, Acc=1.0000


Epoch 37/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 325.36it/s, loss=1.51e-6, acc=1]


Epoch 37: Loss=0.0000, Acc=1.0000


Epoch 38/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 329.49it/s, loss=3.29e-6, acc=1]


Epoch 38: Loss=0.0000, Acc=1.0000


Epoch 39/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 331.44it/s, loss=1.55e-6, acc=1]


Epoch 39: Loss=0.0000, Acc=1.0000


Epoch 40/60: 100%|█████████████████████████████████| 1563/1563 [00:04<00:00, 331.02it/s, loss=3.1e-6, acc=1]


Epoch 40: Loss=0.0000, Acc=1.0000


Epoch 41/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 329.03it/s, loss=1.49e-7, acc=1]


Epoch 41: Loss=0.0000, Acc=1.0000


Epoch 42/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 326.29it/s, loss=1.53e-7, acc=1]


Epoch 42: Loss=0.0000, Acc=1.0000


Epoch 43/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 336.85it/s, loss=1.42e-7, acc=1]


Epoch 43: Loss=0.0000, Acc=1.0000


Epoch 44/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 335.94it/s, loss=3.35e-8, acc=1]


Epoch 44: Loss=0.0000, Acc=1.0000


Epoch 45/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 333.81it/s, loss=0, acc=1]


Epoch 45: Loss=0.0000, Acc=1.0000


Epoch 46/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 330.32it/s, loss=3.73e-9, acc=1]


Epoch 46: Loss=0.0000, Acc=1.0000


Epoch 47/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 332.48it/s, loss=0, acc=1]


Epoch 47: Loss=0.0000, Acc=1.0000


Epoch 48/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 333.85it/s, loss=7.45e-9, acc=1]


Epoch 48: Loss=0.0000, Acc=1.0000


Epoch 49/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 337.26it/s, loss=0, acc=1]


Epoch 49: Loss=0.0000, Acc=1.0000


Epoch 50/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 330.72it/s, loss=0, acc=1]


Epoch 50: Loss=0.0000, Acc=1.0000


Epoch 51/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 328.28it/s, loss=0, acc=1]


Epoch 51: Loss=0.0000, Acc=1.0000


Epoch 52/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 330.14it/s, loss=0, acc=1]


Epoch 52: Loss=0.0000, Acc=1.0000


Epoch 53/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 326.18it/s, loss=0, acc=1]


Epoch 53: Loss=0.0000, Acc=1.0000


Epoch 54/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 324.02it/s, loss=0, acc=1]


Epoch 54: Loss=0.0000, Acc=1.0000


Epoch 55/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 327.47it/s, loss=0, acc=1]


Epoch 55: Loss=0.0000, Acc=1.0000


Epoch 56/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 323.19it/s, loss=0, acc=1]


Epoch 56: Loss=0.0000, Acc=1.0000


Epoch 57/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 347.27it/s, loss=0, acc=1]


Epoch 57: Loss=0.0000, Acc=1.0000


Epoch 58/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 346.24it/s, loss=0, acc=1]


Epoch 58: Loss=0.0000, Acc=1.0000


Epoch 59/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 344.13it/s, loss=0, acc=1]


Epoch 59: Loss=0.0000, Acc=1.0000


Epoch 60/60: 100%|██████████████████████████████████████| 1563/1563 [00:04<00:00, 343.09it/s, loss=0, acc=1]


Epoch 60: Loss=0.0000, Acc=1.0000
Starting to train model for mult-hard with vocab size 45
Moving model to device:  cuda


Epoch 1/60: 100%|██████████████████████████████████| 1563/1563 [00:04<00:00, 319.93it/s, loss=1.37, acc=0.3]


Epoch 1: Loss=1.7431, Acc=0.2996


Epoch 2/60: 100%|███████████████████████████████| 1563/1563 [00:04<00:00, 322.72it/s, loss=0.727, acc=0.654]


Epoch 2: Loss=0.7755, Acc=0.6540


Epoch 3/60: 100%|███████████████████████████████| 1563/1563 [00:04<00:00, 328.84it/s, loss=0.432, acc=0.743]


Epoch 3: Loss=0.5852, Acc=0.7427


Epoch 4/60: 100%|███████████████████████████████| 1563/1563 [00:04<00:00, 332.92it/s, loss=0.662, acc=0.778]


Epoch 4: Loss=0.5123, Acc=0.7778


Epoch 5/60: 100%|███████████████████████████████| 1563/1563 [00:04<00:00, 324.37it/s, loss=0.259, acc=0.796]


Epoch 5: Loss=0.4728, Acc=0.7960


Epoch 6/60: 100%|███████████████████████████████| 1563/1563 [00:04<00:00, 326.18it/s, loss=0.686, acc=0.814]


Epoch 6: Loss=0.4324, Acc=0.8141


Epoch 7/60: 100%|███████████████████████████████| 1563/1563 [00:04<00:00, 326.93it/s, loss=0.281, acc=0.833]


Epoch 7: Loss=0.3896, Acc=0.8329


Epoch 8/60: 100%|█████████████████████████████████| 1563/1563 [00:04<00:00, 336.83it/s, loss=0.23, acc=0.86]


Epoch 8: Loss=0.3347, Acc=0.8600


Epoch 9/60: 100%|███████████████████████████████| 1563/1563 [00:04<00:00, 331.41it/s, loss=0.121, acc=0.885]


Epoch 9: Loss=0.2809, Acc=0.8851


Epoch 10/60: 100%|██████████████████████████████| 1563/1563 [00:04<00:00, 326.82it/s, loss=0.233, acc=0.902]


Epoch 10: Loss=0.2442, Acc=0.9018


Epoch 11/60: 100%|██████████████████████████████| 1563/1563 [00:04<00:00, 331.15it/s, loss=0.247, acc=0.915]


Epoch 11: Loss=0.2136, Acc=0.9150


Epoch 12/60: 100%|██████████████████████████████| 1563/1563 [00:04<00:00, 335.39it/s, loss=0.289, acc=0.926]


Epoch 12: Loss=0.1891, Acc=0.9262


Epoch 13/60: 100%|███████████████████████████████| 1563/1563 [00:04<00:00, 326.00it/s, loss=0.22, acc=0.935]


Epoch 13: Loss=0.1679, Acc=0.9355


Epoch 14/60: 100%|█████████████████████████████| 1563/1563 [00:04<00:00, 335.70it/s, loss=0.0961, acc=0.944]


Epoch 14: Loss=0.1482, Acc=0.9438


Epoch 15/60: 100%|██████████████████████████████| 1563/1563 [00:04<00:00, 328.57it/s, loss=0.0897, acc=0.95]


Epoch 15: Loss=0.1340, Acc=0.9503


Epoch 16/60: 100%|██████████████████████████████| 1563/1563 [00:04<00:00, 329.55it/s, loss=0.106, acc=0.957]


Epoch 16: Loss=0.1175, Acc=0.9570


Epoch 17/60: 100%|██████████████████████████████| 1563/1563 [00:04<00:00, 329.51it/s, loss=0.126, acc=0.962]


Epoch 17: Loss=0.1042, Acc=0.9622


Epoch 18/60: 100%|██████████████████████████████| 1563/1563 [00:04<00:00, 330.75it/s, loss=0.121, acc=0.967]


Epoch 18: Loss=0.0912, Acc=0.9674


Epoch 19/60: 100%|█████████████████████████████| 1563/1563 [00:04<00:00, 326.74it/s, loss=0.0222, acc=0.972]


Epoch 19: Loss=0.0792, Acc=0.9718


Epoch 20/60: 100%|█████████████████████████████| 1563/1563 [00:04<00:00, 333.55it/s, loss=0.0809, acc=0.975]


Epoch 20: Loss=0.0694, Acc=0.9755


Epoch 21/60: 100%|████████████████████████████| 1563/1563 [00:04<00:00, 340.91it/s, loss=0.00648, acc=0.979]


Epoch 21: Loss=0.0613, Acc=0.9787


Epoch 22/60: 100%|█████████████████████████████| 1563/1563 [00:04<00:00, 334.10it/s, loss=0.0079, acc=0.981]


Epoch 22: Loss=0.0527, Acc=0.9809


Epoch 23/60: 100%|█████████████████████████████| 1563/1563 [00:04<00:00, 329.88it/s, loss=0.0178, acc=0.983]


Epoch 23: Loss=0.0474, Acc=0.9835


Epoch 24/60: 100%|█████████████████████████████| 1563/1563 [00:04<00:00, 328.49it/s, loss=0.0259, acc=0.986]


Epoch 24: Loss=0.0397, Acc=0.9863


Epoch 25/60: 100%|█████████████████████████████| 1563/1563 [00:04<00:00, 328.16it/s, loss=0.0018, acc=0.987]


Epoch 25: Loss=0.0365, Acc=0.9872


Epoch 26/60: 100%|█████████████████████████████| 1563/1563 [00:04<00:00, 331.49it/s, loss=0.0551, acc=0.988]


Epoch 26: Loss=0.0319, Acc=0.9885


Epoch 27/60: 100%|██████████████████████████████| 1563/1563 [00:04<00:00, 329.76it/s, loss=0.0966, acc=0.99]


Epoch 27: Loss=0.0293, Acc=0.9900


Epoch 28/60: 100%|█████████████████████████████| 1563/1563 [00:04<00:00, 329.57it/s, loss=0.0624, acc=0.991]


Epoch 28: Loss=0.0252, Acc=0.9914


Epoch 29/60: 100%|█████████████████████████████| 1563/1563 [00:04<00:00, 329.21it/s, loss=0.0104, acc=0.992]


Epoch 29: Loss=0.0228, Acc=0.9922


Epoch 30/60: 100%|█████████████████████████████| 1563/1563 [00:04<00:00, 330.61it/s, loss=0.0428, acc=0.994]


Epoch 30: Loss=0.0177, Acc=0.9939


Epoch 31/60: 100%|█████████████████████████████| 1563/1563 [00:04<00:00, 328.59it/s, loss=0.0791, acc=0.994]


Epoch 31: Loss=0.0170, Acc=0.9938


Epoch 32/60: 100%|████████████████████████████| 1563/1563 [00:04<00:00, 330.75it/s, loss=0.00105, acc=0.994]


Epoch 32: Loss=0.0152, Acc=0.9944


Epoch 33/60: 100%|█████████████████████████████| 1563/1563 [00:04<00:00, 332.96it/s, loss=0.0197, acc=0.996]


Epoch 33: Loss=0.0131, Acc=0.9958


Epoch 34/60: 100%|█████████████████████████████| 1563/1563 [00:04<00:00, 328.26it/s, loss=0.0188, acc=0.996]


Epoch 34: Loss=0.0123, Acc=0.9959


Epoch 35/60: 100%|████████████████████████████| 1563/1563 [00:04<00:00, 329.00it/s, loss=0.00349, acc=0.996]


Epoch 35: Loss=0.0109, Acc=0.9963


Epoch 36/60: 100%|███████████████████████████| 1563/1563 [00:04<00:00, 330.72it/s, loss=0.000618, acc=0.998]


Epoch 36: Loss=0.0069, Acc=0.9978


Epoch 37/60: 100%|████████████████████████████| 1563/1563 [00:04<00:00, 327.78it/s, loss=0.00608, acc=0.997]


Epoch 37: Loss=0.0080, Acc=0.9972


Epoch 38/60: 100%|███████████████████████████| 1563/1563 [00:04<00:00, 333.05it/s, loss=0.000208, acc=0.999]


Epoch 38: Loss=0.0049, Acc=0.9985


Epoch 39/60: 100%|███████████████████████████| 1563/1563 [00:04<00:00, 335.85it/s, loss=0.000157, acc=0.998]


Epoch 39: Loss=0.0058, Acc=0.9981


Epoch 40/60: 100%|████████████████████████████| 1563/1563 [00:04<00:00, 328.52it/s, loss=0.00257, acc=0.999]


Epoch 40: Loss=0.0038, Acc=0.9988


Epoch 41/60: 100%|████████████████████████████| 1563/1563 [00:04<00:00, 332.49it/s, loss=0.00917, acc=0.999]


Epoch 41: Loss=0.0031, Acc=0.9992


Epoch 42/60: 100%|███████████████████████████| 1563/1563 [00:04<00:00, 334.67it/s, loss=0.000137, acc=0.999]


Epoch 42: Loss=0.0031, Acc=0.9990


Epoch 43/60: 100%|███████████████████████████| 1563/1563 [00:04<00:00, 325.99it/s, loss=0.000254, acc=0.999]


Epoch 43: Loss=0.0033, Acc=0.9990


Epoch 44/60: 100%|█████████████████████████████████| 1563/1563 [00:04<00:00, 333.83it/s, loss=1.5e-5, acc=1]


Epoch 44: Loss=0.0013, Acc=0.9997


Epoch 45/60: 100%|███████████████████████████████| 1563/1563 [00:04<00:00, 324.83it/s, loss=0.000274, acc=1]


Epoch 45: Loss=0.0017, Acc=0.9996


Epoch 46/60: 100%|█████████████████████████████████| 1563/1563 [00:04<00:00, 330.19it/s, loss=0.0196, acc=1]


Epoch 46: Loss=0.0008, Acc=0.9999


Epoch 47/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 328.38it/s, loss=5.38e-5, acc=1]


Epoch 47: Loss=0.0008, Acc=0.9998


Epoch 48/60: 100%|███████████████████████████████| 1563/1563 [00:04<00:00, 330.42it/s, loss=0.000542, acc=1]


Epoch 48: Loss=0.0001, Acc=1.0000


Epoch 49/60: 100%|█████████████████████████████████| 1563/1563 [00:04<00:00, 328.55it/s, loss=2.4e-5, acc=1]


Epoch 49: Loss=0.0009, Acc=0.9998


Epoch 50/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 336.27it/s, loss=4.16e-5, acc=1]


Epoch 50: Loss=0.0003, Acc=1.0000


Epoch 51/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 328.42it/s, loss=2.65e-5, acc=1]


Epoch 51: Loss=0.0001, Acc=1.0000


Epoch 52/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 332.73it/s, loss=1.33e-5, acc=1]


Epoch 52: Loss=0.0001, Acc=1.0000


Epoch 53/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 330.22it/s, loss=3.22e-5, acc=1]


Epoch 53: Loss=0.0001, Acc=1.0000


Epoch 54/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 332.67it/s, loss=7.13e-6, acc=1]


Epoch 54: Loss=0.0000, Acc=1.0000


Epoch 55/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 330.59it/s, loss=1.17e-5, acc=1]


Epoch 55: Loss=0.0000, Acc=1.0000


Epoch 56/60: 100%|███████████████████████████████| 1563/1563 [00:04<00:00, 329.80it/s, loss=0.000114, acc=1]


Epoch 56: Loss=0.0000, Acc=1.0000


Epoch 57/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 331.32it/s, loss=4.32e-7, acc=1]


Epoch 57: Loss=0.0000, Acc=1.0000


Epoch 58/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 325.15it/s, loss=6.12e-6, acc=1]


Epoch 58: Loss=0.0000, Acc=1.0000


Epoch 59/60: 100%|████████████████████████████████| 1563/1563 [00:04<00:00, 329.81it/s, loss=3.75e-6, acc=1]


Epoch 59: Loss=0.0000, Acc=1.0000


Epoch 60/60: 100%|███████████████████████████████| 1563/1563 [00:04<00:00, 336.02it/s, loss=0.000131, acc=1]


Epoch 60: Loss=0.0000, Acc=1.0000
Starting to train model for length-hard with vocab size 45
Moving model to device:  cuda


Epoch 1/60: 100%|█████████████████████████████████| 1563/1563 [00:05<00:00, 294.74it/s, loss=0.702, acc=0.5]


Epoch 1: Loss=0.7403, Acc=0.5004


Epoch 2/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 299.06it/s, loss=0.688, acc=0.501]


Epoch 2: Loss=0.6966, Acc=0.5013


Epoch 3/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 292.45it/s, loss=0.697, acc=0.501]


Epoch 3: Loss=0.6950, Acc=0.5013


Epoch 4/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 300.18it/s, loss=0.693, acc=0.502]


Epoch 4: Loss=0.6945, Acc=0.5021


Epoch 5/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 297.33it/s, loss=0.696, acc=0.505]


Epoch 5: Loss=0.6940, Acc=0.5048


Epoch 6/60: 100%|█████████████████████████████████| 1563/1563 [00:05<00:00, 295.35it/s, loss=0.693, acc=0.5]


Epoch 6: Loss=0.6939, Acc=0.5002


Epoch 7/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 295.34it/s, loss=0.693, acc=0.502]


Epoch 7: Loss=0.6936, Acc=0.5022


Epoch 8/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 291.21it/s, loss=0.691, acc=0.504]


Epoch 8: Loss=0.6934, Acc=0.5037


Epoch 9/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 293.82it/s, loss=0.697, acc=0.507]


Epoch 9: Loss=0.6932, Acc=0.5067


Epoch 10/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 301.27it/s, loss=0.691, acc=0.508]


Epoch 10: Loss=0.6931, Acc=0.5075


Epoch 11/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 294.26it/s, loss=0.693, acc=0.513]


Epoch 11: Loss=0.6926, Acc=0.5130


Epoch 12/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 299.61it/s, loss=0.664, acc=0.527]


Epoch 12: Loss=0.6774, Acc=0.5270


Epoch 13/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 295.06it/s, loss=0.63, acc=0.534]


Epoch 13: Loss=0.6661, Acc=0.5336


Epoch 14/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 295.74it/s, loss=0.578, acc=0.537]


Epoch 14: Loss=0.6657, Acc=0.5374


Epoch 15/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 306.40it/s, loss=0.681, acc=0.539]


Epoch 15: Loss=0.6652, Acc=0.5388


Epoch 16/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 306.66it/s, loss=0.695, acc=0.541]


Epoch 16: Loss=0.6653, Acc=0.5407


Epoch 17/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 311.98it/s, loss=0.585, acc=0.545]


Epoch 17: Loss=0.6634, Acc=0.5452


Epoch 18/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 298.87it/s, loss=0.641, acc=0.556]


Epoch 18: Loss=0.6616, Acc=0.5563


Epoch 19/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 309.15it/s, loss=0.65, acc=0.561]


Epoch 19: Loss=0.6609, Acc=0.5610


Epoch 20/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 300.89it/s, loss=0.647, acc=0.564]


Epoch 20: Loss=0.6593, Acc=0.5640


Epoch 21/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 305.71it/s, loss=0.632, acc=0.567]


Epoch 21: Loss=0.6577, Acc=0.5665


Epoch 22/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 298.16it/s, loss=0.641, acc=0.57]


Epoch 22: Loss=0.6569, Acc=0.5704


Epoch 23/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 305.62it/s, loss=0.641, acc=0.573]


Epoch 23: Loss=0.6552, Acc=0.5735


Epoch 24/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 297.27it/s, loss=0.625, acc=0.574]


Epoch 24: Loss=0.6541, Acc=0.5737


Epoch 25/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 298.33it/s, loss=0.666, acc=0.58]


Epoch 25: Loss=0.6522, Acc=0.5798


Epoch 26/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 296.78it/s, loss=0.645, acc=0.585]


Epoch 26: Loss=0.6447, Acc=0.5850


Epoch 27/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 296.59it/s, loss=0.555, acc=0.591]


Epoch 27: Loss=0.6298, Acc=0.5908


Epoch 28/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 296.20it/s, loss=0.675, acc=0.598]


Epoch 28: Loss=0.6254, Acc=0.5982


Epoch 29/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 294.23it/s, loss=0.59, acc=0.603]


Epoch 29: Loss=0.6164, Acc=0.6028


Epoch 30/60: 100%|████████████████████████████████| 1563/1563 [00:05<00:00, 297.42it/s, loss=0.6, acc=0.611]


Epoch 30: Loss=0.6056, Acc=0.6111


Epoch 31/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 294.21it/s, loss=0.612, acc=0.616]


Epoch 31: Loss=0.6011, Acc=0.6157


Epoch 32/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 297.29it/s, loss=0.563, acc=0.621]


Epoch 32: Loss=0.5980, Acc=0.6211


Epoch 33/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 299.68it/s, loss=0.703, acc=0.626]


Epoch 33: Loss=0.5955, Acc=0.6262


Epoch 34/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 291.99it/s, loss=0.667, acc=0.634]


Epoch 34: Loss=0.5907, Acc=0.6342


Epoch 35/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 293.18it/s, loss=0.583, acc=0.646]


Epoch 35: Loss=0.5836, Acc=0.6461


Epoch 36/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 308.92it/s, loss=0.68, acc=0.661]


Epoch 36: Loss=0.5672, Acc=0.6614


Epoch 37/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 299.33it/s, loss=0.534, acc=0.696]


Epoch 37: Loss=0.5302, Acc=0.6958


Epoch 38/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 299.18it/s, loss=0.33, acc=0.755]


Epoch 38: Loss=0.4519, Acc=0.7547


Epoch 39/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 300.95it/s, loss=0.345, acc=0.809]


Epoch 39: Loss=0.3556, Acc=0.8090


Epoch 40/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 291.91it/s, loss=0.257, acc=0.839]


Epoch 40: Loss=0.2955, Acc=0.8386


Epoch 41/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 293.98it/s, loss=0.282, acc=0.854]


Epoch 41: Loss=0.2651, Acc=0.8538


Epoch 42/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 298.26it/s, loss=0.23, acc=0.867]


Epoch 42: Loss=0.2423, Acc=0.8666


Epoch 43/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 295.49it/s, loss=0.112, acc=0.88]


Epoch 43: Loss=0.2179, Acc=0.8804


Epoch 44/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 298.97it/s, loss=0.209, acc=0.887]


Epoch 44: Loss=0.2026, Acc=0.8872


Epoch 45/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 299.95it/s, loss=0.221, acc=0.89]


Epoch 45: Loss=0.1948, Acc=0.8903


Epoch 46/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 292.10it/s, loss=0.191, acc=0.897]


Epoch 46: Loss=0.1880, Acc=0.8965


Epoch 47/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 300.28it/s, loss=0.174, acc=0.899]


Epoch 47: Loss=0.1833, Acc=0.8991


Epoch 48/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 298.94it/s, loss=0.246, acc=0.903]


Epoch 48: Loss=0.1779, Acc=0.9030


Epoch 49/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 299.92it/s, loss=0.124, acc=0.909]


Epoch 49: Loss=0.1711, Acc=0.9087


Epoch 50/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 301.52it/s, loss=0.164, acc=0.914]


Epoch 50: Loss=0.1617, Acc=0.9145


Epoch 51/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 301.80it/s, loss=0.202, acc=0.921]


Epoch 51: Loss=0.1496, Acc=0.9215


Epoch 52/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 301.18it/s, loss=0.105, acc=0.925]


Epoch 52: Loss=0.1438, Acc=0.9253


Epoch 53/60: 100%|█████████████████████████████| 1563/1563 [00:05<00:00, 297.46it/s, loss=0.0543, acc=0.929]


Epoch 53: Loss=0.1389, Acc=0.9286


Epoch 54/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 297.64it/s, loss=0.143, acc=0.933]


Epoch 54: Loss=0.1351, Acc=0.9326


Epoch 55/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 298.81it/s, loss=0.126, acc=0.935]


Epoch 55: Loss=0.1313, Acc=0.9346


Epoch 56/60: 100%|█████████████████████████████| 1563/1563 [00:05<00:00, 296.72it/s, loss=0.0866, acc=0.936]


Epoch 56: Loss=0.1285, Acc=0.9363


Epoch 57/60: 100%|█████████████████████████████| 1563/1563 [00:05<00:00, 302.70it/s, loss=0.0912, acc=0.939]


Epoch 57: Loss=0.1263, Acc=0.9388


Epoch 58/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 297.53it/s, loss=0.11, acc=0.939]


Epoch 58: Loss=0.1248, Acc=0.9392


Epoch 59/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 297.94it/s, loss=0.193, acc=0.94]


Epoch 59: Loss=0.1240, Acc=0.9401


Epoch 60/60: 100%|████████████████████████████████| 1563/1563 [00:05<00:00, 300.33it/s, loss=0.16, acc=0.94]


Epoch 60: Loss=0.1236, Acc=0.9403
Starting to train model for all-hard with vocab size 45
Moving model to device:  cuda


Epoch 1/60: 100%|████████████████████████████████| 1563/1563 [00:05<00:00, 296.27it/s, loss=2.14, acc=0.124]


Epoch 1: Loss=2.1161, Acc=0.1238


Epoch 2/60: 100%|███████████████████████████████████| 1563/1563 [00:05<00:00, 298.81it/s, loss=2, acc=0.131]


Epoch 2: Loss=2.0821, Acc=0.1309


Epoch 3/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 309.00it/s, loss=0.862, acc=0.352]


Epoch 3: Loss=1.5473, Acc=0.3520


Epoch 4/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 300.19it/s, loss=0.392, acc=0.714]


Epoch 4: Loss=0.6707, Acc=0.7141


Epoch 5/60: 100%|████████████████████████████████| 1563/1563 [00:05<00:00, 301.17it/s, loss=0.424, acc=0.86]


Epoch 5: Loss=0.3547, Acc=0.8598


Epoch 6/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 299.57it/s, loss=0.0969, acc=0.909]


Epoch 6: Loss=0.2426, Acc=0.9088


Epoch 7/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 301.64it/s, loss=0.172, acc=0.926]


Epoch 7: Loss=0.2021, Acc=0.9261


Epoch 8/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 303.82it/s, loss=0.216, acc=0.949]


Epoch 8: Loss=0.1510, Acc=0.9492


Epoch 9/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 290.83it/s, loss=0.0262, acc=0.948]


Epoch 9: Loss=0.1505, Acc=0.9483


Epoch 10/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 299.27it/s, loss=0.255, acc=0.958]


Epoch 10: Loss=0.1252, Acc=0.9580


Epoch 11/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 295.97it/s, loss=0.427, acc=0.971]


Epoch 11: Loss=0.0927, Acc=0.9715


Epoch 12/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 292.36it/s, loss=0.015, acc=0.972]


Epoch 12: Loss=0.0874, Acc=0.9720


Epoch 13/60: 100%|█████████████████████████████| 1563/1563 [00:05<00:00, 300.22it/s, loss=0.0318, acc=0.984]


Epoch 13: Loss=0.0617, Acc=0.9836


Epoch 14/60: 100%|█████████████████████████████| 1563/1563 [00:05<00:00, 297.79it/s, loss=0.0536, acc=0.979]


Epoch 14: Loss=0.0753, Acc=0.9787


Epoch 15/60: 100%|████████████████████████████| 1563/1563 [00:05<00:00, 297.26it/s, loss=0.00536, acc=0.985]


Epoch 15: Loss=0.0632, Acc=0.9845


Epoch 16/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 294.16it/s, loss=0.091, acc=0.987]


Epoch 16: Loss=0.0433, Acc=0.9874


Epoch 17/60: 100%|██████████████████████████████| 1563/1563 [00:05<00:00, 300.05it/s, loss=0.205, acc=0.989]


Epoch 17: Loss=0.0451, Acc=0.9886


Epoch 18/60: 100%|█████████████████████████████| 1563/1563 [00:05<00:00, 292.76it/s, loss=0.0119, acc=0.989]


Epoch 18: Loss=0.0377, Acc=0.9891


Epoch 19/60: 100%|████████████████████████████| 1563/1563 [00:05<00:00, 292.09it/s, loss=0.00185, acc=0.992]


Epoch 19: Loss=0.0276, Acc=0.9921


Epoch 20/60: 100%|████████████████████████████| 1563/1563 [00:05<00:00, 298.89it/s, loss=0.00417, acc=0.993]


Epoch 20: Loss=0.0322, Acc=0.9928


Epoch 21/60: 100%|████████████████████████████| 1563/1563 [00:05<00:00, 298.05it/s, loss=0.00464, acc=0.991]


Epoch 21: Loss=0.0323, Acc=0.9909


Epoch 22/60: 100%|████████████████████████████| 1563/1563 [00:05<00:00, 311.00it/s, loss=0.00238, acc=0.995]


Epoch 22: Loss=0.0235, Acc=0.9948


Epoch 23/60: 100%|████████████████████████████| 1563/1563 [00:05<00:00, 306.80it/s, loss=0.00712, acc=0.993]


Epoch 23: Loss=0.0244, Acc=0.9932


Epoch 24/60: 100%|█████████████████████████████| 1563/1563 [00:05<00:00, 304.56it/s, loss=0.0027, acc=0.993]


Epoch 24: Loss=0.0245, Acc=0.9930


Epoch 25/60: 100%|██████████████████████████████| 1563/1563 [00:04<00:00, 314.19it/s, loss=0.438, acc=0.998]


Epoch 25: Loss=0.0085, Acc=0.9981


Epoch 26/60: 100%|████████████████████████████| 1563/1563 [00:05<00:00, 296.43it/s, loss=0.00252, acc=0.995]


Epoch 26: Loss=0.0194, Acc=0.9949


Epoch 27/60: 100%|████████████████████████████| 1563/1563 [00:05<00:00, 309.96it/s, loss=0.00347, acc=0.995]


Epoch 27: Loss=0.0233, Acc=0.9949


Epoch 28/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 301.09it/s, loss=0.000937, acc=1]


Epoch 28: Loss=0.0018, Acc=1.0000


Epoch 29/60: 100%|████████████████████████████| 1563/1563 [00:05<00:00, 293.34it/s, loss=0.00235, acc=0.994]


Epoch 29: Loss=0.0233, Acc=0.9942


Epoch 30/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 297.51it/s, loss=0.000882, acc=1]


Epoch 30: Loss=0.0017, Acc=1.0000


Epoch 31/60: 100%|████████████████████████████| 1563/1563 [00:05<00:00, 288.20it/s, loss=0.00366, acc=0.994]


Epoch 31: Loss=0.0267, Acc=0.9943


Epoch 32/60: 100%|████████████████████████████████| 1563/1563 [00:05<00:00, 294.18it/s, loss=0.00155, acc=1]


Epoch 32: Loss=0.0020, Acc=1.0000


Epoch 33/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 293.57it/s, loss=0.000545, acc=1]


Epoch 33: Loss=0.0010, Acc=1.0000


Epoch 34/60: 100%|████████████████████████████| 1563/1563 [00:05<00:00, 298.44it/s, loss=0.00116, acc=0.995]


Epoch 34: Loss=0.0184, Acc=0.9951


Epoch 35/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 295.67it/s, loss=0.000833, acc=1]


Epoch 35: Loss=0.0012, Acc=1.0000


Epoch 36/60: 100%|████████████████████████████████| 1563/1563 [00:05<00:00, 300.01it/s, loss=0.00037, acc=1]


Epoch 36: Loss=0.0006, Acc=1.0000


Epoch 37/60: 100%|████████████████████████████| 1563/1563 [00:05<00:00, 294.37it/s, loss=0.00157, acc=0.997]


Epoch 37: Loss=0.0087, Acc=0.9973


Epoch 38/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 299.61it/s, loss=0.000788, acc=1]


Epoch 38: Loss=0.0009, Acc=1.0000


Epoch 39/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 300.79it/s, loss=0.000369, acc=1]


Epoch 39: Loss=0.0004, Acc=1.0000


Epoch 40/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 296.12it/s, loss=0.000153, acc=1]


Epoch 40: Loss=0.0003, Acc=1.0000


Epoch 41/60: 100%|███████████████████████████| 1563/1563 [00:05<00:00, 310.55it/s, loss=0.000232, acc=0.999]


Epoch 41: Loss=0.0038, Acc=0.9988


Epoch 42/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 300.28it/s, loss=0.000196, acc=1]


Epoch 42: Loss=0.0003, Acc=1.0000


Epoch 43/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 303.66it/s, loss=0.000257, acc=1]


Epoch 43: Loss=0.0002, Acc=1.0000


Epoch 44/60: 100%|███████████████████████████| 1563/1563 [00:05<00:00, 294.78it/s, loss=0.000182, acc=0.998]


Epoch 44: Loss=0.0052, Acc=0.9984


Epoch 45/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 290.64it/s, loss=0.000477, acc=1]


Epoch 45: Loss=0.0003, Acc=1.0000


Epoch 46/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 304.02it/s, loss=0.000137, acc=1]


Epoch 46: Loss=0.0002, Acc=1.0000


Epoch 47/60: 100%|███████████████████████████████| 1563/1563 [00:05<00:00, 296.97it/s, loss=0.000128, acc=1]


Epoch 47: Loss=0.0001, Acc=1.0000


Epoch 48/60: 100%|████████████████████████████████| 1563/1563 [00:05<00:00, 295.50it/s, loss=4.74e-5, acc=1]


Epoch 48: Loss=0.0001, Acc=1.0000


Epoch 49/60: 100%|████████████████████████████████| 1563/1563 [00:05<00:00, 294.54it/s, loss=8.56e-5, acc=1]


Epoch 49: Loss=0.0001, Acc=1.0000


Epoch 50/60: 100%|████████████████████████████████| 1563/1563 [00:05<00:00, 296.64it/s, loss=7.91e-5, acc=1]


Epoch 50: Loss=0.0001, Acc=1.0000


Epoch 51/60: 100%|████████████████████████████████| 1563/1563 [00:05<00:00, 299.98it/s, loss=6.07e-5, acc=1]


Epoch 51: Loss=0.0002, Acc=1.0000


Epoch 52/60: 100%|████████████████████████████████| 1563/1563 [00:05<00:00, 300.27it/s, loss=3.03e-5, acc=1]


Epoch 52: Loss=0.0000, Acc=1.0000


Epoch 53/60: 100%|████████████████████████████████| 1563/1563 [00:05<00:00, 291.42it/s, loss=4.31e-5, acc=1]


Epoch 53: Loss=0.0000, Acc=1.0000


Epoch 54/60: 100%|████████████████████████████████| 1563/1563 [00:05<00:00, 292.56it/s, loss=2.41e-5, acc=1]


Epoch 54: Loss=0.0000, Acc=1.0000


Epoch 55/60: 100%|████████████████████████████████| 1563/1563 [00:05<00:00, 299.61it/s, loss=2.04e-5, acc=1]


Epoch 55: Loss=0.0000, Acc=1.0000


Epoch 56/60: 100%|█████████████████████████████████| 1563/1563 [00:05<00:00, 302.38it/s, loss=1.2e-5, acc=1]


Epoch 56: Loss=0.0000, Acc=1.0000


Epoch 57/60: 100%|█████████████████████████████████| 1563/1563 [00:05<00:00, 290.85it/s, loss=1.2e-5, acc=1]


Epoch 57: Loss=0.0000, Acc=1.0000


Epoch 58/60: 100%|████████████████████████████████| 1563/1563 [00:05<00:00, 300.30it/s, loss=6.93e-6, acc=1]


Epoch 58: Loss=0.0000, Acc=1.0000


Epoch 59/60: 100%|████████████████████████████████| 1563/1563 [00:05<00:00, 301.29it/s, loss=1.69e-5, acc=1]


Epoch 59: Loss=0.0000, Acc=1.0000


Epoch 60/60: 100%|████████████████████████████████| 1563/1563 [00:05<00:00, 303.04it/s, loss=4.98e-6, acc=1]


Epoch 60: Loss=0.0000, Acc=1.0000


In [10]:
def evaluate_model(model, dataloader, device='cuda'):
    model.eval()
    total = 0
    correct = 0
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            loss_mask = batch["loss_mask"].to(device)
            
            # Forward pass
            logits = model(input_ids[:, :-1])
            targets = input_ids[:, 1:]
            mask = loss_mask[:, 1:]
            
            # Get predictions
            preds = logits.argmax(dim=-1)
            
            # Identify valid rows
            per_row = mask.sum(dim=1)
            valid_rows = (per_row == 1)
            
            if not valid_rows.any():
                continue   # skip batches with no valid rows
            
            # Filter only valid rows
            mask_valid = mask[valid_rows].bool()
            preds_valid = preds[valid_rows]
            targets_valid = targets[valid_rows]
            
            # Extract predictions and true labels at masked positions
            pred_answers = preds_valid[mask_valid]
            true_answers = targets_valid[mask_valid]
            
            correct += (pred_answers == true_answers).sum().item()
            total += pred_answers.numel()
            
    accuracy = correct / total
    return accuracy

# Test
for name in train_dataset_names:
    accuracy = evaluate_model(models[name], DataLoader(
        test_datasets[name],
        batch_size=64,
        shuffle=False,
        collate_fn=collate_fn
    ))
    print(f"Testing model {name} accuracy on {name}: {accuracy}")

Testing model easy accuracy on easy: 1.0
Testing model bpe-hard accuracy on bpe-hard: 1.0
Testing model mult-hard accuracy on mult-hard: 0.996
Testing model length-hard accuracy on length-hard: 0.9291
Testing model all-hard accuracy on all-hard: 0.9997
