# Setup

In [1]:
# my virtual environments are rarely properly connected to jupyter so this fixes that. 
# you prolly won't need this cell but running it won't hurt anything either
import sys
import os
current_dir = os.getcwd()  # Get the current working directory
venv_dir = os.path.join(current_dir, 'venv') 
python_version = str(sys.version_info.major) + '.' + str(sys.version_info.minor)
site_packages_path = os.path.join(venv_dir, 'lib', 'python' + python_version, 'site-packages')
sys.path.append(site_packages_path) 

# Instantiate a brand new model

In [2]:
# config file
from config import ModelConfig, TrainConfig
cfg = ModelConfig()
tcfg = TrainConfig()
print(cfg, '\n\n', tcfg)

# import the tokenizer specified by cfg
from tools import import_from_nested_path
imported_objects = import_from_nested_path(['tokenizers', cfg.tokenizer], 'tokenizer', ['get_tokenizer'])
get_tokenizer = imported_objects.get('get_tokenizer')
tokenizer = get_tokenizer(size = cfg.vocab_len)

# model modules
from modules.model import Model
model = Model(cfg).to(cfg.device)

# print the number of parameters in the model
print(f'\n{sum(p.numel() for p in model.parameters())/1e3}K parameters\n')
print(model)

ModelConfig(dim=128, device='mps', tokenizer='bpe_v1', vocab_len=8192, num_layers=4, second_resid_norm=False, num_heads=4, head_dim=32, max_seq_len=128, mm_bias=False, pmem_size=336, pmem_count=2, scale_first_resid=True, norm_type='RMSNorm', norm_affine=True, norm_bias=True, eps=1e-06) 

 TrainConfig(weight_decay=0.05, batch_size=32, max_iters=100, eval_interval=1, eval_samples=1, checkpoint_interval=None, lr_init=1e-06, lr_max=0.1, lr_min=0.001, warmup_iters=1, final_flat_iters=10, anneal_type='cos', num_restarts=3, T_mult=2)

2067.184K parameters

Model(
  (token_embedder): Embedding(8195, 128)
  (layers): ModuleList(
    (0-3): 4 x Layer(
      (pre_context_norm): Norm()
      (context): ContextMem(
        (k_featurizer): KeyFeatureExtractor(
          (W_k): Linear(in_features=128, out_features=128, bias=False)
          (leaky_avg): LeakyAvg()
        )
        (v_featurizer): ValFeatureExtractor(
          (W_v): Linear(in_features=128, out_features=128, bias=False)
        )
  

# Training

In [3]:
import torch
from tools import get_data_loader
from train import scheduler_lambda, train

optimizer = torch.optim.AdamW(model.parameters(), lr = tcfg.lr_max, weight_decay = tcfg.weight_decay)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_lambda)

train_data_loader = get_data_loader(batch_size=tcfg.batch_size, split='train')
test_data_loader = get_data_loader(batch_size=tcfg.batch_size, split='validation')

Found cached dataset json (/Users/tunadorable/.cache/huggingface/datasets/noanabeshima___json/noanabeshima--TinyStoriesV2-226173b7dd235c68/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)
Found cached dataset json (/Users/tunadorable/.cache/huggingface/datasets/noanabeshima___json/noanabeshima--TinyStoriesV2-226173b7dd235c68/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)


In [4]:
if False: # set to true if you'd like to see a graph of the learning rate schedule
    import matplotlib.pyplot as plt
    
    # Generate learning rate values
    lrs = [scheduler_lambda(i) for i in range(tcfg.max_iters)]
    
    # Plot the learning rates
    plt.figure(figsize=(10, 5))
    plt.plot(lrs, label='Learning Rate')
    plt.title('Learning Rate Schedule')
    plt.xlabel('Iteration')
    plt.ylabel('Learning Rate')
    plt.grid(True)
    plt.legend()
    plt.show()

In [5]:
model, optimizer, log_data = train(
    model, 
    tokenizer, 
    cfg, 
    optimizer,
    scheduler,
    tcfg, 
    train_data_loader,
    test_data_loader,
    #log_data: list = None, 
    #detect_anomoly = False # use if you're getting crazy errors about a the gradient being broken
)

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  1%|▏                       | 1/100 [00:01<02:56,  1.78s/it]

step 0000: lr 0.000000, train loss 126.8932, val loss 126.4831, ppl inf, time elapsed: 1.45 seconds


  2%|▍                       | 2/100 [00:02<01:58,  1.20s/it]

step 0001: lr 0.010000, train loss 122.7961, val loss 122.5900, ppl inf, time elapsed: 2.27 seconds


  3%|▋                       | 3/100 [00:03<01:39,  1.02s/it]

step 0002: lr 0.009322, train loss 119.0985, val loss 118.6497, ppl inf, time elapsed: 3.08 seconds


  4%|▉                       | 4/100 [00:04<01:30,  1.06it/s]

step 0003: lr 0.007474, train loss 115.7737, val loss 114.8903, ppl inf, time elapsed: 3.88 seconds


  5%|█▏                      | 5/100 [00:05<01:25,  1.12it/s]

step 0004: lr 0.004963, train loss 112.5038, val loss 113.0391, ppl inf, time elapsed: 4.70 seconds


  6%|█▍                      | 6/100 [00:05<01:21,  1.16it/s]

step 0005: lr 0.002475, train loss 111.4729, val loss 110.9884, ppl inf, time elapsed: 5.51 seconds


  7%|█▋                      | 7/100 [00:06<01:21,  1.14it/s]

step 0006: lr 0.000692, train loss 110.9977, val loss 110.5159, ppl inf, time elapsed: 6.36 seconds


  8%|█▉                      | 8/100 [00:07<01:26,  1.06it/s]

step 0007: lr 0.009999, train loss 102.5335, val loss 101.8255, ppl inf, time elapsed: 7.44 seconds


  9%|██▏                     | 9/100 [00:08<01:27,  1.04it/s]

step 0008: lr 0.009804, train loss 93.3865, val loss 92.8597, ppl inf, time elapsed: 8.48 seconds


 10%|██▎                    | 10/100 [00:09<01:27,  1.03it/s]

step 0009: lr 0.009277, train loss 83.0495, val loss 83.9436, ppl 2859238853227330943653326567084392448, time elapsed: 9.48 seconds


 11%|██▌                    | 11/100 [00:10<01:30,  1.01s/it]

step 0010: lr 0.008456, train loss 73.3327, val loss 73.4167, ppl 76641157974354600330006710190080, time elapsed: 10.47 seconds


 12%|██▊                    | 12/100 [00:12<01:46,  1.21s/it]

step 0011: lr 0.007398, train loss 64.8192, val loss 64.7995, ppl 13869847729161464388289298432, time elapsed: 12.11 seconds


 13%|██▉                    | 13/100 [00:13<01:41,  1.17s/it]

step 0012: lr 0.006176, train loss 58.6375, val loss 58.7319, ppl 32132026096328239837347840, time elapsed: 13.31 seconds


 14%|███▏                   | 14/100 [00:14<01:37,  1.13s/it]

step 0013: lr 0.004875, train loss 55.0867, val loss 55.0899, ppl 841847429993231317729280, time elapsed: 14.34 seconds


 15%|███▍                   | 15/100 [00:15<01:34,  1.12s/it]

step 0014: lr 0.003587, train loss 54.0578, val loss 53.9384, ppl 266164179129477072158720, time elapsed: 15.43 seconds


 16%|███▋                   | 16/100 [00:16<01:32,  1.11s/it]

step 0015: lr 0.002401, train loss 53.2725, val loss 52.9111, ppl 95278838263793573691392, time elapsed: 16.53 seconds


 17%|███▉                   | 17/100 [00:17<01:28,  1.07s/it]

step 0016: lr 0.001399, train loss 52.6860, val loss 52.6941, ppl 76692275436052677656576, time elapsed: 17.51 seconds


 18%|████▏                  | 18/100 [00:18<01:26,  1.06s/it]

step 0017: lr 0.000651, train loss 52.2787, val loss 52.6341, ppl 72224038072956294791168, time elapsed: 18.54 seconds


 19%|████▎                  | 19/100 [00:19<01:24,  1.05s/it]

step 0018: lr 0.000211, train loss 52.1697, val loss 51.9371, ppl 35974166103684150198272, time elapsed: 19.56 seconds


 20%|████▌                  | 20/100 [00:21<01:25,  1.07s/it]

step 0019: lr 0.009998, train loss 48.8778, val loss 48.5807, ppl 1254069537487821209600, time elapsed: 20.70 seconds


 21%|████▊                  | 21/100 [00:22<01:22,  1.04s/it]

step 0020: lr 0.009938, train loss 44.8291, val loss 45.0295, ppl 35978914375435026432, time elapsed: 21.69 seconds


 22%|█████                  | 22/100 [00:22<01:19,  1.02s/it]

step 0021: lr 0.009792, train loss 40.7828, val loss 40.8488, ppl 550031035392327680, time elapsed: 22.63 seconds


 23%|█████▎                 | 23/100 [00:23<01:16,  1.01it/s]

step 0022: lr 0.009563, train loss 36.9452, val loss 36.7763, ppl 9369870588182528, time elapsed: 23.58 seconds


 24%|█████▌                 | 24/100 [00:24<01:14,  1.02it/s]

step 0023: lr 0.009255, train loss 32.6411, val loss 32.6947, ppl 158170475397120, time elapsed: 24.51 seconds


 25%|█████▊                 | 25/100 [00:25<01:15,  1.01s/it]

step 0024: lr 0.008873, train loss 29.0265, val loss 29.1850, ppl 4730115850240, time elapsed: 25.60 seconds


 26%|█████▉                 | 26/100 [00:26<01:12,  1.02it/s]

step 0025: lr 0.008424, train loss 25.5189, val loss 25.8396, ppl 166720454656, time elapsed: 26.53 seconds


 27%|██████▏                | 27/100 [00:27<01:11,  1.02it/s]

step 0026: lr 0.007917, train loss 23.3407, val loss 22.8709, ppl 8564744704, time elapsed: 27.50 seconds


 28%|██████▍                | 28/100 [00:28<01:10,  1.02it/s]

step 0027: lr 0.007359, train loss 20.5724, val loss 20.7875, ppl 1066389248, time elapsed: 28.46 seconds


 29%|██████▋                | 29/100 [00:29<01:11,  1.00s/it]

step 0028: lr 0.006761, train loss 19.0581, val loss 19.0913, ppl 195539760, time elapsed: 29.53 seconds


 30%|██████▉                | 30/100 [00:30<01:10,  1.01s/it]

step 0029: lr 0.006133, train loss 17.8579, val loss 17.9189, ppl 60546628, time elapsed: 30.54 seconds


 31%|███████▏               | 31/100 [00:31<01:09,  1.01s/it]

step 0030: lr 0.005486, train loss 16.9118, val loss 17.1066, ppl 26872990, time elapsed: 31.55 seconds


 32%|███████▎               | 32/100 [00:33<01:12,  1.06s/it]

step 0031: lr 0.004832, train loss 16.1411, val loss 16.4623, ppl 14108068, time elapsed: 32.68 seconds


 33%|███████▌               | 33/100 [00:34<01:12,  1.08s/it]

step 0032: lr 0.004181, train loss 15.8663, val loss 15.6894, ppl 6513406, time elapsed: 33.85 seconds


 34%|███████▊               | 34/100 [00:35<01:09,  1.05s/it]

step 0033: lr 0.003545, train loss 14.4945, val loss 14.8151, ppl 2717272, time elapsed: 34.84 seconds


 35%|████████               | 35/100 [00:36<01:10,  1.08s/it]

step 0034: lr 0.002936, train loss 14.5083, val loss 14.4243, ppl 1838242, time elapsed: 36.02 seconds


 36%|████████▎              | 36/100 [00:37<01:09,  1.08s/it]

step 0035: lr 0.002364, train loss 14.2519, val loss 13.9798, ppl 1178579, time elapsed: 37.09 seconds


 37%|████████▌              | 37/100 [00:38<01:06,  1.05s/it]

step 0036: lr 0.001839, train loss 13.8519, val loss 13.7623, ppl 948155, time elapsed: 38.06 seconds


 38%|████████▋              | 38/100 [00:39<01:04,  1.04s/it]

step 0037: lr 0.001369, train loss 13.5570, val loss 13.3443, ppl 624230, time elapsed: 39.07 seconds


 39%|████████▉              | 39/100 [00:40<01:03,  1.04s/it]

step 0038: lr 0.000965, train loss 13.3519, val loss 13.2148, ppl 548419, time elapsed: 40.07 seconds


 40%|█████████▏             | 40/100 [00:41<01:03,  1.06s/it]

step 0039: lr 0.000632, train loss 13.2083, val loss 12.9441, ppl 418354, time elapsed: 41.22 seconds


 41%|█████████▍             | 41/100 [00:42<01:01,  1.04s/it]

step 0040: lr 0.000376, train loss 13.1581, val loss 12.9883, ppl 437260, time elapsed: 42.21 seconds


 42%|█████████▋             | 42/100 [00:43<00:59,  1.03s/it]

step 0041: lr 0.000202, train loss 13.0446, val loss 13.2518, ppl 569081, time elapsed: 43.21 seconds


 43%|█████████▉             | 43/100 [00:44<00:57,  1.00s/it]

step 0042: lr 0.000112, train loss 13.0681, val loss 13.2403, ppl 562559, time elapsed: 44.14 seconds


 44%|██████████             | 44/100 [00:45<00:55,  1.01it/s]

step 0043: lr 0.009998, train loss 12.3136, val loss 12.0247, ppl 166821, time elapsed: 45.14 seconds


 44%|██████████             | 44/100 [00:46<00:58,  1.05s/it]


KeyboardInterrupt: 

# inference test before you decide to save it

In [7]:
from inference import generate
prompt = "Once upon a time"
model.eval()
output = generate(
    prompt, 
    model, 
    tokenizer,
    #max_gen_len = 512,
    temperature = 0.9,
    #memory_saver_div = 8,
    #top_p = 0.9,
    #top_k = 32,
)
model.train()
print(output)

Once upon a timerelaxedJufencetoldCraBrookefinishjuicywedSinationbalancingzippingsuggestedsettingponiesbeautiinsistedyummiestundersufferingtrafficabyscoopidsJuliasplashingspacesrarefinalvaluableippdivefallendancersdiscoveredalrightcloudsrollCawoolkidspolishesscenouredapolotrunkstraofferadultsAbbieinfantshockedwichscoocompasstootellangryapologibrushesDoggy'sfoldeddarBujoggingKatexactlyGuardianhighllCharlotreturneddays: holesopeningexpercircusOnexttimeinvlonelyThencrumbelfchTwinkleaskedtowncubsscarvesiansetestedsecup, bossyaidwebteamshortwelcomlotionscallslawcrackershelperardenungstarsJaneDanthanksudgedtrafficburnt


# Saving your final model
if `tcfg.checkpoint_interval != None` then checkpoints have already been saved

you DO still need to do this even if you had been saving checkpoints; the final state has not yet been saved

In [8]:
from tools import save_model
save_model(model, cfg, tcfg, log_data)

NameError: name 'log_data' is not defined