# Setup

In [1]:
# my virtual environments are rarely properly connected to jupyter so this fixes that. 
# you won't need this cell but running it won't hurt anything either
import sys
import os
current_dir = os.getcwd()  # Get the current working directory
venv_dir = os.path.join(current_dir, './venv') 
python_version = str(sys.version_info.major) + '.' + str(sys.version_info.minor)
site_packages_path = os.path.join(venv_dir, 'lib', 'python' + python_version, 'site-packages')
sys.path.append(site_packages_path) 

# Instantiate a brand new model

In [2]:
# tokenizer
from tokenizer import get_tokenizer
size = 1024 # size options are 128, 256, 512 and 1024
path = f'./tokenizers/tiny_stories_tokenizer_{size}.model'
tokenizer = get_tokenizer(path) 

# config file
from config import ModelConfig, TrainConfig
cfg = ModelConfig()
cfg.vocab_len = tokenizer.vocab_len
print(cfg, '\n')

# model modules
from model import customGPT
model = customGPT(cfg).to(cfg.device)

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters\n')
print(model)

ModelConfig(dim=64, vocab_len=1027, device='cpu', num_layers=2, pre_connect_dropout=True, second_resid_norm=False, mlp_hidden_mult=2, mlp_bias=False, mlp_nonlinearity='GeLU', mlp_gated=True, num_q_heads=4, num_kv_heads=1, theta=10000, max_seq_len=512, scale_first_resid=True, norm_type='RMSNorm', norm_affine=True, norm_bias=True, eps=1e-06, max_batch_size=1) 

136.0 K parameters

customGPT(
  (token_embedder): Embedding(1027, 64)
  (layers): ModuleList(
    (0-1): 2 x ResidualLayer(
      (pre_attn_norm): Norm()
      (attn): MQSA(
        (Wq): Linear(in_features=64, out_features=64, bias=False)
        (Wk): Linear(in_features=64, out_features=16, bias=False)
        (Wv): Linear(in_features=64, out_features=16, bias=False)
        (Wo): Linear(in_features=64, out_features=64, bias=False)
      )
      (pre_mlp_norm): Norm()
      (mlp): MLP(
        (Wgate): Linear(in_features=64, out_features=128, bias=False)
        (Wup): Linear(in_features=64, out_features=128, bias=False)
      

# Training

In [3]:
import torch
from train import get_data_loader, scheduler_lambda, train

tcfg = TrainConfig()
optimizer = torch.optim.AdamW(model.parameters(), lr = tcfg.lr_max, weight_decay = tcfg.weight_decay)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_lambda)

train_data_loader = get_data_loader(batch_size=tcfg.batch_size, split='train')
test_data_loader = get_data_loader(batch_size=tcfg.batch_size, split='validation')

Found cached dataset json (/Users/tunadorable/.cache/huggingface/datasets/noanabeshima___json/noanabeshima--TinyStoriesV2-226173b7dd235c68/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)
Found cached dataset json (/Users/tunadorable/.cache/huggingface/datasets/noanabeshima___json/noanabeshima--TinyStoriesV2-226173b7dd235c68/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)


In [4]:
if False: # set to true if you'd like to see a graph of the learning rate schedule
    import matplotlib.pyplot as plt
    
    # Generate learning rate values
    lrs = [scheduler_lambda(i) for i in range(tcfg.max_iters)]
    
    # Plot the learning rates
    plt.figure(figsize=(10, 5))
    plt.plot(lrs, label='Learning Rate')
    plt.title('Learning Rate Schedule')
    plt.xlabel('Iteration')
    plt.ylabel('Learning Rate')
    plt.grid(True)
    plt.legend()
    plt.show()

In [5]:
model, optimizer, log_data = train(
    model, 
    tokenizer, 
    cfg, 
    optimizer,
    scheduler,
    tcfg, 
    train_data_loader,
    test_data_loader,
    #log_data: list = None, 
    #detect_anomoly = False # use if you're getting crazy errors about a the gradient being broken
)

step 0000: lr 0.000010, train loss 51.8585, val loss 52.3006, ppl 51743153155552311246848, time elapsed: 0.71 seconds
step 0050: lr 0.000037, train loss 50.3469, val loss 49.7388, ppl 3992911695825004920832, time elapsed: 25.96 seconds
step 0099: lr 0.000000, train loss 48.5074, val loss 49.6796, ppl 3763391088840625094656, time elapsed: 54.49 seconds


# inference test before you decide to save it
if `tcfg.checkpoint_interval != None` then checkpoints have already been saved

In [6]:
from inference import generate
prompt = "Once"
model.eval()
output = generate(
    prompt, 
    model, 
    tokenizer,
    #max_gen_len = 512,
    #temperature = 0.7,
    #memory_saver_div = 8,
    #top_p = 0.9,
    #top_k = 32,
)
model.train()
print(output)

OnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnce

# Saving your final model

In [7]:
from tools import save_model
save_model(model, cfg, tcfg, log_data)