# Setup

In [None]:
# my virtual environments are rarely properly connected to jupyter so this fixes that. 
# you prolly won't need this cell but running it won't hurt anything either
import sys
import os
current_dir = os.getcwd()  # Get the current working directory
venv_dir = os.path.join(current_dir, 'venv') 
python_version = str(sys.version_info.major) + '.' + str(sys.version_info.minor)
site_packages_path = os.path.join(venv_dir, 'lib', 'python' + python_version, 'site-packages')
sys.path.append(site_packages_path) 

# Instantiate a brand new model

In [2]:
# config file
from config import ModelConfig, TrainConfig
cfg = ModelConfig()
tcfg = TrainConfig()
print(cfg, '\n\n', tcfg)

# import the tokenizer specified by cfg
from tools import import_from_nested_path
imported_objects = import_from_nested_path(['tokenizers', cfg.tokenizer], 'tokenizer', ['get_tokenizer'])
get_tokenizer = imported_objects.get('get_tokenizer')
tokenizer = get_tokenizer(size = cfg.vocab_len)

# model modules
from modules.model import Model
model = Model(cfg).to(cfg.device)

# print the number of parameters in the model
print(f'\n{sum(p.numel() for p in model.parameters())/1e3}K parameters\n')
print(model)

ModelConfig(dim=32, device='cpu', dropout_rate=0.1, weight_tying=True, tokenizer='bpe_v2', vocab_len=1024, num_layers=6, second_resid_norm=False, mlp_hidden_mult=4, mlp_bias=False, mlp_nonlinearity='SiLU', mlp_gated=True, num_q_heads=2, num_kv_heads=1, head_dim=16, theta=10000, max_seq_len=128, ca_num_q_heads=2, ca_num_kv_heads=1, ca_head_dim=16, scale_first_resid=True, norm_type='RMSNorm', norm_affine=True, norm_bias=True, eps=1e-06, pool_type='sum', pre_pool_norm=True, pool_output_linear=False, pool_bias=False, compress_freq='constant', compress_freq_n=1, fs_mult=4, fs_periods=3, fs_loss_lambda=1.0, max_batch_size=1) 

 TrainConfig(weight_decay=0.05, batch_size=32, max_iters=20, eval_interval=2, eval_samples=1, checkpoint_interval=None, lr_init=1e-06, lr_max=0.1, lr_min=0.001, warmup_iters=0, final_flat_iters=2, anneal_type='cos', num_restarts=3, T_mult=2)

110.752K parameters

Model(
  (token_embedder): Embedding(1027, 32)
  (body_layers): ModuleList(
    (0-2): 3 x Layer(
      (pr

# Training

In [3]:
import torch
from tools import get_data_loader
from train import scheduler_lambda, train

optimizer = torch.optim.AdamW(model.parameters(), lr = tcfg.lr_max, weight_decay = tcfg.weight_decay)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_lambda)

train_data_loader = get_data_loader(batch_size=tcfg.batch_size, split='train')
test_data_loader = get_data_loader(batch_size=tcfg.batch_size, split='validation')

Found cached dataset json (/Users/tunadorable/.cache/huggingface/datasets/noanabeshima___json/noanabeshima--TinyStoriesV2-40971520ba3bacdf/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)
Found cached dataset json (/Users/tunadorable/.cache/huggingface/datasets/noanabeshima___json/noanabeshima--TinyStoriesV2-40971520ba3bacdf/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)


In [4]:
if False: # set to true if you'd like to see a graph of the learning rate schedule
    import matplotlib.pyplot as plt
    
    # Generate learning rate values
    lrs = [scheduler_lambda(i) for i in range(tcfg.max_iters)]
    
    # Plot the learning rates
    plt.figure(figsize=(10, 5))
    plt.plot(lrs, label='Learning Rate')
    plt.title('Learning Rate Schedule')
    plt.xlabel('Iteration')
    plt.ylabel('Learning Rate')
    plt.grid(True)
    plt.legend()
    plt.show()

In [5]:
model, optimizer, log_data = train(
    model, 
    tokenizer, 
    cfg, 
    optimizer,
    scheduler,
    tcfg, 
    train_data_loader,
    test_data_loader,
    #log_data: list = None, 
    #detect_anomoly = False # use if you're getting crazy errors about a the gradient being broken
)

  5%|█▉                                    | 1/20 [00:01<00:19,  1.02s/it]

step 0000: lr 0.010000, train loss 8.7213, val loss 8.7315, ppl 6195, time elapsed: 0.57 seconds


 15%|█████▋                                | 3/20 [00:02<00:12,  1.32it/s]

step 0002: lr 0.007525, train loss 8.1504, val loss 8.1596, ppl 3497, time elapsed: 1.84 seconds


 25%|█████████▌                            | 5/20 [00:03<00:10,  1.41it/s]

step 0004: lr 0.009831, train loss 7.5920, val loss 7.5594, ppl 1919, time elapsed: 3.12 seconds


 35%|█████████████▎                        | 7/20 [00:04<00:08,  1.46it/s]

step 0006: lr 0.005050, train loss 6.9572, val loss 6.9563, ppl 1050, time elapsed: 4.39 seconds


 45%|█████████████████                     | 9/20 [00:06<00:07,  1.47it/s]

step 0008: lr 0.000269, train loss 6.8806, val loss 6.8202, ppl 916, time elapsed: 5.65 seconds


 55%|████████████████████▎                | 11/20 [00:07<00:06,  1.48it/s]

step 0010: lr 0.009337, train loss 6.1873, val loss 6.2482, ppl 517, time elapsed: 6.91 seconds


 65%|████████████████████████             | 13/20 [00:08<00:04,  1.49it/s]

step 0012: lr 0.006944, train loss 5.9866, val loss 5.9878, ppl 399, time elapsed: 8.17 seconds


 75%|███████████████████████████▊         | 15/20 [00:09<00:03,  1.49it/s]

step 0014: lr 0.003769, train loss 5.8173, val loss 5.8236, ppl 338, time elapsed: 9.42 seconds


 85%|███████████████████████████████▍     | 17/20 [00:11<00:02,  1.48it/s]

step 0016: lr 0.001123, train loss 5.7426, val loss 5.8237, ppl 338, time elapsed: 10.69 seconds


 95%|███████████████████████████████████▏ | 19/20 [00:12<00:00,  1.47it/s]

step 0018: lr 0.000100, train loss 5.8140, val loss 5.7992, ppl 330, time elapsed: 11.98 seconds


100%|█████████████████████████████████████| 20/20 [00:13<00:00,  1.50it/s]

step 0019: lr 0.000100, train loss 5.7262, val loss 5.8085, ppl 333, time elapsed: 12.85 seconds





# inference test before you decide to save it

In [6]:
from inference import generate
prompt = "Once upon a time"
model.eval()
output = generate(
    prompt, 
    model, 
    tokenizer,
    #max_gen_len = 512,
    temperature = 0.9,
    #memory_saver_div = 8,
    #top_p = 0.9,
    #top_k = 32,
)
model.train()
print(output)

Once upon a time a.
 foen bu bigt. ca trt h.un Shes wanted a fo an bird, I wantdle. Time tod sor "t.
 ant thd.e theem fod wasstieyer th nod anie not ae bigo to to herinem adndet ba,tereo anel.emrd, thetntddd thr li..ro ca a, animals It.rs out an anin.,t o th Th.



# Saving your final model
if `tcfg.checkpoint_interval != None` then checkpoints have already been saved

you DO still need to do this even if you had been saving checkpoints; the final state has not yet been saved

In [7]:
from tools import save_model
save_model(model, cfg, tcfg, log_data)