# Setup

In [1]:
# my virtual environments are rarely properly connected to jupyter so this fixes that. 
# you prolly won't need this cell but running it won't hurt anything either
import sys
import os
current_dir = os.getcwd()  # Get the current working directory
venv_dir = os.path.join(current_dir, 'venv') 
python_version = str(sys.version_info.major) + '.' + str(sys.version_info.minor)
site_packages_path = os.path.join(venv_dir, 'lib', 'python' + python_version, 'site-packages')
sys.path.append(site_packages_path) 

# Instantiate a brand new model

In [2]:
# config file
from config import ModelConfig, TrainConfig
cfg = ModelConfig()
tcfg = TrainConfig()
print(cfg, '\n\n', tcfg)

# import the tokenizer specified by cfg
from tools import import_from_nested_path
imported_objects = import_from_nested_path(['tokenizers', cfg.tokenizer], 'tokenizer', ['get_tokenizer'])
get_tokenizer = imported_objects.get('get_tokenizer')
tokenizer = get_tokenizer(size = cfg.vocab_len)

# model modules
from modules.model import Model
model = Model(cfg).to(cfg.device)

# print the number of parameters in the model
print(f'\n{sum(p.numel() for p in model.parameters())/1e3}K parameters\n')
print(model)

ModelConfig(dim=80, device='cpu', tokenizer='bpe_v1', vocab_len=8192, num_layers=4, second_resid_norm=False, num_heads=2, head_dim=40, max_seq_len=512, mm_bias=False, pmem_size=224, pmem_count=1, scale_first_resid=True, norm_type='RMSNorm', norm_affine=True, norm_bias=True, eps=1e-06) 

 TrainConfig(weight_decay=0.05, batch_size=32, max_iters=2000, eval_interval=20, eval_samples=1, checkpoint_interval=None, lr_init=1e-06, lr_max=0.1, lr_min=0.001, warmup_iters=20, final_flat_iters=400, anneal_type='cos', num_restarts=3, T_mult=2)

928.456K parameters

Model(
  (token_embedder): Embedding(8195, 80)
  (layers): ModuleList(
    (0-3): 4 x Layer(
      (pre_context_norm): Norm()
      (context): ContextMem(
        (k_featurizer): KeyFeatureExtractor(
          (W_k): Linear(in_features=80, out_features=80, bias=False)
          (leaky_avg): LeakyAvg()
        )
        (v_featurizer): ValFeatureExtractor(
          (W_v): Linear(in_features=80, out_features=80, bias=False)
        )
     

# Training

In [3]:
import torch
from tools import get_data_loader
from train import scheduler_lambda, train

optimizer = torch.optim.AdamW(model.parameters(), lr = tcfg.lr_max, weight_decay = tcfg.weight_decay)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_lambda)

train_data_loader = get_data_loader(batch_size=tcfg.batch_size, split='train')
test_data_loader = get_data_loader(batch_size=tcfg.batch_size, split='validation')

In [4]:
if False: # set to true if you'd like to see a graph of the learning rate schedule
    import matplotlib.pyplot as plt
    
    # Generate learning rate values
    lrs = [scheduler_lambda(i) for i in range(tcfg.max_iters)]
    
    # Plot the learning rates
    plt.figure(figsize=(10, 5))
    plt.plot(lrs, label='Learning Rate')
    plt.title('Learning Rate Schedule')
    plt.xlabel('Iteration')
    plt.ylabel('Learning Rate')
    plt.grid(True)
    plt.legend()
    plt.show()

In [5]:
model, optimizer, log_data = train(
    model, 
    tokenizer, 
    cfg, 
    optimizer,
    scheduler,
    tcfg, 
    train_data_loader,
    test_data_loader,
    #log_data: list = None, 
    #detect_anomoly = False # use if you're getting crazy errors about a the gradient being broken
)

  0%|                                                                         | 1/2000 [00:19<10:38:45, 19.17s/it]

step 0000: lr 0.000000, train loss 77.6600, val loss 77.4064, ppl 4141650957493128962322235104690176, time elapsed: 15.92 seconds


  1%|▊                                                                        | 21/2000 [04:18<3:10:40,  5.78s/it]

step 0020: lr 0.010000, train loss 29.7979, val loss 29.8511, ppl 9208346968064, time elapsed: 255.80 seconds


  2%|█▍                                                                       | 41/2000 [06:15<3:41:47,  6.79s/it]

step 0040: lr 0.009145, train loss 6.6940, val loss 6.7940, ppl 892, time elapsed: 372.19 seconds


  3%|██▏                                                                      | 61/2000 [08:19<3:41:00,  6.84s/it]

step 0060: lr 0.006876, train loss 4.7159, val loss 4.7831, ppl 119, time elapsed: 496.74 seconds


  4%|██▉                                                                      | 81/2000 [10:21<3:54:40,  7.34s/it]

step 0080: lr 0.003976, train loss 4.3837, val loss 4.3944, ppl 81, time elapsed: 618.34 seconds


  5%|███▋                                                                    | 101/2000 [12:54<3:58:26,  7.53s/it]

step 0100: lr 0.001447, train loss 4.3154, val loss 4.3162, ppl 75, time elapsed: 771.25 seconds


  6%|████▎                                                                   | 121/2000 [15:06<3:47:29,  7.26s/it]

step 0120: lr 0.000162, train loss 4.2784, val loss 4.2528, ppl 70, time elapsed: 903.79 seconds


  7%|█████                                                                   | 141/2000 [17:20<3:48:00,  7.36s/it]

step 0140: lr 0.009882, train loss 4.1048, val loss 4.0047, ppl 55, time elapsed: 1037.52 seconds


  8%|█████▊                                                                  | 161/2000 [19:29<3:37:13,  7.09s/it]

step 0160: lr 0.009353, train loss 3.9474, val loss 3.9009, ppl 49, time elapsed: 1166.76 seconds


  9%|██████▌                                                                 | 181/2000 [22:25<5:50:16, 11.55s/it]

step 0180: lr 0.008444, train loss 3.8301, val loss 3.7774, ppl 44, time elapsed: 1339.84 seconds


 10%|███████▏                                                                | 201/2000 [24:16<2:48:26,  5.62s/it]

step 0200: lr 0.007236, train loss 3.5793, val loss 3.6239, ppl 37, time elapsed: 1453.83 seconds


 11%|███████▉                                                                | 221/2000 [25:52<2:47:07,  5.64s/it]

step 0220: lr 0.005834, train loss 3.5448, val loss 3.3940, ppl 30, time elapsed: 1549.80 seconds


 12%|████████▋                                                               | 241/2000 [27:27<2:40:28,  5.47s/it]

step 0240: lr 0.004363, train loss 3.4268, val loss 3.4375, ppl 31, time elapsed: 1644.34 seconds


 13%|█████████▍                                                              | 261/2000 [29:02<2:39:12,  5.49s/it]

step 0260: lr 0.002953, train loss 3.2839, val loss 3.3350, ppl 28, time elapsed: 1739.58 seconds


 14%|██████████                                                              | 281/2000 [30:38<2:40:43,  5.61s/it]

step 0280: lr 0.001728, train loss 3.1877, val loss 3.1962, ppl 24, time elapsed: 1835.89 seconds


 15%|██████████▊                                                             | 301/2000 [32:15<2:35:24,  5.49s/it]

step 0300: lr 0.000796, train loss 3.2071, val loss 3.2412, ppl 26, time elapsed: 1932.33 seconds


 16%|███████████▌                                                            | 321/2000 [33:49<2:28:17,  5.30s/it]

step 0320: lr 0.000240, train loss 3.0882, val loss 3.1790, ppl 24, time elapsed: 2026.37 seconds


 17%|████████████▎                                                           | 341/2000 [35:25<2:32:55,  5.53s/it]

step 0340: lr 0.009998, train loss 3.1156, val loss 3.0421, ppl 21, time elapsed: 2122.42 seconds


 18%|████████████▉                                                           | 361/2000 [36:58<2:28:05,  5.42s/it]

step 0360: lr 0.009921, train loss 2.9126, val loss 2.9449, ppl 19, time elapsed: 2215.54 seconds


 19%|█████████████▋                                                          | 381/2000 [38:31<2:24:40,  5.36s/it]

step 0380: lr 0.009736, train loss 2.9056, val loss 2.8764, ppl 18, time elapsed: 2308.79 seconds


 20%|██████████████▍                                                         | 401/2000 [39:59<2:14:19,  5.04s/it]

step 0400: lr 0.009447, train loss 2.7024, val loss 2.7805, ppl 16, time elapsed: 2396.92 seconds


 21%|███████████████▏                                                        | 421/2000 [41:26<2:11:02,  4.98s/it]

step 0420: lr 0.009060, train loss 2.6708, val loss 2.6252, ppl 14, time elapsed: 2483.74 seconds


 22%|███████████████▉                                                        | 441/2000 [42:52<2:10:43,  5.03s/it]

step 0440: lr 0.008585, train loss 2.5085, val loss 2.5735, ppl 13, time elapsed: 2569.63 seconds


 23%|████████████████▌                                                       | 461/2000 [44:17<2:05:55,  4.91s/it]

step 0460: lr 0.008031, train loss 2.6321, val loss 2.5414, ppl 13, time elapsed: 2655.09 seconds


 24%|█████████████████▎                                                      | 481/2000 [45:43<2:05:48,  4.97s/it]

step 0480: lr 0.007410, train loss 2.5225, val loss 2.4848, ppl 12, time elapsed: 2740.52 seconds


 25%|██████████████████                                                      | 501/2000 [47:11<2:06:47,  5.08s/it]

step 0500: lr 0.006738, train loss 2.6169, val loss 2.5361, ppl 13, time elapsed: 2828.22 seconds


 26%|██████████████████▊                                                     | 521/2000 [48:40<2:06:13,  5.12s/it]

step 0520: lr 0.006028, train loss 2.4507, val loss 2.4049, ppl 11, time elapsed: 2917.40 seconds


 27%|███████████████████▍                                                    | 541/2000 [50:08<2:08:18,  5.28s/it]

step 0540: lr 0.005296, train loss 2.4298, val loss 2.3953, ppl 11, time elapsed: 3005.84 seconds


 28%|████████████████████▏                                                   | 561/2000 [51:39<2:07:08,  5.30s/it]

step 0560: lr 0.004559, train loss 2.3541, val loss 2.4256, ppl 11, time elapsed: 3097.14 seconds


 29%|████████████████████▉                                                   | 581/2000 [53:10<2:04:49,  5.28s/it]

step 0580: lr 0.003832, train loss 2.4193, val loss 2.4552, ppl 12, time elapsed: 3187.67 seconds


 30%|█████████████████████▋                                                  | 601/2000 [54:38<1:57:22,  5.03s/it]

step 0600: lr 0.003133, train loss 2.3620, val loss 2.2903, ppl 10, time elapsed: 3275.37 seconds


 31%|██████████████████████▎                                                 | 621/2000 [56:06<1:55:39,  5.03s/it]

step 0620: lr 0.002476, train loss 2.3108, val loss 2.4266, ppl 11, time elapsed: 3363.79 seconds


 32%|███████████████████████                                                 | 641/2000 [57:36<1:53:46,  5.02s/it]

step 0640: lr 0.001877, train loss 2.3082, val loss 2.3792, ppl 11, time elapsed: 3453.56 seconds


 33%|███████████████████████▊                                                | 661/2000 [59:05<1:47:58,  4.84s/it]

step 0660: lr 0.001347, train loss 2.3220, val loss 2.2192, ppl 9, time elapsed: 3542.64 seconds


 34%|███████████████████████▊                                              | 681/2000 [1:00:31<1:47:03,  4.87s/it]

step 0680: lr 0.000900, train loss 2.3345, val loss 2.2974, ppl 10, time elapsed: 3629.35 seconds


 35%|████████████████████████▌                                             | 701/2000 [1:01:56<1:44:23,  4.82s/it]

step 0700: lr 0.000545, train loss 2.2632, val loss 2.2396, ppl 9, time elapsed: 3713.58 seconds


 36%|█████████████████████████▏                                            | 721/2000 [1:03:20<1:46:41,  5.01s/it]

step 0720: lr 0.000291, train loss 2.2248, val loss 2.2769, ppl 10, time elapsed: 3798.43 seconds


 37%|█████████████████████████▉                                            | 741/2000 [1:04:46<1:44:41,  4.99s/it]

step 0740: lr 0.000141, train loss 2.3429, val loss 2.2445, ppl 9, time elapsed: 3883.34 seconds


 38%|██████████████████████████▋                                           | 761/2000 [1:06:12<1:41:20,  4.91s/it]

step 0760: lr 0.010000, train loss 2.4129, val loss 2.3616, ppl 11, time elapsed: 3970.11 seconds


 39%|███████████████████████████▎                                          | 781/2000 [1:07:34<1:31:43,  4.51s/it]

step 0780: lr 0.009982, train loss 2.2956, val loss 2.2429, ppl 9, time elapsed: 4051.58 seconds


 40%|████████████████████████████                                          | 801/2000 [1:08:48<1:25:10,  4.26s/it]

step 0800: lr 0.009938, train loss 2.1974, val loss 2.2602, ppl 10, time elapsed: 4125.81 seconds


 41%|████████████████████████████▋                                         | 821/2000 [1:09:58<1:21:16,  4.14s/it]

step 0820: lr 0.009866, train loss 2.2565, val loss 2.2942, ppl 10, time elapsed: 4196.27 seconds


 42%|█████████████████████████████▍                                        | 841/2000 [1:11:06<1:15:14,  3.89s/it]

step 0840: lr 0.009767, train loss 2.2053, val loss 2.2472, ppl 9, time elapsed: 4264.49 seconds


 43%|██████████████████████████████▏                                       | 861/2000 [1:12:13<1:15:09,  3.96s/it]

step 0860: lr 0.009642, train loss 2.1466, val loss 2.0651, ppl 8, time elapsed: 4330.55 seconds


 44%|██████████████████████████████▊                                       | 881/2000 [1:13:17<1:13:29,  3.94s/it]

step 0880: lr 0.009491, train loss 2.1102, val loss 2.2075, ppl 9, time elapsed: 4395.55 seconds


 45%|███████████████████████████████▌                                      | 901/2000 [1:14:22<1:10:08,  3.83s/it]

step 0900: lr 0.009316, train loss 2.1198, val loss 2.0672, ppl 8, time elapsed: 4460.05 seconds


 46%|████████████████████████████████▏                                     | 921/2000 [1:15:27<1:08:22,  3.80s/it]

step 0920: lr 0.009117, train loss 2.0197, val loss 2.0517, ppl 8, time elapsed: 4524.97 seconds


 47%|████████████████████████████████▉                                     | 941/2000 [1:16:32<1:11:17,  4.04s/it]

step 0940: lr 0.008896, train loss 2.0614, val loss 2.1090, ppl 8, time elapsed: 4589.30 seconds


 48%|█████████████████████████████████▋                                    | 961/2000 [1:17:37<1:05:25,  3.78s/it]

step 0960: lr 0.008653, train loss 2.0201, val loss 1.9974, ppl 7, time elapsed: 4654.65 seconds


 49%|██████████████████████████████████▎                                   | 981/2000 [1:18:40<1:02:51,  3.70s/it]

step 0980: lr 0.008390, train loss 2.1277, val loss 2.0259, ppl 8, time elapsed: 4718.32 seconds


 50%|██████████████████████████████████▌                                  | 1001/2000 [1:19:44<1:02:12,  3.74s/it]

step 1000: lr 0.008109, train loss 2.0088, val loss 1.9717, ppl 7, time elapsed: 4782.06 seconds


 51%|███████████████████████████████████▏                                 | 1021/2000 [1:20:49<1:02:31,  3.83s/it]

step 1020: lr 0.007810, train loss 1.9405, val loss 1.9435, ppl 7, time elapsed: 4847.16 seconds


 52%|███████████████████████████████████▉                                 | 1041/2000 [1:21:54<1:01:18,  3.84s/it]

step 1040: lr 0.007497, train loss 1.9813, val loss 1.9767, ppl 7, time elapsed: 4912.23 seconds


 53%|█████████████████████████████████████▋                                 | 1061/2000 [1:22:58<59:14,  3.79s/it]

step 1060: lr 0.007169, train loss 1.9437, val loss 2.0429, ppl 8, time elapsed: 4976.29 seconds


 54%|██████████████████████████████████████▍                                | 1081/2000 [1:24:02<57:54,  3.78s/it]

step 1080: lr 0.006830, train loss 1.9185, val loss 2.0237, ppl 8, time elapsed: 5040.68 seconds


 55%|███████████████████████████████████████                                | 1101/2000 [1:25:07<55:17,  3.69s/it]

step 1100: lr 0.006481, train loss 1.9770, val loss 1.9571, ppl 7, time elapsed: 5104.82 seconds


 56%|███████████████████████████████████████▊                               | 1121/2000 [1:26:10<54:33,  3.72s/it]

step 1120: lr 0.006124, train loss 1.9786, val loss 1.9435, ppl 7, time elapsed: 5168.46 seconds


 57%|████████████████████████████████████████▌                              | 1141/2000 [1:27:14<55:04,  3.85s/it]

step 1140: lr 0.005761, train loss 1.9832, val loss 1.8988, ppl 7, time elapsed: 5232.36 seconds


 58%|█████████████████████████████████████████▏                             | 1161/2000 [1:28:22<55:31,  3.97s/it]

step 1160: lr 0.005394, train loss 1.9012, val loss 1.8961, ppl 7, time elapsed: 5300.12 seconds


 59%|█████████████████████████████████████████▉                             | 1181/2000 [1:29:27<52:43,  3.86s/it]

step 1180: lr 0.005025, train loss 1.9264, val loss 1.9001, ppl 7, time elapsed: 5365.40 seconds


 60%|██████████████████████████████████████████▋                            | 1201/2000 [1:30:31<49:51,  3.74s/it]

step 1200: lr 0.004657, train loss 1.8703, val loss 1.8607, ppl 6, time elapsed: 5428.84 seconds


 61%|███████████████████████████████████████████▎                           | 1221/2000 [1:31:36<50:00,  3.85s/it]

step 1220: lr 0.004290, train loss 1.7933, val loss 1.9310, ppl 7, time elapsed: 5493.93 seconds


 62%|████████████████████████████████████████████                           | 1241/2000 [1:32:40<48:22,  3.82s/it]

step 1240: lr 0.003928, train loss 1.8751, val loss 1.7911, ppl 6, time elapsed: 5558.32 seconds


 63%|████████████████████████████████████████████▊                          | 1261/2000 [1:33:45<46:30,  3.78s/it]

step 1260: lr 0.003572, train loss 1.8947, val loss 1.8154, ppl 6, time elapsed: 5622.62 seconds


 64%|█████████████████████████████████████████████▍                         | 1281/2000 [1:34:48<44:49,  3.74s/it]

step 1280: lr 0.003224, train loss 1.8997, val loss 1.9259, ppl 7, time elapsed: 5686.23 seconds


 65%|██████████████████████████████████████████████▏                        | 1301/2000 [1:35:53<46:57,  4.03s/it]

step 1300: lr 0.002886, train loss 1.8443, val loss 1.8402, ppl 6, time elapsed: 5750.50 seconds


 66%|██████████████████████████████████████████████▉                        | 1321/2000 [1:37:12<54:58,  4.86s/it]

step 1320: lr 0.002561, train loss 1.8338, val loss 1.7705, ppl 6, time elapsed: 5829.23 seconds


 67%|███████████████████████████████████████████████▌                       | 1341/2000 [1:38:23<41:58,  3.82s/it]

step 1340: lr 0.002249, train loss 1.8187, val loss 1.7400, ppl 6, time elapsed: 5901.30 seconds


 68%|████████████████████████████████████████████████▎                      | 1361/2000 [1:39:27<40:40,  3.82s/it]

step 1360: lr 0.001953, train loss 1.7267, val loss 1.8664, ppl 6, time elapsed: 5965.44 seconds


 69%|█████████████████████████████████████████████████                      | 1381/2000 [1:40:41<43:00,  4.17s/it]

step 1380: lr 0.001674, train loss 1.8223, val loss 1.7732, ppl 6, time elapsed: 6039.23 seconds


 70%|█████████████████████████████████████████████████▋                     | 1401/2000 [1:41:47<38:12,  3.83s/it]

step 1400: lr 0.001413, train loss 1.7410, val loss 1.8267, ppl 6, time elapsed: 6104.69 seconds


 71%|██████████████████████████████████████████████████▍                    | 1421/2000 [1:42:50<37:03,  3.84s/it]

step 1420: lr 0.001173, train loss 1.8675, val loss 1.7994, ppl 6, time elapsed: 6167.90 seconds


 72%|███████████████████████████████████████████████████▏                   | 1441/2000 [1:43:54<34:59,  3.76s/it]

step 1440: lr 0.000955, train loss 1.7712, val loss 1.7432, ppl 6, time elapsed: 6232.31 seconds


 73%|███████████████████████████████████████████████████▊                   | 1461/2000 [1:44:59<34:35,  3.85s/it]

step 1460: lr 0.000759, train loss 1.7718, val loss 1.7467, ppl 6, time elapsed: 6297.04 seconds


 74%|████████████████████████████████████████████████████▌                  | 1481/2000 [1:46:06<39:21,  4.55s/it]

step 1480: lr 0.000587, train loss 1.8002, val loss 1.7180, ppl 6, time elapsed: 6363.60 seconds


 75%|█████████████████████████████████████████████████████▎                 | 1501/2000 [1:47:12<32:52,  3.95s/it]

step 1500: lr 0.000440, train loss 1.7542, val loss 1.8648, ppl 6, time elapsed: 6429.65 seconds


 76%|█████████████████████████████████████████████████████▉                 | 1521/2000 [1:48:16<30:15,  3.79s/it]

step 1520: lr 0.000319, train loss 1.7779, val loss 1.7874, ppl 6, time elapsed: 6493.70 seconds


 77%|██████████████████████████████████████████████████████▋                | 1541/2000 [1:49:20<28:48,  3.77s/it]

step 1540: lr 0.000223, train loss 1.7163, val loss 1.7608, ppl 6, time elapsed: 6558.33 seconds


 78%|███████████████████████████████████████████████████████▍               | 1561/2000 [1:50:24<28:43,  3.93s/it]

step 1560: lr 0.000155, train loss 1.7490, val loss 1.7357, ppl 6, time elapsed: 6622.12 seconds


 79%|████████████████████████████████████████████████████████▏              | 1581/2000 [1:51:28<26:33,  3.80s/it]

step 1580: lr 0.000114, train loss 1.7875, val loss 1.8293, ppl 6, time elapsed: 6686.18 seconds


 80%|████████████████████████████████████████████████████████▊              | 1601/2000 [1:52:34<25:05,  3.77s/it]

step 1600: lr 0.000100, train loss 1.7343, val loss 1.6329, ppl 5, time elapsed: 6751.71 seconds


 81%|█████████████████████████████████████████████████████████▌             | 1621/2000 [1:53:39<23:42,  3.75s/it]

step 1620: lr 0.000100, train loss 1.7668, val loss 1.7818, ppl 6, time elapsed: 6816.75 seconds


 82%|██████████████████████████████████████████████████████████▎            | 1641/2000 [1:54:42<22:16,  3.72s/it]

step 1640: lr 0.000100, train loss 1.7222, val loss 1.8294, ppl 6, time elapsed: 6880.42 seconds


 83%|██████████████████████████████████████████████████████████▉            | 1661/2000 [1:55:46<21:17,  3.77s/it]

step 1660: lr 0.000100, train loss 1.7512, val loss 1.7304, ppl 6, time elapsed: 6943.88 seconds


 84%|███████████████████████████████████████████████████████████▋           | 1681/2000 [1:56:51<22:34,  4.25s/it]

step 1680: lr 0.000100, train loss 1.7857, val loss 1.7805, ppl 6, time elapsed: 7009.36 seconds


 85%|████████████████████████████████████████████████████████████▍          | 1701/2000 [1:58:01<22:13,  4.46s/it]

step 1700: lr 0.000100, train loss 1.8437, val loss 1.7718, ppl 6, time elapsed: 7078.02 seconds


 86%|█████████████████████████████████████████████████████████████          | 1721/2000 [1:59:09<18:11,  3.91s/it]

step 1720: lr 0.000100, train loss 1.7713, val loss 1.8026, ppl 6, time elapsed: 7147.28 seconds


 87%|█████████████████████████████████████████████████████████████▊         | 1741/2000 [2:00:14<16:17,  3.77s/it]

step 1740: lr 0.000100, train loss 1.8532, val loss 1.7893, ppl 6, time elapsed: 7211.83 seconds


 88%|██████████████████████████████████████████████████████████████▌        | 1761/2000 [2:01:18<15:20,  3.85s/it]

step 1760: lr 0.000100, train loss 1.8299, val loss 1.8081, ppl 6, time elapsed: 7276.38 seconds


 89%|███████████████████████████████████████████████████████████████▏       | 1781/2000 [2:02:22<13:36,  3.73s/it]

step 1780: lr 0.000100, train loss 1.7622, val loss 1.7025, ppl 5, time elapsed: 7340.27 seconds


 90%|███████████████████████████████████████████████████████████████▉       | 1801/2000 [2:03:28<12:43,  3.84s/it]

step 1800: lr 0.000100, train loss 1.6656, val loss 1.8455, ppl 6, time elapsed: 7405.70 seconds


 91%|████████████████████████████████████████████████████████████████▋      | 1821/2000 [2:04:32<11:50,  3.97s/it]

step 1820: lr 0.000100, train loss 1.7491, val loss 1.8033, ppl 6, time elapsed: 7470.19 seconds


 92%|█████████████████████████████████████████████████████████████████▎     | 1841/2000 [2:05:40<10:15,  3.87s/it]

step 1840: lr 0.000100, train loss 1.8502, val loss 1.7451, ppl 6, time elapsed: 7537.74 seconds


 93%|██████████████████████████████████████████████████████████████████     | 1861/2000 [2:06:44<08:37,  3.72s/it]

step 1860: lr 0.000100, train loss 1.7315, val loss 1.8476, ppl 6, time elapsed: 7601.86 seconds


 94%|██████████████████████████████████████████████████████████████████▊    | 1881/2000 [2:07:47<07:31,  3.80s/it]

step 1880: lr 0.000100, train loss 1.7628, val loss 1.7744, ppl 6, time elapsed: 7665.47 seconds


 95%|███████████████████████████████████████████████████████████████████▍   | 1901/2000 [2:08:53<06:38,  4.02s/it]

step 1900: lr 0.000100, train loss 1.8645, val loss 1.7410, ppl 6, time elapsed: 7731.53 seconds


 96%|████████████████████████████████████████████████████████████████████▏  | 1921/2000 [2:09:58<05:03,  3.84s/it]

step 1920: lr 0.000100, train loss 1.7621, val loss 1.8804, ppl 7, time elapsed: 7795.64 seconds


 97%|████████████████████████████████████████████████████████████████████▉  | 1941/2000 [2:11:01<03:37,  3.68s/it]

step 1940: lr 0.000100, train loss 1.7243, val loss 1.8273, ppl 6, time elapsed: 7859.46 seconds


 98%|█████████████████████████████████████████████████████████████████████▌ | 1961/2000 [2:12:05<02:24,  3.71s/it]

step 1960: lr 0.000100, train loss 1.7415, val loss 1.7295, ppl 6, time elapsed: 7923.38 seconds


 99%|██████████████████████████████████████████████████████████████████████▎| 1981/2000 [2:13:09<01:11,  3.79s/it]

step 1980: lr 0.000100, train loss 1.8125, val loss 1.7438, ppl 6, time elapsed: 7987.13 seconds


100%|███████████████████████████████████████████████████████████████████████| 2000/2000 [2:14:10<00:00,  4.03s/it]

step 1999: lr 0.000100, train loss 1.7581, val loss 1.7197, ppl 6, time elapsed: 8048.32 seconds





# inference test before you decide to save it

In [6]:
from inference import generate
prompt = "Once upon a time"
model.eval()
output = generate(
    prompt, 
    model, 
    tokenizer,
    #max_gen_len = 512,
    temperature = 0.9,
    #memory_saver_div = 8,
    #top_p = 0.9,
    #top_k = 32,
)
model.train()
print(output)

Once upon a time, there was a chake it was sad and special chake the chake it was a chake it saw a big chake. The chake the chake it did not real flow and found a girl and played back to play with his bird was happy. The chake it kept magirl and playing that the chake the chake the chake the chake the chake the chake. The chake the chake the chake the chake it and played to be the chake and flew the chake. The chake the chake the chake of the chake. The chake the chake the chake the chake the chake. The chake the chake a chake the chake the chake the chake it goo was a pretty. The chake the chake the chake the chake to playing together of the chake the chake the chake for her came and flew the chake the chake. The chake the chake the chake the chake the chake. The chake it


# Saving your final model
if `tcfg.checkpoint_interval != None` then checkpoints have already been saved

you DO still need to do this even if you had been saving checkpoints; the final state has not yet been saved

In [7]:
from tools import save_model
save_model(model, cfg, tcfg, log_data)