# Setup

In [1]:
# my virtual environments are rarely properly connected to jupyter so this fixes that. 
# you won't need this cell but running it won't hurt anything either
import sys
import os
current_dir = os.getcwd()  # Get the current working directory
venv_dir = os.path.join(current_dir, './venv') 
python_version = str(sys.version_info.major) + '.' + str(sys.version_info.minor)
site_packages_path = os.path.join(venv_dir, 'lib', 'python' + python_version, 'site-packages')
sys.path.append(site_packages_path) 

# Instantiate a brand new model

In [2]:
# tokenizer
from tokenizer import get_tokenizer
tokenizer = get_tokenizer(size = 2048) # size options are 95(character-wise), 128, 256, 512, 1024, 2048 & 4096

# config file
from config import ModelConfig, TrainConfig
cfg = ModelConfig()
cfg.vocab_len = tokenizer.vocab_len
tcfg = TrainConfig()
print(cfg, '\n\n', tcfg)

# model modules
from model import customGPT
model = customGPT(cfg).to(cfg.device)

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters\n')
print(model)

ModelConfig(dim=64, vocab_len=2051, device='cpu', num_layers=8, second_resid_norm=False, mlp_hidden_mult=2, mlp_bias=False, mlp_nonlinearity='SiLU', mlp_gated=True, num_q_heads=8, num_kv_heads=2, head_dim=16, theta=10000, max_seq_len=512, scale_first_resid=True, norm_type='RMSNorm', norm_affine=True, norm_bias=True, eps=1e-06, max_batch_size=1) 

 TrainConfig(weight_decay=0.02, batch_size=32, max_iters=1000, eval_interval=10, eval_samples=1, checkpoint_interval=None, lr_max=0.1, lr_min=1e-06, warmup_iters=100, final_flat_iters=100, anneal_type='cos', num_restarts=3, T_mult=2)
493.888 K parameters

customGPT(
  (token_embedder): Embedding(2051, 64)
  (layers): ModuleList(
    (0-7): 8 x ResidualLayer(
      (pre_attn_norm): Norm()
      (attn): MQSA(
        (Wq): Linear(in_features=64, out_features=128, bias=False)
        (Wk): Linear(in_features=64, out_features=32, bias=False)
        (Wv): Linear(in_features=64, out_features=32, bias=False)
        (Wo): Linear(in_features=128, out

# Training

In [3]:
import torch
from tools import get_data_loader
from train import scheduler_lambda, train

optimizer = torch.optim.AdamW(model.parameters(), lr = tcfg.lr_max, weight_decay = tcfg.weight_decay)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_lambda)

train_data_loader = get_data_loader(batch_size=tcfg.batch_size, split='train')
test_data_loader = get_data_loader(batch_size=tcfg.batch_size, split='validation')

In [4]:
if False: # set to true if you'd like to see a graph of the learning rate schedule
    import matplotlib.pyplot as plt
    
    # Generate learning rate values
    lrs = [scheduler_lambda(i) for i in range(tcfg.max_iters)]
    
    # Plot the learning rates
    plt.figure(figsize=(10, 5))
    plt.plot(lrs, label='Learning Rate')
    plt.title('Learning Rate Schedule')
    plt.xlabel('Iteration')
    plt.ylabel('Learning Rate')
    plt.grid(True)
    plt.legend()
    plt.show()

In [5]:
model, optimizer, log_data = train(
    model, 
    tokenizer, 
    cfg, 
    optimizer,
    scheduler,
    tcfg, 
    train_data_loader,
    test_data_loader,
    #log_data: list = None, 
    #detect_anomoly = False # use if you're getting crazy errors about a the gradient being broken
)

  0%|                                                                          | 1/1000 [00:10<2:56:31, 10.60s/it]

step 0000: lr 0.000000, train loss 63.0248, val loss 63.1004, ppl 2536014693363622611949453312, time elapsed: 5.46 seconds


  1%|▊                                                                        | 11/1000 [01:06<1:49:39,  6.65s/it]

step 0010: lr 0.001000, train loss 61.5801, val loss 61.5277, ppl 526206877719226915238182912, time elapsed: 61.35 seconds


  2%|█▌                                                                       | 21/1000 [02:00<1:43:11,  6.32s/it]

step 0020: lr 0.002000, train loss 30.6337, val loss 30.7319, ppl 22217834364928, time elapsed: 115.59 seconds


  3%|██▎                                                                      | 31/1000 [02:53<1:40:52,  6.25s/it]

step 0030: lr 0.003000, train loss 16.0729, val loss 16.2064, ppl 10922704, time elapsed: 168.03 seconds


  4%|██▉                                                                      | 41/1000 [03:45<1:40:18,  6.28s/it]

step 0040: lr 0.004000, train loss 11.4316, val loss 11.3215, ppl 82576, time elapsed: 221.08 seconds


  5%|███▋                                                                     | 51/1000 [04:41<1:48:54,  6.89s/it]

step 0050: lr 0.005000, train loss 8.4893, val loss 8.6791, ppl 5879, time elapsed: 275.72 seconds


  6%|████▍                                                                    | 61/1000 [05:31<1:35:10,  6.08s/it]

step 0060: lr 0.006000, train loss 7.0748, val loss 7.1039, ppl 1217, time elapsed: 327.10 seconds


  7%|█████▏                                                                   | 71/1000 [06:22<1:33:20,  6.03s/it]

step 0070: lr 0.007000, train loss 6.0622, val loss 6.1606, ppl 474, time elapsed: 377.80 seconds


  8%|█████▉                                                                   | 81/1000 [07:13<1:32:03,  6.01s/it]

step 0080: lr 0.008000, train loss 5.4834, val loss 5.5435, ppl 256, time elapsed: 428.44 seconds


  9%|██████▋                                                                  | 91/1000 [08:03<1:31:00,  6.01s/it]

step 0090: lr 0.009000, train loss 5.2053, val loss 5.1188, ppl 167, time elapsed: 478.92 seconds


 10%|███████▎                                                                | 101/1000 [08:53<1:29:50,  6.00s/it]

step 0100: lr 0.010000, train loss 4.9042, val loss 4.9832, ppl 146, time elapsed: 529.02 seconds


 11%|███████▉                                                                | 111/1000 [09:45<1:34:11,  6.36s/it]

step 0110: lr 0.009157, train loss 4.6523, val loss 4.6749, ppl 107, time elapsed: 581.09 seconds


 12%|████████▋                                                               | 121/1000 [10:36<1:28:40,  6.05s/it]

step 0120: lr 0.006913, train loss 4.4726, val loss 4.5084, ppl 91, time elapsed: 632.09 seconds


 13%|█████████▍                                                              | 131/1000 [11:27<1:27:00,  6.01s/it]

step 0130: lr 0.004025, train loss 4.4285, val loss 4.3751, ppl 79, time elapsed: 682.90 seconds


 14%|██████████▏                                                             | 141/1000 [12:18<1:26:39,  6.05s/it]

step 0140: lr 0.001465, train loss 4.3653, val loss 4.4016, ppl 82, time elapsed: 733.65 seconds


 15%|██████████▊                                                             | 151/1000 [13:09<1:24:58,  6.00s/it]

step 0150: lr 0.000096, train loss 4.4078, val loss 4.3562, ppl 78, time elapsed: 784.34 seconds


 16%|███████████▌                                                            | 161/1000 [13:59<1:24:36,  6.05s/it]

step 0160: lr 0.009904, train loss 4.3608, val loss 4.3373, ppl 76, time elapsed: 834.94 seconds


 17%|████████████▎                                                           | 171/1000 [14:50<1:22:27,  5.97s/it]

step 0170: lr 0.009410, train loss 4.2756, val loss 4.1803, ppl 65, time elapsed: 885.39 seconds


 18%|█████████████                                                           | 181/1000 [15:40<1:21:28,  5.97s/it]

step 0180: lr 0.008536, train loss 3.9874, val loss 3.9986, ppl 55, time elapsed: 935.85 seconds


 19%|█████████████▊                                                          | 191/1000 [16:31<1:22:08,  6.09s/it]

step 0190: lr 0.007357, train loss 3.9803, val loss 3.9341, ppl 51, time elapsed: 986.37 seconds


 20%|██████████████▍                                                         | 201/1000 [17:21<1:19:46,  5.99s/it]

step 0200: lr 0.005975, train loss 3.8553, val loss 3.8799, ppl 48, time elapsed: 1036.89 seconds


 21%|███████████████▏                                                        | 211/1000 [18:11<1:18:50,  6.00s/it]

step 0210: lr 0.004510, train loss 3.8329, val loss 3.6867, ppl 40, time elapsed: 1087.30 seconds


 22%|███████████████▉                                                        | 221/1000 [19:02<1:17:18,  5.95s/it]

step 0220: lr 0.003087, train loss 3.6879, val loss 3.6695, ppl 39, time elapsed: 1137.73 seconds


 23%|████████████████▋                                                       | 231/1000 [19:52<1:17:28,  6.05s/it]

step 0230: lr 0.001828, train loss 3.6541, val loss 3.6754, ppl 39, time elapsed: 1187.90 seconds


 24%|█████████████████▎                                                      | 241/1000 [20:43<1:18:52,  6.24s/it]

step 0240: lr 0.000843, train loss 3.7186, val loss 3.6748, ppl 39, time elapsed: 1238.45 seconds


 25%|██████████████████                                                      | 251/1000 [21:37<1:19:17,  6.35s/it]

step 0250: lr 0.000215, train loss 3.7113, val loss 3.6625, ppl 39, time elapsed: 1292.17 seconds


 26%|██████████████████▊                                                     | 261/1000 [22:31<1:18:40,  6.39s/it]

step 0260: lr 0.000000, train loss 3.8241, val loss 3.6180, ppl 37, time elapsed: 1346.20 seconds


 27%|███████████████████▌                                                    | 271/1000 [23:24<1:17:36,  6.39s/it]

step 0270: lr 0.009946, train loss 3.7162, val loss 3.5677, ppl 35, time elapsed: 1399.50 seconds


 28%|████████████████████▏                                                   | 281/1000 [24:18<1:15:14,  6.28s/it]

step 0280: lr 0.009785, train loss 3.6009, val loss 3.5212, ppl 34, time elapsed: 1453.62 seconds


 29%|████████████████████▉                                                   | 291/1000 [25:11<1:14:31,  6.31s/it]

step 0290: lr 0.009520, train loss 3.5843, val loss 3.5569, ppl 35, time elapsed: 1506.24 seconds


 30%|█████████████████████▋                                                  | 301/1000 [26:04<1:13:11,  6.28s/it]

step 0300: lr 0.009157, train loss 3.5381, val loss 3.4974, ppl 33, time elapsed: 1559.79 seconds


 31%|██████████████████████▍                                                 | 311/1000 [26:57<1:12:03,  6.28s/it]

step 0310: lr 0.008705, train loss 3.4050, val loss 3.4404, ppl 31, time elapsed: 1612.50 seconds


 32%|███████████████████████                                                 | 321/1000 [27:50<1:11:38,  6.33s/it]

step 0320: lr 0.008172, train loss 3.3728, val loss 3.3141, ppl 27, time elapsed: 1665.16 seconds


 33%|███████████████████████▊                                                | 331/1000 [28:42<1:09:19,  6.22s/it]

step 0330: lr 0.007571, train loss 3.3293, val loss 3.3512, ppl 29, time elapsed: 1717.88 seconds


 34%|████████████████████████▌                                               | 341/1000 [29:35<1:08:43,  6.26s/it]

step 0340: lr 0.006913, train loss 3.4458, val loss 3.2617, ppl 26, time elapsed: 1770.72 seconds


 35%|█████████████████████████▎                                              | 351/1000 [30:28<1:06:54,  6.18s/it]

step 0350: lr 0.006215, train loss 3.1871, val loss 3.2522, ppl 26, time elapsed: 1823.42 seconds


 36%|█████████████████████████▉                                              | 361/1000 [31:20<1:06:50,  6.28s/it]

step 0360: lr 0.005490, train loss 3.2215, val loss 3.1580, ppl 24, time elapsed: 1875.69 seconds


 37%|██████████████████████████▋                                             | 371/1000 [32:13<1:05:33,  6.25s/it]

step 0370: lr 0.004755, train loss 3.1981, val loss 3.1319, ppl 23, time elapsed: 1928.63 seconds


 38%|███████████████████████████▍                                            | 381/1000 [33:06<1:04:32,  6.26s/it]

step 0380: lr 0.004025, train loss 3.1593, val loss 3.1908, ppl 24, time elapsed: 1981.25 seconds


 39%|████████████████████████████▏                                           | 391/1000 [34:00<1:02:00,  6.11s/it]

step 0390: lr 0.003316, train loss 3.1118, val loss 3.1475, ppl 23, time elapsed: 2035.51 seconds


 40%|████████████████████████████▊                                           | 401/1000 [34:51<1:00:40,  6.08s/it]

step 0400: lr 0.002643, train loss 3.2159, val loss 3.2070, ppl 25, time elapsed: 2086.78 seconds


 41%|██████████████████████████████▍                                           | 411/1000 [35:42<59:11,  6.03s/it]

step 0410: lr 0.002022, train loss 3.1306, val loss 3.1112, ppl 22, time elapsed: 2137.53 seconds


 42%|███████████████████████████████▏                                          | 421/1000 [36:33<59:02,  6.12s/it]

step 0420: lr 0.001465, train loss 3.1821, val loss 3.0941, ppl 22, time elapsed: 2188.57 seconds


 43%|███████████████████████████████▉                                          | 431/1000 [37:24<57:51,  6.10s/it]

step 0430: lr 0.000984, train loss 3.0257, val loss 3.1372, ppl 23, time elapsed: 2239.56 seconds


 44%|████████████████████████████████▋                                         | 441/1000 [38:15<57:17,  6.15s/it]

step 0440: lr 0.000590, train loss 3.1645, val loss 3.1271, ppl 23, time elapsed: 2290.85 seconds


 45%|█████████████████████████████████▎                                        | 451/1000 [39:06<55:14,  6.04s/it]

step 0450: lr 0.000292, train loss 3.0142, val loss 3.0113, ppl 20, time elapsed: 2341.56 seconds


 46%|██████████████████████████████████                                        | 461/1000 [39:59<57:45,  6.43s/it]

step 0460: lr 0.000096, train loss 3.1197, val loss 3.1033, ppl 22, time elapsed: 2394.42 seconds


 47%|██████████████████████████████████▊                                       | 471/1000 [40:53<56:40,  6.43s/it]

step 0470: lr 0.000006, train loss 3.1836, val loss 3.1651, ppl 24, time elapsed: 2448.29 seconds


 48%|███████████████████████████████████▌                                      | 481/1000 [41:47<55:35,  6.43s/it]

step 0480: lr 0.009994, train loss 3.1640, val loss 3.1138, ppl 23, time elapsed: 2502.01 seconds


 49%|████████████████████████████████████▎                                     | 491/1000 [42:41<54:33,  6.43s/it]

step 0490: lr 0.009962, train loss 3.0361, val loss 3.0889, ppl 22, time elapsed: 2555.95 seconds


 50%|█████████████████████████████████████                                     | 501/1000 [43:34<52:32,  6.32s/it]

step 0500: lr 0.009904, train loss 3.0993, val loss 3.1438, ppl 23, time elapsed: 2609.17 seconds


 51%|█████████████████████████████████████▊                                    | 511/1000 [44:28<52:19,  6.42s/it]

step 0510: lr 0.009819, train loss 3.0469, val loss 3.0104, ppl 20, time elapsed: 2662.98 seconds


 52%|██████████████████████████████████████▌                                   | 521/1000 [45:21<49:02,  6.14s/it]

step 0520: lr 0.009708, train loss 2.9228, val loss 3.1552, ppl 23, time elapsed: 2717.01 seconds


 53%|███████████████████████████████████████▎                                  | 531/1000 [46:12<47:14,  6.04s/it]

step 0530: lr 0.009571, train loss 3.0347, val loss 3.1676, ppl 24, time elapsed: 2768.01 seconds


 54%|████████████████████████████████████████                                  | 541/1000 [47:03<46:26,  6.07s/it]

step 0540: lr 0.009410, train loss 2.9904, val loss 3.0536, ppl 21, time elapsed: 2818.90 seconds


 55%|████████████████████████████████████████▊                                 | 551/1000 [47:54<45:32,  6.09s/it]

step 0550: lr 0.009224, train loss 2.9902, val loss 2.9655, ppl 19, time elapsed: 2870.15 seconds


 56%|█████████████████████████████████████████▌                                | 561/1000 [48:46<44:52,  6.13s/it]

step 0560: lr 0.009016, train loss 3.0925, val loss 2.8701, ppl 18, time elapsed: 2921.20 seconds


 57%|██████████████████████████████████████████▎                               | 571/1000 [49:36<43:19,  6.06s/it]

step 0570: lr 0.008786, train loss 2.9130, val loss 2.9917, ppl 20, time elapsed: 2972.17 seconds


 58%|██████████████████████████████████████████▉                               | 581/1000 [50:27<42:01,  6.02s/it]

step 0580: lr 0.008536, train loss 2.9642, val loss 2.9656, ppl 19, time elapsed: 3022.88 seconds


 59%|███████████████████████████████████████████▋                              | 591/1000 [51:19<41:40,  6.11s/it]

step 0590: lr 0.008266, train loss 2.9211, val loss 2.8636, ppl 18, time elapsed: 3074.21 seconds


 60%|████████████████████████████████████████████▍                             | 601/1000 [52:10<40:48,  6.14s/it]

step 0600: lr 0.007979, train loss 2.9221, val loss 2.9257, ppl 19, time elapsed: 3125.57 seconds


 61%|█████████████████████████████████████████████▏                            | 611/1000 [53:01<39:25,  6.08s/it]

step 0610: lr 0.007675, train loss 2.8540, val loss 2.9184, ppl 19, time elapsed: 3177.14 seconds


 62%|█████████████████████████████████████████████▉                            | 621/1000 [53:53<38:36,  6.11s/it]

step 0620: lr 0.007357, train loss 2.9317, val loss 2.8298, ppl 17, time elapsed: 3228.58 seconds


 63%|██████████████████████████████████████████████▋                           | 631/1000 [54:45<38:03,  6.19s/it]

step 0630: lr 0.007026, train loss 2.8122, val loss 2.9760, ppl 20, time elapsed: 3280.32 seconds


 64%|███████████████████████████████████████████████▍                          | 641/1000 [55:38<37:15,  6.23s/it]

step 0640: lr 0.006684, train loss 2.8574, val loss 2.8773, ppl 18, time elapsed: 3333.26 seconds


 65%|████████████████████████████████████████████████▏                         | 651/1000 [56:30<35:43,  6.14s/it]

step 0650: lr 0.006334, train loss 2.7373, val loss 2.8372, ppl 17, time elapsed: 3385.39 seconds


 66%|████████████████████████████████████████████████▉                         | 661/1000 [57:21<34:20,  6.08s/it]

step 0660: lr 0.005975, train loss 2.8714, val loss 2.8197, ppl 17, time elapsed: 3437.01 seconds


 67%|█████████████████████████████████████████████████▋                        | 671/1000 [58:13<33:43,  6.15s/it]

step 0670: lr 0.005612, train loss 2.7616, val loss 2.8985, ppl 18, time elapsed: 3488.65 seconds


 68%|██████████████████████████████████████████████████▍                       | 681/1000 [59:05<32:19,  6.08s/it]

step 0680: lr 0.005245, train loss 2.8418, val loss 2.7765, ppl 16, time elapsed: 3540.37 seconds


 69%|███████████████████████████████████████████████████▏                      | 691/1000 [59:56<31:48,  6.18s/it]

step 0690: lr 0.004877, train loss 2.7616, val loss 2.7688, ppl 16, time elapsed: 3592.28 seconds


 70%|██████████████████████████████████████████████████▍                     | 701/1000 [1:00:48<30:18,  6.08s/it]

step 0700: lr 0.004510, train loss 2.7901, val loss 2.7359, ppl 15, time elapsed: 3643.55 seconds


 71%|███████████████████████████████████████████████████▏                    | 711/1000 [1:01:39<29:07,  6.05s/it]

step 0710: lr 0.004145, train loss 2.8385, val loss 2.7485, ppl 16, time elapsed: 3694.54 seconds


 72%|███████████████████████████████████████████████████▉                    | 721/1000 [1:02:30<28:12,  6.07s/it]

step 0720: lr 0.003785, train loss 2.5733, val loss 2.7395, ppl 15, time elapsed: 3745.70 seconds


 73%|████████████████████████████████████████████████████▋                   | 731/1000 [1:03:21<27:13,  6.07s/it]

step 0730: lr 0.003432, train loss 2.6900, val loss 2.6090, ppl 14, time elapsed: 3797.03 seconds


 74%|█████████████████████████████████████████████████████▎                  | 741/1000 [1:04:13<26:29,  6.14s/it]

step 0740: lr 0.003087, train loss 2.8005, val loss 2.6209, ppl 14, time elapsed: 3848.40 seconds


 75%|██████████████████████████████████████████████████████                  | 751/1000 [1:05:04<25:16,  6.09s/it]

step 0750: lr 0.002752, train loss 2.7624, val loss 2.7472, ppl 16, time elapsed: 3899.70 seconds


 76%|██████████████████████████████████████████████████████▊                 | 761/1000 [1:05:56<24:34,  6.17s/it]

step 0760: lr 0.002430, train loss 2.7164, val loss 2.6561, ppl 14, time elapsed: 3951.36 seconds


 77%|███████████████████████████████████████████████████████▌                | 771/1000 [1:06:47<23:14,  6.09s/it]

step 0770: lr 0.002121, train loss 2.7114, val loss 2.7072, ppl 15, time elapsed: 4002.47 seconds


 78%|████████████████████████████████████████████████████████▏               | 781/1000 [1:07:38<21:58,  6.02s/it]

step 0780: lr 0.001828, train loss 2.7023, val loss 2.6466, ppl 14, time elapsed: 4053.55 seconds


 79%|████████████████████████████████████████████████████████▉               | 791/1000 [1:08:28<20:55,  6.01s/it]

step 0790: lr 0.001552, train loss 2.6275, val loss 2.5457, ppl 13, time elapsed: 4104.34 seconds


 80%|█████████████████████████████████████████████████████████▋              | 801/1000 [1:09:19<20:03,  6.05s/it]

step 0800: lr 0.001295, train loss 2.7158, val loss 2.5385, ppl 13, time elapsed: 4155.11 seconds


 81%|██████████████████████████████████████████████████████████▍             | 811/1000 [1:10:10<18:59,  6.03s/it]

step 0810: lr 0.001058, train loss 2.6540, val loss 2.6579, ppl 14, time elapsed: 4205.90 seconds


 82%|███████████████████████████████████████████████████████████             | 821/1000 [1:11:01<18:04,  6.06s/it]

step 0820: lr 0.000843, train loss 2.7025, val loss 2.6444, ppl 14, time elapsed: 4256.73 seconds


 83%|███████████████████████████████████████████████████████████▊            | 831/1000 [1:11:52<17:09,  6.09s/it]

step 0830: lr 0.000650, train loss 2.6434, val loss 2.6997, ppl 15, time elapsed: 4307.79 seconds


 84%|████████████████████████████████████████████████████████████▌           | 841/1000 [1:12:43<16:10,  6.10s/it]

step 0840: lr 0.000480, train loss 2.6827, val loss 2.6210, ppl 14, time elapsed: 4359.14 seconds


 85%|█████████████████████████████████████████████████████████████▎          | 851/1000 [1:13:34<14:50,  5.98s/it]

step 0850: lr 0.000335, train loss 2.6249, val loss 2.6165, ppl 14, time elapsed: 4409.53 seconds


 86%|█████████████████████████████████████████████████████████████▉          | 861/1000 [1:14:25<14:02,  6.06s/it]

step 0860: lr 0.000215, train loss 2.6681, val loss 2.7036, ppl 15, time elapsed: 4460.65 seconds


 87%|██████████████████████████████████████████████████████████████▋         | 871/1000 [1:15:16<13:02,  6.07s/it]

step 0870: lr 0.000122, train loss 2.6482, val loss 2.6231, ppl 14, time elapsed: 4511.58 seconds


 88%|███████████████████████████████████████████████████████████████▍        | 881/1000 [1:16:07<11:54,  6.00s/it]

step 0880: lr 0.000054, train loss 2.6391, val loss 2.6838, ppl 15, time elapsed: 4562.35 seconds


 89%|████████████████████████████████████████████████████████████████▏       | 891/1000 [1:16:57<10:52,  5.98s/it]

step 0890: lr 0.000014, train loss 2.6404, val loss 2.6373, ppl 14, time elapsed: 4612.89 seconds


 90%|████████████████████████████████████████████████████████████████▊       | 901/1000 [1:17:50<10:31,  6.38s/it]

step 0900: lr 0.000000, train loss 2.6356, val loss 2.6181, ppl 14, time elapsed: 4665.49 seconds


 91%|█████████████████████████████████████████████████████████████████▌      | 911/1000 [1:18:43<09:24,  6.35s/it]

step 0910: lr 0.000000, train loss 2.6283, val loss 2.5455, ppl 13, time elapsed: 4718.86 seconds


 92%|██████████████████████████████████████████████████████████████████▎     | 921/1000 [1:19:37<08:23,  6.38s/it]

step 0920: lr 0.000000, train loss 2.6396, val loss 2.7615, ppl 16, time elapsed: 4772.13 seconds


 93%|███████████████████████████████████████████████████████████████████     | 931/1000 [1:20:30<07:17,  6.34s/it]

step 0930: lr 0.000000, train loss 2.7587, val loss 2.6379, ppl 14, time elapsed: 4825.58 seconds


 94%|███████████████████████████████████████████████████████████████████▊    | 941/1000 [1:21:24<06:19,  6.44s/it]

step 0940: lr 0.000000, train loss 2.5930, val loss 2.5673, ppl 13, time elapsed: 4879.59 seconds


 95%|████████████████████████████████████████████████████████████████████▍   | 951/1000 [1:22:17<05:11,  6.36s/it]

step 0950: lr 0.000000, train loss 2.7526, val loss 2.6536, ppl 14, time elapsed: 4932.96 seconds


 96%|█████████████████████████████████████████████████████████████████████▏  | 961/1000 [1:23:11<03:57,  6.08s/it]

step 0960: lr 0.000000, train loss 2.5861, val loss 2.6601, ppl 14, time elapsed: 4986.65 seconds


 97%|█████████████████████████████████████████████████████████████████████▉  | 971/1000 [1:24:01<02:54,  6.02s/it]

step 0970: lr 0.000000, train loss 2.6336, val loss 2.7520, ppl 16, time elapsed: 5037.19 seconds


 98%|██████████████████████████████████████████████████████████████████████▋ | 981/1000 [1:24:52<01:53,  5.97s/it]

step 0980: lr 0.000000, train loss 2.6580, val loss 2.8351, ppl 17, time elapsed: 5087.55 seconds


 99%|███████████████████████████████████████████████████████████████████████▎| 991/1000 [1:25:43<00:54,  6.00s/it]

step 0990: lr 0.000000, train loss 2.8126, val loss 2.6469, ppl 14, time elapsed: 5138.50 seconds


100%|███████████████████████████████████████████████████████████████████████| 1000/1000 [1:26:28<00:00,  5.19s/it]

step 0999: lr 0.000000, train loss 2.5850, val loss 2.6339, ppl 14, time elapsed: 5184.22 seconds





# inference test before you decide to save it
if `tcfg.checkpoint_interval != None` then checkpoints have already been saved

In [6]:
from inference import generate
prompt = "Once upon a time"
model.eval()
output = generate(
    prompt, 
    model, 
    tokenizer,
    #max_gen_len = 512,
    temperature = 0.7,
    #memory_saver_div = 8,
    #top_p = 0.9,
    #top_k = 32,
)
model.train()
print(output)

Once upon a time, there was a pie to coloved Tim. The little girl named Tim. She was very happy. She would mom saw a little girl named Tim had a big mom. The mom was not good a little book, a brel decided to be be. It was very happy and it was not not earned. She wanted to like to man said, "I am always a scar.
She was so excited to the red and a and on the tople. The bird flew that down. The bird saw down and said, "I have did not in the wise the mom. The mom and the pretty becared and found a little. They are so happy and played with a a gree!" Max and went to the snail and you want to knew she longs. The dress was very good all was very happy. They played with the more together every day. He played Lily was so happy other and from the very sad.
Sue was very happy and it was not know what the wish.
The car your was asked to the ground and go the way the in the park. The cat and down. They were toon to make they was speciall and Lily.
A man said disappogether sking. She saybe were all

# Saving your final model
you DO still need to do this even if you had been saving checkpoints; the final state has not yet been saved

In [7]:
from tools import save_model
save_model(model, cfg, tcfg, log_data)