In [None]:
import torch
from train import trainGPT, SanityConfig
from dataclasses import dataclass
from transformers import AutoTokenizer
from data_gen import gen_grok_data

@dataclass
class GrokConfig:
    block_size: int = 1024
    vocab_size: int = 50304
    n_layer: int = 2
    n_head: int = 4 # from the paper
    n_embd: int = 128
    dropout: float = 0.0
    bias: bool = True 

torch.manual_seed(2025)
config = GrokConfig()
tokenizer = AutoTokenizer.from_pretrained('google/byt5-base')
config.vocab_size = tokenizer.vocab_size
tokenizer.pad_token = tokenizer.eos_token 
train, test = gen_grok_data([97], train_split=0.50, a=False, s=False, d=True) # 0.50 like the paper

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = trainGPT(config, 
                 tokenizer, 
                 train, 
                 test, 
                 device=device, 
                 epochs=100000, 
                 weight_decay=1, # params from the paper
                 lr=1e-3,
                 batch_size=512,
                 betas=[0.9, 0.98],
                 max_steps=1e5,
                 grokking=True)

  from .autonotebook import tqdm as notebook_tqdm
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


['33 / 81 = 4', '0 / 60 = 0', '42 / 3 = 14', '39 / 4 = 34', '16 / 77 = 38', '52 / 62 = 29', '87 / 73 = 57', '77 / 23 = 16', '22 / 55 = 78', '20 / 45 = 22', '61 / 95 = 18', '73 / 69 = 84', '42 / 34 = 64', '45 / 39 = 31', '77 / 70 = 69', '41 / 1 = 41', '37 / 20 = 94', '85 / 34 = 51', '90 / 27 = 68', '52 / 95 = 71', '83 / 81 = 13', '80 / 46 = 65', '47 / 9 = 16', '91 / 5 = 57', '84 / 79 = 60', '0 / 10 = 0', '75 / 35 = 16', '62 / 34 = 76', '9 / 13 = 38', '52 / 22 = 20', '4 / 34 = 80', '57 / 9 = 71', '30 / 86 = 59', '80 / 60 = 66', '49 / 49 = 1', '86 / 91 = 18', '23 / 65 = 69', '18 / 26 = 38', '69 / 78 = 27', '87 / 75 = 71', '45 / 77 = 22', '26 / 25 = 67', '37 / 64 = 90', '97 / 43 = 0', '30 / 37 = 48', '53 / 42 = 59', '30 / 54 = 76', '33 / 54 = 6', '91 / 2 = 94', '45 / 2 = 71', '32 / 27 = 91', '57 / 59 = 47', '79 / 13 = 21', '89 / 48 = 16', '62 / 13 = 57', '67 / 62 = 84', '8 / 80 = 68', '37 / 58 = 9', '53 / 41 = 77', '83 / 23 = 50', '47 / 74 = 57', '52 / 11 = 40', '31 / 47 = 44', '48 / 67 = 

Training Progress:   0%|          | 340/100000.0 [06:37<32:17:58,  1.17s/it, loss=0.08768026158213615, train_acc=0.9541245698928833, val_loss=3.4098711013793945, val_acc=0.10290404409170151, step=3400]