# Setup

In [1]:
# my virtual environments are rarely properly connected to jupyter so this fixes that. 
# you prolly won't need this cell but running it won't hurt anything either
import sys
import os
current_dir = os.getcwd()  # Get the current working directory
venv_dir = os.path.join(current_dir, 'venv') 
python_version = str(sys.version_info.major) + '.' + str(sys.version_info.minor)
site_packages_path = os.path.join(venv_dir, 'lib', 'python' + python_version, 'site-packages')
sys.path.append(site_packages_path) 

# Instantiate a brand new model

In [2]:
# config file
from config import ModelConfig, TrainConfig
cfg = ModelConfig()
tcfg = TrainConfig()
print(cfg, '\n\n', tcfg)

# import the tokenizer specified by cfg
from tools import import_from_nested_path
imported_objects = import_from_nested_path(['tokenizers', cfg.tokenizer], 'tokenizer', ['get_tokenizer'])
get_tokenizer = imported_objects.get('get_tokenizer')
tokenizer = get_tokenizer(size = cfg.vocab_len)

# model modules
from modules.model import Model
model = Model(cfg).to(cfg.device)

# print the number of parameters in the model
print(f'\n{sum(p.numel() for p in model.parameters())/1e3}K parameters\n')
print(model)

ModelConfig(dim=80, device='cpu', tokenizer='bpe_v1', vocab_len=8192, num_layers=4, second_resid_norm=False, num_heads=2, head_dim=40, max_seq_len=512, mm_bias=False, pmem_size=224, pmem_count=1, scale_first_resid=True, norm_type='RMSNorm', norm_affine=True, norm_bias=True, eps=1e-06) 

 TrainConfig(weight_decay=0.05, batch_size=32, max_iters=4000, eval_interval=40, eval_samples=1, checkpoint_interval=None, lr_init=1e-06, lr_max=0.1, lr_min=0.001, warmup_iters=40, final_flat_iters=400, anneal_type='cos', num_restarts=3, T_mult=2)

928.456K parameters

Model(
  (token_embedder): Embedding(8195, 80)
  (layers): ModuleList(
    (0-3): 4 x Layer(
      (pre_context_norm): Norm()
      (context): ContextMem(
        (k_featurizer): KeyFeatureExtractor(
          (W_k): Linear(in_features=80, out_features=80, bias=False)
          (leaky_avg): LeakyAvg()
        )
        (v_featurizer): ValFeatureExtractor(
          (W_v): Linear(in_features=80, out_features=80, bias=False)
        )
     

# Training

In [3]:
import torch
from tools import get_data_loader
from train import scheduler_lambda, train

optimizer = torch.optim.AdamW(model.parameters(), lr = tcfg.lr_max, weight_decay = tcfg.weight_decay)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_lambda)

train_data_loader = get_data_loader(batch_size=tcfg.batch_size, split='train')
test_data_loader = get_data_loader(batch_size=tcfg.batch_size, split='validation')

In [4]:
if False: # set to true if you'd like to see a graph of the learning rate schedule
    import matplotlib.pyplot as plt
    
    # Generate learning rate values
    lrs = [scheduler_lambda(i) for i in range(tcfg.max_iters)]
    
    # Plot the learning rates
    plt.figure(figsize=(10, 5))
    plt.plot(lrs, label='Learning Rate')
    plt.title('Learning Rate Schedule')
    plt.xlabel('Iteration')
    plt.ylabel('Learning Rate')
    plt.grid(True)
    plt.legend()
    plt.show()

In [5]:
model, optimizer, log_data = train(
    model, 
    tokenizer, 
    cfg, 
    optimizer,
    scheduler,
    tcfg, 
    train_data_loader,
    test_data_loader,
    #log_data: list = None, 
    #detect_anomoly = False # use if you're getting crazy errors about a the gradient being broken
)

  0%|                                                                         | 1/4000 [00:17<19:18:29, 17.38s/it]

step 0000: lr 0.000000, train loss 81.0457, val loss 81.4177, ppl 228704727989345035625610782904942592, time elapsed: 14.47 seconds


  1%|▋                                                                        | 41/4000 [06:10<6:28:32,  5.89s/it]

step 0040: lr 0.010000, train loss 12.4455, val loss 12.4570, ppl 257053, time elapsed: 367.33 seconds


  2%|█▍                                                                       | 81/4000 [10:00<7:08:07,  6.55s/it]

step 0080: lr 0.009322, train loss 4.5349, val loss 4.5225, ppl 92, time elapsed: 597.49 seconds


  3%|██▏                                                                     | 121/4000 [13:53<7:14:55,  6.73s/it]

step 0120: lr 0.007474, train loss 4.3637, val loss 4.3646, ppl 79, time elapsed: 830.64 seconds


  4%|██▉                                                                     | 161/4000 [17:32<6:01:50,  5.66s/it]

step 0160: lr 0.004963, train loss 3.9411, val loss 4.0349, ppl 57, time elapsed: 1049.86 seconds


  5%|███▌                                                                    | 201/4000 [20:35<5:29:26,  5.20s/it]

step 0200: lr 0.002475, train loss 3.8590, val loss 3.8787, ppl 48, time elapsed: 1233.26 seconds


  6%|████▎                                                                   | 241/4000 [23:33<5:15:05,  5.03s/it]

step 0240: lr 0.000692, train loss 3.7726, val loss 3.7395, ppl 42, time elapsed: 1411.43 seconds


  7%|█████                                                                   | 281/4000 [26:27<5:26:38,  5.27s/it]

step 0280: lr 0.009999, train loss 3.6738, val loss 3.7101, ppl 41, time elapsed: 1585.18 seconds


  8%|█████▊                                                                  | 321/4000 [29:13<4:45:17,  4.65s/it]

step 0320: lr 0.009804, train loss 3.2021, val loss 3.2539, ppl 26, time elapsed: 1750.60 seconds


  9%|██████▍                                                                 | 361/4000 [31:49<4:52:26,  4.82s/it]

step 0360: lr 0.009277, train loss 2.9622, val loss 2.9397, ppl 19, time elapsed: 1906.54 seconds


 10%|███████▏                                                                | 401/4000 [34:36<5:04:13,  5.07s/it]

step 0400: lr 0.008456, train loss 2.8356, val loss 2.8387, ppl 17, time elapsed: 2073.15 seconds


 11%|███████▉                                                                | 441/4000 [37:33<5:08:51,  5.21s/it]

step 0440: lr 0.007398, train loss 2.6862, val loss 2.6785, ppl 15, time elapsed: 2250.09 seconds


 12%|████████▋                                                               | 481/4000 [40:27<4:58:52,  5.10s/it]

step 0480: lr 0.006176, train loss 2.7287, val loss 2.7611, ppl 16, time elapsed: 2424.94 seconds


 13%|█████████▍                                                              | 521/4000 [43:20<4:44:17,  4.90s/it]

step 0520: lr 0.004875, train loss 2.6026, val loss 2.5853, ppl 13, time elapsed: 2598.00 seconds


 14%|██████████                                                              | 561/4000 [46:03<4:25:47,  4.64s/it]

step 0560: lr 0.003587, train loss 2.6362, val loss 2.5199, ppl 12, time elapsed: 2761.34 seconds


 15%|██████████▊                                                             | 601/4000 [48:39<4:14:31,  4.49s/it]

step 0600: lr 0.002401, train loss 2.4992, val loss 2.5656, ppl 13, time elapsed: 2917.28 seconds


 16%|███████████▌                                                            | 641/4000 [51:11<4:10:55,  4.48s/it]

step 0640: lr 0.001399, train loss 2.4721, val loss 2.5287, ppl 13, time elapsed: 3068.80 seconds


 17%|████████████▎                                                           | 681/4000 [53:40<3:59:18,  4.33s/it]

step 0680: lr 0.000651, train loss 2.5382, val loss 2.4650, ppl 12, time elapsed: 3218.66 seconds


 18%|████████████▉                                                           | 721/4000 [56:11<4:00:37,  4.40s/it]

step 0720: lr 0.000211, train loss 2.4991, val loss 2.4923, ppl 12, time elapsed: 3369.44 seconds


 19%|█████████████▋                                                          | 761/4000 [58:38<3:54:26,  4.34s/it]

step 0760: lr 0.009998, train loss 2.4408, val loss 2.4332, ppl 11, time elapsed: 3515.82 seconds


 20%|██████████████                                                        | 801/4000 [1:00:54<3:27:15,  3.89s/it]

step 0800: lr 0.009938, train loss 2.4048, val loss 2.4324, ppl 11, time elapsed: 3652.30 seconds


 21%|██████████████▋                                                       | 841/4000 [1:03:00<3:14:46,  3.70s/it]

step 0840: lr 0.009792, train loss 2.3370, val loss 2.3506, ppl 10, time elapsed: 3778.27 seconds


 22%|███████████████▍                                                      | 881/4000 [1:05:55<3:56:28,  4.55s/it]

step 0880: lr 0.009563, train loss 2.1910, val loss 2.2378, ppl 9, time elapsed: 3952.61 seconds


 23%|████████████████                                                      | 921/4000 [1:08:03<3:15:08,  3.80s/it]

step 0920: lr 0.009255, train loss 2.1363, val loss 2.1518, ppl 9, time elapsed: 4081.48 seconds


 24%|████████████████▊                                                     | 961/4000 [1:10:11<3:28:31,  4.12s/it]

step 0960: lr 0.008873, train loss 2.1027, val loss 2.0499, ppl 8, time elapsed: 4208.59 seconds


 25%|█████████████████▎                                                   | 1001/4000 [1:12:16<3:13:25,  3.87s/it]

step 1000: lr 0.008424, train loss 2.0448, val loss 2.1145, ppl 8, time elapsed: 4333.71 seconds


 26%|█████████████████▉                                                   | 1041/4000 [1:14:22<3:05:42,  3.77s/it]

step 1040: lr 0.007917, train loss 2.1773, val loss 2.0168, ppl 8, time elapsed: 4460.36 seconds


 27%|██████████████████▋                                                  | 1081/4000 [1:16:26<3:01:04,  3.72s/it]

step 1080: lr 0.007359, train loss 2.0440, val loss 2.0115, ppl 7, time elapsed: 4584.17 seconds


 28%|███████████████████▎                                                 | 1121/4000 [1:18:30<2:57:02,  3.69s/it]

step 1120: lr 0.006761, train loss 1.9650, val loss 2.0563, ppl 8, time elapsed: 4708.01 seconds


 29%|████████████████████                                                 | 1161/4000 [1:20:33<2:55:19,  3.71s/it]

step 1160: lr 0.006133, train loss 2.0452, val loss 2.0251, ppl 8, time elapsed: 4831.33 seconds


 30%|████████████████████▋                                                | 1201/4000 [1:22:34<2:52:49,  3.70s/it]

step 1200: lr 0.005486, train loss 1.9679, val loss 1.9184, ppl 7, time elapsed: 4952.53 seconds


 31%|█████████████████████▍                                               | 1241/4000 [1:24:36<2:47:50,  3.65s/it]

step 1240: lr 0.004832, train loss 1.9698, val loss 1.9951, ppl 7, time elapsed: 5074.31 seconds


 32%|██████████████████████                                               | 1281/4000 [1:26:39<2:51:04,  3.78s/it]

step 1280: lr 0.004181, train loss 1.9166, val loss 1.9230, ppl 7, time elapsed: 5197.06 seconds


 33%|██████████████████████▊                                              | 1321/4000 [1:28:42<2:51:28,  3.84s/it]

step 1320: lr 0.003545, train loss 1.8206, val loss 1.9143, ppl 7, time elapsed: 5319.73 seconds


 34%|███████████████████████▍                                             | 1361/4000 [1:30:42<2:37:58,  3.59s/it]

step 1360: lr 0.002936, train loss 1.9309, val loss 1.8785, ppl 7, time elapsed: 5440.05 seconds


 35%|████████████████████████▏                                            | 1401/4000 [1:32:47<2:39:14,  3.68s/it]

step 1400: lr 0.002364, train loss 1.8565, val loss 1.8707, ppl 6, time elapsed: 5564.74 seconds


 36%|████████████████████████▊                                            | 1441/4000 [1:34:50<2:40:33,  3.76s/it]

step 1440: lr 0.001839, train loss 1.8879, val loss 1.8091, ppl 6, time elapsed: 5687.46 seconds


 37%|█████████████████████████▌                                           | 1481/4000 [1:37:06<3:12:01,  4.57s/it]

step 1480: lr 0.001369, train loss 1.8727, val loss 1.8349, ppl 6, time elapsed: 5824.08 seconds


 38%|██████████████████████████▏                                          | 1521/4000 [1:39:28<2:43:51,  3.97s/it]

step 1520: lr 0.000965, train loss 1.8005, val loss 1.8768, ppl 7, time elapsed: 5965.86 seconds


 39%|██████████████████████████▉                                          | 1561/4000 [1:41:35<2:45:24,  4.07s/it]

step 1560: lr 0.000632, train loss 1.8243, val loss 1.8064, ppl 6, time elapsed: 6092.94 seconds


 40%|███████████████████████████▌                                         | 1601/4000 [1:43:41<2:40:04,  4.00s/it]

step 1600: lr 0.000376, train loss 1.9390, val loss 1.7415, ppl 6, time elapsed: 6218.97 seconds


 41%|████████████████████████████▎                                        | 1641/4000 [1:45:45<2:29:22,  3.80s/it]

step 1640: lr 0.000202, train loss 1.8807, val loss 1.7717, ppl 6, time elapsed: 6343.23 seconds


 42%|████████████████████████████▉                                        | 1681/4000 [1:47:49<2:24:43,  3.74s/it]

step 1680: lr 0.000112, train loss 1.9075, val loss 1.8229, ppl 6, time elapsed: 6466.76 seconds


 43%|█████████████████████████████▋                                       | 1721/4000 [1:49:53<2:18:04,  3.64s/it]

step 1720: lr 0.009998, train loss 2.0156, val loss 1.9215, ppl 7, time elapsed: 6590.96 seconds


 44%|██████████████████████████████▍                                      | 1761/4000 [1:51:57<2:18:01,  3.70s/it]

step 1760: lr 0.009977, train loss 1.9569, val loss 1.9772, ppl 7, time elapsed: 6714.90 seconds


 45%|███████████████████████████████                                      | 1801/4000 [1:54:00<2:14:22,  3.67s/it]

step 1800: lr 0.009934, train loss 1.7754, val loss 1.8368, ppl 6, time elapsed: 6838.03 seconds


 46%|███████████████████████████████▊                                     | 1841/4000 [1:56:03<2:17:14,  3.81s/it]

step 1840: lr 0.009870, train loss 1.8146, val loss 1.8466, ppl 6, time elapsed: 6961.11 seconds


 47%|████████████████████████████████▍                                    | 1881/4000 [1:58:06<2:13:42,  3.79s/it]

step 1880: lr 0.009785, train loss 1.7473, val loss 1.8701, ppl 6, time elapsed: 7084.22 seconds


 48%|█████████████████████████████████▏                                   | 1921/4000 [2:00:09<2:08:36,  3.71s/it]

step 1920: lr 0.009680, train loss 1.7726, val loss 1.7430, ppl 6, time elapsed: 7206.80 seconds


 49%|█████████████████████████████████▊                                   | 1961/4000 [2:02:11<2:08:20,  3.78s/it]

step 1960: lr 0.009554, train loss 1.7187, val loss 1.7314, ppl 6, time elapsed: 7329.36 seconds


 50%|██████████████████████████████████▌                                  | 2001/4000 [2:04:15<2:04:11,  3.73s/it]

step 2000: lr 0.009408, train loss 1.6596, val loss 1.7473, ppl 6, time elapsed: 7453.09 seconds


 51%|███████████████████████████████████▏                                 | 2041/4000 [2:06:24<2:07:19,  3.90s/it]

step 2040: lr 0.009243, train loss 1.6763, val loss 1.7156, ppl 6, time elapsed: 7582.25 seconds


 52%|███████████████████████████████████▉                                 | 2081/4000 [2:08:28<1:57:32,  3.68s/it]

step 2080: lr 0.009060, train loss 1.7102, val loss 1.6817, ppl 5, time elapsed: 7706.17 seconds


 53%|████████████████████████████████████▌                                | 2121/4000 [2:10:31<2:01:44,  3.89s/it]

step 2120: lr 0.008859, train loss 1.6343, val loss 1.6776, ppl 5, time elapsed: 7828.80 seconds


 54%|█████████████████████████████████████▎                               | 2161/4000 [2:12:35<1:53:12,  3.69s/it]

step 2160: lr 0.008642, train loss 1.6030, val loss 1.6323, ppl 5, time elapsed: 7953.71 seconds


 55%|█████████████████████████████████████▉                               | 2201/4000 [2:14:37<1:51:39,  3.72s/it]

step 2200: lr 0.008408, train loss 1.6307, val loss 1.5734, ppl 5, time elapsed: 8074.99 seconds


 56%|██████████████████████████████████████▋                              | 2241/4000 [2:16:40<1:47:59,  3.68s/it]

step 2240: lr 0.008161, train loss 1.6206, val loss 1.6388, ppl 5, time elapsed: 8198.04 seconds


 57%|███████████████████████████████████████▎                             | 2281/4000 [2:18:44<1:46:41,  3.72s/it]

step 2280: lr 0.007899, train loss 1.5377, val loss 1.5900, ppl 5, time elapsed: 8322.24 seconds


 58%|████████████████████████████████████████                             | 2321/4000 [2:20:48<1:47:33,  3.84s/it]

step 2320: lr 0.007625, train loss 1.5791, val loss 1.5567, ppl 5, time elapsed: 8446.13 seconds


 59%|████████████████████████████████████████▋                            | 2361/4000 [2:22:53<1:42:48,  3.76s/it]

step 2360: lr 0.007340, train loss 1.5900, val loss 1.5895, ppl 5, time elapsed: 8570.83 seconds


 60%|█████████████████████████████████████████▍                           | 2401/4000 [2:24:56<1:35:01,  3.57s/it]

step 2400: lr 0.007045, train loss 1.5639, val loss 1.5735, ppl 5, time elapsed: 8694.30 seconds


 61%|██████████████████████████████████████████                           | 2441/4000 [2:26:57<1:34:46,  3.65s/it]

step 2440: lr 0.006741, train loss 1.4618, val loss 1.6227, ppl 5, time elapsed: 8815.27 seconds


 62%|██████████████████████████████████████████▊                          | 2481/4000 [2:28:57<1:34:13,  3.72s/it]

step 2480: lr 0.006429, train loss 1.4933, val loss 1.5356, ppl 5, time elapsed: 8935.31 seconds


 63%|███████████████████████████████████████████▍                         | 2521/4000 [2:31:01<1:31:23,  3.71s/it]

step 2520: lr 0.006112, train loss 1.4847, val loss 1.5519, ppl 5, time elapsed: 9059.47 seconds


 64%|████████████████████████████████████████████▏                        | 2561/4000 [2:33:04<1:30:02,  3.75s/it]

step 2560: lr 0.005790, train loss 1.5226, val loss 1.4442, ppl 4, time elapsed: 9181.77 seconds


 65%|████████████████████████████████████████████▊                        | 2601/4000 [2:35:04<1:24:45,  3.64s/it]

step 2600: lr 0.005464, train loss 1.5452, val loss 1.5640, ppl 5, time elapsed: 9302.25 seconds


 66%|█████████████████████████████████████████████▌                       | 2641/4000 [2:37:11<1:24:59,  3.75s/it]

step 2640: lr 0.005137, train loss 1.4846, val loss 1.4983, ppl 4, time elapsed: 9428.76 seconds


 67%|██████████████████████████████████████████████▏                      | 2681/4000 [2:39:14<1:25:06,  3.87s/it]

step 2680: lr 0.004810, train loss 1.4899, val loss 1.4218, ppl 4, time elapsed: 9551.81 seconds


 68%|██████████████████████████████████████████████▉                      | 2721/4000 [2:41:14<1:17:00,  3.61s/it]

step 2720: lr 0.004483, train loss 1.5107, val loss 1.4276, ppl 4, time elapsed: 9672.26 seconds


 69%|███████████████████████████████████████████████▋                     | 2761/4000 [2:43:15<1:18:38,  3.81s/it]

step 2760: lr 0.004159, train loss 1.4064, val loss 1.5109, ppl 5, time elapsed: 9792.61 seconds


 70%|████████████████████████████████████████████████▎                    | 2801/4000 [2:45:16<1:11:21,  3.57s/it]

step 2800: lr 0.003839, train loss 1.5074, val loss 1.4223, ppl 4, time elapsed: 9914.28 seconds


 71%|█████████████████████████████████████████████████                    | 2841/4000 [2:47:18<1:11:15,  3.69s/it]

step 2840: lr 0.003525, train loss 1.4810, val loss 1.4638, ppl 4, time elapsed: 10035.91 seconds


 72%|█████████████████████████████████████████████████▋                   | 2881/4000 [2:49:20<1:06:35,  3.57s/it]

step 2880: lr 0.003216, train loss 1.4305, val loss 1.4651, ppl 4, time elapsed: 10158.78 seconds


 73%|██████████████████████████████████████████████████▍                  | 2921/4000 [2:51:21<1:04:59,  3.61s/it]

step 2920: lr 0.002916, train loss 1.4350, val loss 1.4360, ppl 4, time elapsed: 10279.20 seconds


 74%|███████████████████████████████████████████████████                  | 2961/4000 [2:53:24<1:01:45,  3.57s/it]

step 2960: lr 0.002626, train loss 1.4761, val loss 1.5143, ppl 5, time elapsed: 10402.57 seconds


 75%|███████████████████████████████████████████████████▊                 | 3001/4000 [2:55:25<1:02:20,  3.74s/it]

step 3000: lr 0.002345, train loss 1.4550, val loss 1.4824, ppl 4, time elapsed: 10523.09 seconds


 76%|█████████████████████████████████████████████████████▉                 | 3041/4000 [2:57:26<59:05,  3.70s/it]

step 3040: lr 0.002077, train loss 1.4079, val loss 1.5458, ppl 5, time elapsed: 10644.44 seconds


 77%|██████████████████████████████████████████████████████▋                | 3081/4000 [2:59:28<55:48,  3.64s/it]

step 3080: lr 0.001822, train loss 1.3432, val loss 1.5295, ppl 5, time elapsed: 10766.67 seconds


 78%|███████████████████████████████████████████████████████▍               | 3121/4000 [3:01:29<53:47,  3.67s/it]

step 3120: lr 0.001581, train loss 1.4778, val loss 1.3708, ppl 4, time elapsed: 10887.55 seconds


 79%|████████████████████████████████████████████████████████               | 3161/4000 [3:03:30<51:26,  3.68s/it]

step 3160: lr 0.001355, train loss 1.3840, val loss 1.3749, ppl 4, time elapsed: 11008.25 seconds


 80%|████████████████████████████████████████████████████████▊              | 3201/4000 [3:07:03<55:45,  4.19s/it]

step 3200: lr 0.001145, train loss 1.4134, val loss 1.3835, ppl 4, time elapsed: 11221.11 seconds


 81%|█████████████████████████████████████████████████████████▌             | 3241/4000 [3:09:10<47:52,  3.78s/it]

step 3240: lr 0.000953, train loss 1.4078, val loss 1.3409, ppl 4, time elapsed: 11348.06 seconds


 82%|██████████████████████████████████████████████████████████▏            | 3281/4000 [3:11:16<45:30,  3.80s/it]

step 3280: lr 0.000778, train loss 1.4848, val loss 1.4374, ppl 4, time elapsed: 11474.10 seconds


 83%|██████████████████████████████████████████████████████████▉            | 3321/4000 [3:13:23<43:43,  3.86s/it]

step 3320: lr 0.000622, train loss 1.3912, val loss 1.4538, ppl 4, time elapsed: 11601.05 seconds


 84%|███████████████████████████████████████████████████████████▋           | 3361/4000 [3:15:29<40:19,  3.79s/it]

step 3360: lr 0.000485, train loss 1.4131, val loss 1.3238, ppl 4, time elapsed: 11727.45 seconds


 85%|████████████████████████████████████████████████████████████▎          | 3401/4000 [3:17:36<40:06,  4.02s/it]

step 3400: lr 0.000369, train loss 1.4153, val loss 1.3638, ppl 4, time elapsed: 11854.13 seconds


 86%|█████████████████████████████████████████████████████████████          | 3441/4000 [3:19:52<36:04,  3.87s/it]

step 3440: lr 0.000272, train loss 1.4833, val loss 1.3898, ppl 4, time elapsed: 11989.99 seconds


 87%|█████████████████████████████████████████████████████████████▊         | 3481/4000 [3:21:57<32:34,  3.77s/it]

step 3480: lr 0.000197, train loss 1.4072, val loss 1.3261, ppl 4, time elapsed: 12115.15 seconds


 88%|██████████████████████████████████████████████████████████████▍        | 3521/4000 [3:24:03<30:45,  3.85s/it]

step 3520: lr 0.000143, train loss 1.3723, val loss 1.4531, ppl 4, time elapsed: 12240.94 seconds


 89%|███████████████████████████████████████████████████████████████▏       | 3561/4000 [3:26:05<27:00,  3.69s/it]

step 3560: lr 0.000111, train loss 1.4089, val loss 1.3708, ppl 4, time elapsed: 12363.01 seconds


 90%|███████████████████████████████████████████████████████████████▉       | 3601/4000 [3:28:06<23:52,  3.59s/it]

step 3600: lr 0.000100, train loss 1.3878, val loss 1.3364, ppl 4, time elapsed: 12484.38 seconds


 91%|████████████████████████████████████████████████████████████████▋      | 3641/4000 [3:30:09<22:02,  3.69s/it]

step 3640: lr 0.000100, train loss 1.3778, val loss 1.4122, ppl 4, time elapsed: 12606.84 seconds


 92%|█████████████████████████████████████████████████████████████████▎     | 3681/4000 [3:32:11<19:32,  3.68s/it]

step 3680: lr 0.000100, train loss 1.4492, val loss 1.4079, ppl 4, time elapsed: 12729.23 seconds


 93%|██████████████████████████████████████████████████████████████████     | 3721/4000 [3:34:18<24:42,  5.31s/it]

step 3720: lr 0.000100, train loss 1.4406, val loss 1.3319, ppl 4, time elapsed: 12856.46 seconds


 94%|██████████████████████████████████████████████████████████████████▊    | 3761/4000 [3:36:22<14:34,  3.66s/it]

step 3760: lr 0.000100, train loss 1.3513, val loss 1.3574, ppl 4, time elapsed: 12979.97 seconds


 95%|███████████████████████████████████████████████████████████████████▍   | 3801/4000 [3:38:28<12:15,  3.70s/it]

step 3800: lr 0.000100, train loss 1.3687, val loss 1.3822, ppl 4, time elapsed: 13106.13 seconds


 96%|████████████████████████████████████████████████████████████████████▏  | 3841/4000 [3:40:30<09:56,  3.75s/it]

step 3840: lr 0.000100, train loss 1.4106, val loss 1.4255, ppl 4, time elapsed: 13228.00 seconds


 97%|████████████████████████████████████████████████████████████████████▉  | 3881/4000 [3:42:34<07:22,  3.72s/it]

step 3880: lr 0.000100, train loss 1.3664, val loss 1.3953, ppl 4, time elapsed: 13352.17 seconds


 98%|█████████████████████████████████████████████████████████████████████▌ | 3921/4000 [3:44:38<05:02,  3.83s/it]

step 3920: lr 0.000100, train loss 1.3529, val loss 1.3685, ppl 4, time elapsed: 13476.21 seconds


 99%|██████████████████████████████████████████████████████████████████████▎| 3961/4000 [3:46:41<02:33,  3.92s/it]

step 3960: lr 0.000100, train loss 1.3310, val loss 1.4318, ppl 4, time elapsed: 13598.92 seconds


100%|███████████████████████████████████████████████████████████████████████| 4000/4000 [3:48:42<00:00,  3.43s/it]

step 3999: lr 0.000100, train loss 1.4152, val loss 1.3332, ppl 4, time elapsed: 13719.72 seconds





# inference test before you decide to save it

In [6]:
from inference import generate
prompt = "Once upon a time"
model.eval()
output = generate(
    prompt, 
    model, 
    tokenizer,
    #max_gen_len = 512,
    temperature = 0.9,
    #memory_saver_div = 8,
    #top_p = 0.9,
    #top_k = 32,
)
model.train()
print(output)

Once upon a time, there was a busy cat. The busy cat lived in a big tree. The busy cat loved to play with the busy cat. One day, the busy cat loved to play with the busy cat wanted to play. The busy cat was a sunprised to play with the busy cat thought it would be slow. The busy cat wanted to play with the busy cat. The busy cat still always play with the busy cat would be slow again. The busy cat didn't be slow anymore.


# Saving your final model
if `tcfg.checkpoint_interval != None` then checkpoints have already been saved

you DO still need to do this even if you had been saving checkpoints; the final state has not yet been saved

In [7]:
from tools import save_model
save_model(model, cfg, tcfg, log_data)