# Setup

In [1]:
# my virtual environments are rarely properly connected to jupyter so this fixes that. 
# you won't need this cell but running it won't hurt anything either
import sys
import os
current_dir = os.getcwd()  # Get the current working directory
venv_dir = os.path.join(current_dir, './venv') 
python_version = str(sys.version_info.major) + '.' + str(sys.version_info.minor)
site_packages_path = os.path.join(venv_dir, 'lib', 'python' + python_version, 'site-packages')
sys.path.append(site_packages_path) 

# Instantiate a brand new model

In [2]:
# tokenizer
from tokenizer import get_tokenizer
tokenizer = get_tokenizer(size = 2048) # size options are 95(character-wise), 128, 256, 512, 1024, 2048 & 4096

# config file
from config import ModelConfig, TrainConfig
cfg = ModelConfig()
cfg.vocab_len = tokenizer.vocab_len
tcfg = TrainConfig()
print(cfg, '\n\n', tcfg)

# model modules
from model import customGPT
model = customGPT(cfg).to(cfg.device)

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters\n')
print(model)

ModelConfig(dim=64, vocab_len=2051, device='cpu', num_layers=10, second_resid_norm=False, mlp_hidden_mult=1, mlp_bias=False, mlp_nonlinearity='SiLU', mlp_gated=True, num_q_heads=10, num_kv_heads=2, head_dim=16, theta=10000, max_seq_len=512, scale_first_resid=True, norm_type='RMSNorm', norm_affine=True, norm_bias=True, eps=1e-06, max_batch_size=1) 

 TrainConfig(weight_decay=0.02, batch_size=32, max_iters=1000, eval_interval=10, eval_samples=1, checkpoint_interval=None, lr_max=0.1, lr_min=1e-06, warmup_iters=100, final_flat_iters=100, anneal_type='cos', num_restarts=3, T_mult=2)
502.592 K parameters

customGPT(
  (token_embedder): Embedding(2051, 64)
  (layers): ModuleList(
    (0-9): 10 x ResidualLayer(
      (pre_attn_norm): Norm()
      (attn): MQSA(
        (Wq): Linear(in_features=64, out_features=160, bias=False)
        (Wk): Linear(in_features=64, out_features=32, bias=False)
        (Wv): Linear(in_features=64, out_features=32, bias=False)
        (Wo): Linear(in_features=160, 

# Training

In [3]:
import torch
from tools import get_data_loader
from train import scheduler_lambda, train

optimizer = torch.optim.AdamW(model.parameters(), lr = tcfg.lr_max, weight_decay = tcfg.weight_decay)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_lambda)

train_data_loader = get_data_loader(batch_size=tcfg.batch_size, split='train')
test_data_loader = get_data_loader(batch_size=tcfg.batch_size, split='validation')

In [4]:
if False: # set to true if you'd like to see a graph of the learning rate schedule
    import matplotlib.pyplot as plt
    
    # Generate learning rate values
    lrs = [scheduler_lambda(i) for i in range(tcfg.max_iters)]
    
    # Plot the learning rates
    plt.figure(figsize=(10, 5))
    plt.plot(lrs, label='Learning Rate')
    plt.title('Learning Rate Schedule')
    plt.xlabel('Iteration')
    plt.ylabel('Learning Rate')
    plt.grid(True)
    plt.legend()
    plt.show()

In [4]:
model, optimizer, log_data = train(
    model, 
    tokenizer, 
    cfg, 
    optimizer,
    scheduler,
    tcfg, 
    train_data_loader,
    test_data_loader,
    #log_data: list = None, 
    #detect_anomoly = False # use if you're getting crazy errors about a the gradient being broken
)

  0%|                                                                          | 1/1000 [00:20<5:42:40, 20.58s/it]

step 0000: lr 0.000000, train loss 63.9337, val loss 64.2233, ppl 7794836105729155934864801792, time elapsed: 9.98 seconds


  1%|▊                                                                        | 11/1000 [02:10<3:05:26, 11.25s/it]

step 0010: lr 0.001000, train loss 61.9759, val loss 61.7167, ppl 635625584866842491603648512, time elapsed: 121.95 seconds


  2%|█▌                                                                       | 21/1000 [03:36<2:44:37, 10.09s/it]

step 0020: lr 0.002000, train loss 28.6671, val loss 28.6723, ppl 2832833183744, time elapsed: 208.19 seconds


  3%|██▎                                                                      | 31/1000 [04:58<2:40:20,  9.93s/it]

step 0030: lr 0.003000, train loss 16.8761, val loss 16.9013, ppl 21884830, time elapsed: 291.05 seconds


  4%|██▉                                                                      | 41/1000 [06:24<2:36:11,  9.77s/it]

step 0040: lr 0.004000, train loss 12.2021, val loss 12.3412, ppl 228930, time elapsed: 375.68 seconds


  5%|███▋                                                                     | 51/1000 [07:41<2:25:22,  9.19s/it]

step 0050: lr 0.005000, train loss 9.0983, val loss 9.1421, ppl 9340, time elapsed: 453.86 seconds


  6%|████▍                                                                    | 61/1000 [08:57<2:23:35,  9.18s/it]

step 0060: lr 0.006000, train loss 7.0627, val loss 7.2192, ppl 1365, time elapsed: 530.33 seconds


  7%|█████▏                                                                   | 71/1000 [10:15<2:24:20,  9.32s/it]

step 0070: lr 0.007000, train loss 6.1129, val loss 6.1202, ppl 455, time elapsed: 608.10 seconds


  8%|█████▉                                                                   | 81/1000 [11:35<2:26:57,  9.59s/it]

step 0080: lr 0.008000, train loss 5.6825, val loss 5.5900, ppl 268, time elapsed: 688.49 seconds


  9%|██████▋                                                                  | 91/1000 [12:52<2:17:42,  9.09s/it]

step 0090: lr 0.009000, train loss 5.2627, val loss 5.2000, ppl 181, time elapsed: 765.00 seconds


 10%|███████▎                                                                | 101/1000 [14:08<2:19:03,  9.28s/it]

step 0100: lr 0.010000, train loss 5.0429, val loss 5.0012, ppl 149, time elapsed: 841.50 seconds


 11%|███████▉                                                                | 111/1000 [15:25<2:15:39,  9.16s/it]

step 0110: lr 0.009157, train loss 4.8401, val loss 4.7502, ppl 116, time elapsed: 917.62 seconds


 12%|████████▋                                                               | 121/1000 [16:41<2:14:26,  9.18s/it]

step 0120: lr 0.006913, train loss 4.6308, val loss 4.6916, ppl 109, time elapsed: 993.61 seconds


 13%|█████████▍                                                              | 131/1000 [17:56<2:11:45,  9.10s/it]

step 0130: lr 0.004025, train loss 4.5274, val loss 4.4709, ppl 87, time elapsed: 1069.45 seconds


 14%|██████████▏                                                             | 141/1000 [19:12<2:09:01,  9.01s/it]

step 0140: lr 0.001465, train loss 4.4715, val loss 4.4377, ppl 85, time elapsed: 1144.79 seconds


 15%|██████████▊                                                             | 151/1000 [20:27<2:07:49,  9.03s/it]

step 0150: lr 0.000096, train loss 4.5046, val loss 4.5000, ppl 90, time elapsed: 1220.24 seconds


 16%|███████████▌                                                            | 161/1000 [21:44<2:07:30,  9.12s/it]

step 0160: lr 0.009904, train loss 4.4243, val loss 4.3826, ppl 80, time elapsed: 1296.60 seconds


 17%|████████████▎                                                           | 171/1000 [22:59<2:05:12,  9.06s/it]

step 0170: lr 0.009410, train loss 4.3896, val loss 4.4327, ppl 84, time elapsed: 1372.40 seconds


 18%|█████████████                                                           | 181/1000 [24:15<2:02:33,  8.98s/it]

step 0180: lr 0.008536, train loss 4.2756, val loss 4.2951, ppl 73, time elapsed: 1447.90 seconds


 19%|█████████████▊                                                          | 191/1000 [25:30<2:01:32,  9.01s/it]

step 0190: lr 0.007357, train loss 4.0393, val loss 4.1571, ppl 64, time elapsed: 1523.14 seconds


 20%|██████████████▍                                                         | 201/1000 [26:46<2:02:13,  9.18s/it]

step 0200: lr 0.005975, train loss 3.9964, val loss 4.1051, ppl 61, time elapsed: 1598.80 seconds


 21%|███████████████▏                                                        | 211/1000 [28:03<2:01:28,  9.24s/it]

step 0210: lr 0.004510, train loss 3.9647, val loss 3.9477, ppl 52, time elapsed: 1675.70 seconds


 22%|███████████████▉                                                        | 221/1000 [29:19<1:58:20,  9.11s/it]

step 0220: lr 0.003087, train loss 3.9985, val loss 3.9353, ppl 51, time elapsed: 1751.83 seconds


 23%|████████████████▋                                                       | 231/1000 [30:35<1:56:03,  9.05s/it]

step 0230: lr 0.001828, train loss 3.8604, val loss 3.9716, ppl 53, time elapsed: 1827.97 seconds


 24%|█████████████████▎                                                      | 241/1000 [31:50<1:54:13,  9.03s/it]

step 0240: lr 0.000843, train loss 3.8057, val loss 3.9341, ppl 51, time elapsed: 1903.05 seconds


 25%|██████████████████                                                      | 251/1000 [33:06<1:52:56,  9.05s/it]

step 0250: lr 0.000215, train loss 4.0078, val loss 4.0754, ppl 59, time elapsed: 1979.48 seconds


 26%|██████████████████▊                                                     | 261/1000 [34:21<1:51:08,  9.02s/it]

step 0260: lr 0.000000, train loss 3.8641, val loss 3.9758, ppl 53, time elapsed: 2054.46 seconds


 27%|███████████████████▌                                                    | 271/1000 [35:36<1:49:39,  9.02s/it]

step 0270: lr 0.009946, train loss 3.7748, val loss 3.9478, ppl 52, time elapsed: 2129.56 seconds


 28%|████████████████████▏                                                   | 281/1000 [36:51<1:47:59,  9.01s/it]

step 0280: lr 0.009785, train loss 3.7094, val loss 3.7897, ppl 44, time elapsed: 2204.49 seconds


 29%|████████████████████▉                                                   | 291/1000 [38:06<1:45:25,  8.92s/it]

step 0290: lr 0.009520, train loss 3.6159, val loss 3.6493, ppl 38, time elapsed: 2278.89 seconds


 30%|█████████████████████▋                                                  | 301/1000 [39:20<1:45:21,  9.04s/it]

step 0300: lr 0.009157, train loss 3.6784, val loss 3.6521, ppl 39, time elapsed: 2353.66 seconds


 31%|██████████████████████▍                                                 | 311/1000 [40:35<1:42:30,  8.93s/it]

step 0310: lr 0.008705, train loss 3.6139, val loss 3.5993, ppl 37, time elapsed: 2428.19 seconds


 32%|███████████████████████                                                 | 321/1000 [41:50<1:40:43,  8.90s/it]

step 0320: lr 0.008172, train loss 3.5047, val loss 3.4639, ppl 32, time elapsed: 2503.08 seconds


 33%|███████████████████████▊                                                | 331/1000 [43:05<1:40:44,  9.03s/it]

step 0330: lr 0.007571, train loss 3.3933, val loss 3.4347, ppl 31, time elapsed: 2578.39 seconds


 34%|████████████████████████▌                                               | 341/1000 [44:19<1:38:17,  8.95s/it]

step 0340: lr 0.006913, train loss 3.4478, val loss 3.4517, ppl 32, time elapsed: 2652.57 seconds


 35%|█████████████████████████▎                                              | 351/1000 [45:35<1:39:20,  9.18s/it]

step 0350: lr 0.006215, train loss 3.4624, val loss 3.3461, ppl 28, time elapsed: 2727.93 seconds


 36%|█████████████████████████▉                                              | 361/1000 [46:50<1:35:52,  9.00s/it]

step 0360: lr 0.005490, train loss 3.3698, val loss 3.3798, ppl 29, time elapsed: 2803.44 seconds


 37%|██████████████████████████▋                                             | 371/1000 [48:05<1:33:45,  8.94s/it]

step 0370: lr 0.004755, train loss 3.3326, val loss 3.3261, ppl 28, time elapsed: 2878.30 seconds


 38%|███████████████████████████▍                                            | 381/1000 [49:19<1:31:58,  8.92s/it]

step 0380: lr 0.004025, train loss 3.3871, val loss 3.3368, ppl 28, time elapsed: 2952.78 seconds


 39%|████████████████████████████▏                                           | 391/1000 [50:34<1:31:00,  8.97s/it]

step 0390: lr 0.003316, train loss 3.3832, val loss 3.2956, ppl 27, time elapsed: 3027.28 seconds


 40%|████████████████████████████▊                                           | 401/1000 [51:48<1:28:56,  8.91s/it]

step 0400: lr 0.002643, train loss 3.2785, val loss 3.2932, ppl 27, time elapsed: 3101.23 seconds


 41%|█████████████████████████████▌                                          | 411/1000 [53:03<1:27:40,  8.93s/it]

step 0410: lr 0.002022, train loss 3.2899, val loss 3.3440, ppl 28, time elapsed: 3175.89 seconds


 42%|██████████████████████████████▎                                         | 421/1000 [54:17<1:26:55,  9.01s/it]

step 0420: lr 0.001465, train loss 3.3149, val loss 3.2546, ppl 26, time elapsed: 3250.47 seconds


 43%|███████████████████████████████                                         | 431/1000 [55:31<1:24:40,  8.93s/it]

step 0430: lr 0.000984, train loss 3.3082, val loss 3.3306, ppl 28, time elapsed: 3324.83 seconds


 44%|███████████████████████████████▊                                        | 441/1000 [56:47<1:26:41,  9.30s/it]

step 0440: lr 0.000590, train loss 3.3013, val loss 3.2155, ppl 25, time elapsed: 3400.72 seconds


 45%|████████████████████████████████▍                                       | 451/1000 [58:02<1:22:06,  8.97s/it]

step 0450: lr 0.000292, train loss 3.2731, val loss 3.2230, ppl 25, time elapsed: 3475.26 seconds


 46%|█████████████████████████████████▏                                      | 461/1000 [59:18<1:20:49,  9.00s/it]

step 0460: lr 0.000096, train loss 3.1465, val loss 3.1981, ppl 24, time elapsed: 3550.83 seconds


 47%|████████████████████████████████▉                                     | 471/1000 [1:00:31<1:18:24,  8.89s/it]

step 0470: lr 0.000006, train loss 3.2498, val loss 3.2400, ppl 26, time elapsed: 3624.63 seconds


 48%|█████████████████████████████████▋                                    | 481/1000 [1:01:45<1:16:44,  8.87s/it]

step 0480: lr 0.009994, train loss 3.2595, val loss 3.4002, ppl 30, time elapsed: 3698.55 seconds


 49%|██████████████████████████████████▎                                   | 491/1000 [1:02:59<1:15:54,  8.95s/it]

step 0490: lr 0.009962, train loss 3.2452, val loss 3.3183, ppl 28, time elapsed: 3772.75 seconds


 50%|███████████████████████████████████                                   | 501/1000 [1:04:14<1:14:24,  8.95s/it]

step 0500: lr 0.009904, train loss 3.2497, val loss 3.2608, ppl 26, time elapsed: 3847.43 seconds


 51%|███████████████████████████████████▊                                  | 511/1000 [1:05:29<1:13:07,  8.97s/it]

step 0510: lr 0.009819, train loss 3.3313, val loss 3.1240, ppl 23, time elapsed: 3922.35 seconds


 52%|████████████████████████████████████▍                                 | 521/1000 [1:06:43<1:11:14,  8.92s/it]

step 0520: lr 0.009708, train loss 3.2765, val loss 3.2397, ppl 26, time elapsed: 3996.68 seconds


 53%|█████████████████████████████████████▏                                | 531/1000 [1:07:59<1:10:40,  9.04s/it]

step 0530: lr 0.009571, train loss 3.1315, val loss 3.1070, ppl 22, time elapsed: 4072.45 seconds


 54%|█████████████████████████████████████▊                                | 541/1000 [1:09:15<1:09:10,  9.04s/it]

step 0540: lr 0.009410, train loss 3.1268, val loss 3.1620, ppl 24, time elapsed: 4148.31 seconds


 55%|██████████████████████████████████████▌                               | 551/1000 [1:11:52<2:22:08, 18.99s/it]

step 0550: lr 0.009224, train loss 3.1269, val loss 3.2000, ppl 25, time elapsed: 4300.84 seconds


 56%|███████████████████████████████████████▎                              | 561/1000 [1:13:45<1:32:46, 12.68s/it]

step 0560: lr 0.009016, train loss 3.1579, val loss 3.1638, ppl 24, time elapsed: 4416.88 seconds


 57%|███████████████████████████████████████▉                              | 571/1000 [1:15:34<1:30:54, 12.71s/it]

step 0570: lr 0.008786, train loss 2.9091, val loss 3.0249, ppl 21, time elapsed: 4523.74 seconds


 58%|████████████████████████████████████████▋                             | 581/1000 [1:17:03<1:08:31,  9.81s/it]

step 0580: lr 0.008536, train loss 3.1063, val loss 3.0707, ppl 22, time elapsed: 4615.95 seconds


 59%|█████████████████████████████████████████▎                            | 591/1000 [1:18:22<1:04:45,  9.50s/it]

step 0590: lr 0.008266, train loss 3.1099, val loss 3.1413, ppl 23, time elapsed: 4694.73 seconds


 60%|██████████████████████████████████████████                            | 601/1000 [1:19:41<1:03:04,  9.48s/it]

step 0600: lr 0.007979, train loss 3.0318, val loss 2.9169, ppl 18, time elapsed: 4773.63 seconds


 61%|██████████████████████████████████████████▊                           | 611/1000 [1:20:59<1:00:46,  9.37s/it]

step 0610: lr 0.007675, train loss 3.0457, val loss 2.9910, ppl 20, time elapsed: 4851.88 seconds


 62%|████████████████████████████████████████████▋                           | 621/1000 [1:22:16<58:02,  9.19s/it]

step 0620: lr 0.007357, train loss 2.9137, val loss 3.0345, ppl 21, time elapsed: 4929.01 seconds


 63%|█████████████████████████████████████████████▍                          | 631/1000 [1:23:33<56:58,  9.27s/it]

step 0630: lr 0.007026, train loss 2.9682, val loss 2.8546, ppl 17, time elapsed: 5006.13 seconds


 64%|██████████████████████████████████████████████▏                         | 641/1000 [1:24:51<55:59,  9.36s/it]

step 0640: lr 0.006684, train loss 2.9077, val loss 2.9763, ppl 20, time elapsed: 5084.06 seconds


 65%|██████████████████████████████████████████████▊                         | 651/1000 [1:26:09<53:54,  9.27s/it]

step 0650: lr 0.006334, train loss 2.9841, val loss 2.9659, ppl 19, time elapsed: 5162.07 seconds


 66%|███████████████████████████████████████████████▌                        | 661/1000 [1:27:26<52:11,  9.24s/it]

step 0660: lr 0.005975, train loss 3.0707, val loss 2.9233, ppl 19, time elapsed: 5239.48 seconds


 67%|████████████████████████████████████████████████▎                       | 671/1000 [1:28:43<50:21,  9.18s/it]

step 0670: lr 0.005612, train loss 2.8910, val loss 2.9136, ppl 18, time elapsed: 5316.74 seconds


 68%|█████████████████████████████████████████████████                       | 681/1000 [1:30:00<48:29,  9.12s/it]

step 0680: lr 0.005245, train loss 2.9430, val loss 2.8711, ppl 18, time elapsed: 5393.16 seconds


 69%|█████████████████████████████████████████████████▊                      | 691/1000 [1:31:17<47:47,  9.28s/it]

step 0690: lr 0.004877, train loss 2.9331, val loss 2.8078, ppl 17, time elapsed: 5470.55 seconds


 70%|██████████████████████████████████████████████████▍                     | 701/1000 [1:32:34<46:06,  9.25s/it]

step 0700: lr 0.004510, train loss 2.9237, val loss 2.9256, ppl 19, time elapsed: 5547.66 seconds


 71%|███████████████████████████████████████████████████▏                    | 711/1000 [1:33:51<44:11,  9.18s/it]

step 0710: lr 0.004145, train loss 2.9131, val loss 2.8724, ppl 18, time elapsed: 5624.38 seconds


 72%|███████████████████████████████████████████████████▉                    | 721/1000 [1:35:07<42:08,  9.06s/it]

step 0720: lr 0.003785, train loss 2.8561, val loss 2.8781, ppl 18, time elapsed: 5700.54 seconds


 73%|████████████████████████████████████████████████████▋                   | 731/1000 [1:36:23<40:22,  9.01s/it]

step 0730: lr 0.003432, train loss 2.8289, val loss 2.7923, ppl 16, time elapsed: 5776.02 seconds


 74%|█████████████████████████████████████████████████████▎                  | 741/1000 [1:37:39<39:49,  9.23s/it]

step 0740: lr 0.003087, train loss 2.7230, val loss 2.8010, ppl 16, time elapsed: 5852.56 seconds


 75%|██████████████████████████████████████████████████████                  | 751/1000 [1:38:56<37:53,  9.13s/it]

step 0750: lr 0.002752, train loss 2.6812, val loss 2.8855, ppl 18, time elapsed: 5928.87 seconds


 76%|██████████████████████████████████████████████████████▊                 | 761/1000 [1:40:13<36:20,  9.13s/it]

step 0760: lr 0.002430, train loss 2.8551, val loss 2.7849, ppl 16, time elapsed: 6005.87 seconds


 77%|███████████████████████████████████████████████████████▌                | 771/1000 [1:41:30<34:47,  9.11s/it]

step 0770: lr 0.002121, train loss 2.7369, val loss 2.7431, ppl 16, time elapsed: 6083.12 seconds


 78%|████████████████████████████████████████████████████████▏               | 781/1000 [1:43:01<42:50, 11.74s/it]

step 0780: lr 0.001828, train loss 2.8456, val loss 2.7877, ppl 16, time elapsed: 6172.58 seconds


 79%|████████████████████████████████████████████████████████▉               | 791/1000 [1:44:46<35:10, 10.10s/it]

step 0790: lr 0.001552, train loss 2.7264, val loss 2.7575, ppl 16, time elapsed: 6278.89 seconds


 80%|█████████████████████████████████████████████████████████▋              | 801/1000 [1:46:04<31:02,  9.36s/it]

step 0800: lr 0.001295, train loss 2.7883, val loss 2.7797, ppl 16, time elapsed: 6357.40 seconds


 81%|██████████████████████████████████████████████████████████▍             | 811/1000 [1:47:34<34:49, 11.06s/it]

step 0810: lr 0.001058, train loss 2.7587, val loss 2.7402, ppl 15, time elapsed: 6445.33 seconds


 82%|███████████████████████████████████████████████████████████             | 821/1000 [1:49:08<33:06, 11.10s/it]

step 0820: lr 0.000843, train loss 2.6745, val loss 2.7459, ppl 16, time elapsed: 6540.97 seconds


 83%|███████████████████████████████████████████████████████████▊            | 831/1000 [1:50:29<27:07,  9.63s/it]

step 0830: lr 0.000650, train loss 2.6476, val loss 2.7153, ppl 15, time elapsed: 6621.91 seconds


 84%|████████████████████████████████████████████████████████████▌           | 841/1000 [1:51:50<25:28,  9.61s/it]

step 0840: lr 0.000480, train loss 2.7007, val loss 2.6882, ppl 15, time elapsed: 6702.52 seconds


 85%|█████████████████████████████████████████████████████████████▎          | 851/1000 [1:53:11<24:04,  9.70s/it]

step 0850: lr 0.000335, train loss 2.8432, val loss 2.7036, ppl 15, time elapsed: 6783.62 seconds


 86%|█████████████████████████████████████████████████████████████▉          | 861/1000 [1:54:47<30:07, 13.00s/it]

step 0860: lr 0.000215, train loss 2.6306, val loss 2.6856, ppl 15, time elapsed: 6874.18 seconds


 87%|██████████████████████████████████████████████████████████████▋         | 871/1000 [1:56:15<20:43,  9.64s/it]

step 0870: lr 0.000122, train loss 2.7258, val loss 2.7118, ppl 15, time elapsed: 6967.67 seconds


 88%|███████████████████████████████████████████████████████████████▍        | 881/1000 [1:57:32<18:16,  9.21s/it]

step 0880: lr 0.000054, train loss 2.6864, val loss 2.7004, ppl 15, time elapsed: 7045.03 seconds


 89%|████████████████████████████████████████████████████████████████▏       | 891/1000 [1:58:48<16:34,  9.13s/it]

step 0890: lr 0.000014, train loss 2.8017, val loss 2.7634, ppl 16, time elapsed: 7121.57 seconds


 90%|████████████████████████████████████████████████████████████████▊       | 901/1000 [2:00:06<15:26,  9.36s/it]

step 0900: lr 0.000000, train loss 2.5560, val loss 2.8118, ppl 17, time elapsed: 7199.48 seconds


 91%|█████████████████████████████████████████████████████████████████▌      | 911/1000 [2:01:24<13:46,  9.29s/it]

step 0910: lr 0.000000, train loss 2.7280, val loss 2.7879, ppl 16, time elapsed: 7276.72 seconds


 92%|██████████████████████████████████████████████████████████████████▎     | 921/1000 [2:02:41<12:10,  9.24s/it]

step 0920: lr 0.000000, train loss 2.8438, val loss 2.7801, ppl 16, time elapsed: 7353.77 seconds


 93%|███████████████████████████████████████████████████████████████████     | 931/1000 [2:04:00<10:58,  9.55s/it]

step 0930: lr 0.000000, train loss 2.7341, val loss 2.7202, ppl 15, time elapsed: 7433.01 seconds


 94%|███████████████████████████████████████████████████████████████████▊    | 941/1000 [2:05:16<08:57,  9.11s/it]

step 0940: lr 0.000000, train loss 2.7645, val loss 2.7566, ppl 16, time elapsed: 7509.59 seconds


 95%|████████████████████████████████████████████████████████████████████▍   | 951/1000 [2:06:33<07:30,  9.19s/it]

step 0950: lr 0.000000, train loss 2.7978, val loss 2.7650, ppl 16, time elapsed: 7585.91 seconds


 96%|█████████████████████████████████████████████████████████████████████▏  | 961/1000 [2:07:50<05:58,  9.19s/it]

step 0960: lr 0.000000, train loss 2.7471, val loss 2.8458, ppl 17, time elapsed: 7662.98 seconds


 97%|█████████████████████████████████████████████████████████████████████▉  | 971/1000 [2:09:06<04:23,  9.09s/it]

step 0970: lr 0.000000, train loss 2.7546, val loss 2.8407, ppl 17, time elapsed: 7739.28 seconds


 98%|██████████████████████████████████████████████████████████████████████▋ | 981/1000 [2:10:25<02:58,  9.39s/it]

step 0980: lr 0.000000, train loss 2.8420, val loss 2.7914, ppl 16, time elapsed: 7817.49 seconds


 99%|███████████████████████████████████████████████████████████████████████▎| 991/1000 [2:11:44<01:23,  9.29s/it]

step 0990: lr 0.000000, train loss 2.7034, val loss 2.7596, ppl 16, time elapsed: 7897.30 seconds


100%|███████████████████████████████████████████████████████████████████████| 1000/1000 [2:12:55<00:00,  7.98s/it]

step 0999: lr 0.000000, train loss 2.8573, val loss 2.6890, ppl 15, time elapsed: 7967.83 seconds





# inference test before you decide to save it
if `tcfg.checkpoint_interval != None` then checkpoints have already been saved

In [5]:
from inference import generate
prompt = "Once upon a time"
model.eval()
output = generate(
    prompt, 
    model, 
    tokenizer,
    #max_gen_len = 512,
    temperature = 0.7,
    #memory_saver_div = 8,
    #top_p = 0.9,
    #top_k = 32,
)
model.train()
print(output)

Once upon a time, there was a little boy named Tim. Tim was a loved in a big buse. One day, the dog very Lily in the park. The bird saw a big ball of the were very out not dog. They looked to fious with the balk. The dog was so happy. It was started to play with a like to the caf. The cat was a slide would not toy lose is sk. The dog was not pick of a lose of the made all the care. They were happy for the The bird went for a moused to the clostil. The little water and happy and the bird. The little girl and said, "I have went of the was not man her friends, and we collower ain the good to ted and in the dirl for the rung and Tim was something the ran and in the win, and made the fly.
The wind a fun, and were happy and they more a on her.


# Saving your final model
you DO still need to do this even if you had been saving checkpoints; the final state has not yet been saved

In [6]:
from tools import save_model
save_model(model, cfg, tcfg, log_data)