# ToDo

- [x] figure out how to implement memory_saver_div into the kv cache
- [x] add dropout
- [ ] train bigger version (longer context length?)
- [ ] copy & paste into model.py
- [ ] make a train.py
- [ ] make a params.py
- [ ] build colab notebook

# Setup

In [None]:
# my virtual environments are rarely properly connected to jupyter so this fixes that
import sys
import os
current_dir = os.getcwd()  # Get the current working directory
venv_dir = os.path.join(current_dir, 'venv') 
python_version = str(sys.version_info.major) + '.' + str(sys.version_info.minor)
site_packages_path = os.path.join(venv_dir, 'lib', 'python' + python_version, 'site-packages')
sys.path.append(site_packages_path) 

In [1]:
# importing the model config
from params import *

# importing minLlama3
from model import *

# used in the training loop
import time

# used to save & load models
import json
from dataclasses import asdict

In [2]:
params = ModelArgs()
print(params)

ModelArgs(dim=128, n_layers=8, n_heads=4, n_kv_heads=1, vocab_size=512, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, rope_theta=10000, max_batch_size=32, max_seq_len=512, device='cpu', dropout_rate=0.1)


In [3]:
# load the dataset
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# the first 200 characters. It's just one continuous text document with all of the works of shakespeare back-to-back
print(text[:200])

# Train and test splits
data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be our training dataset, the rest for validation
train_data = data[:n]
val_data = data[n:]

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you


# Instantiate a brand new model

In [4]:
model = Llama3(params, tokenizer).to(params.device)

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

print(model)

2033.792 K parameters
Llama3(
  (tok_embeddings): Embedding(512, 128)
  (layers): ModuleList(
    (0-7): 8 x TransformerBlock(
      (attention): Attention(
        (wq): Linear(in_features=128, out_features=128, bias=False)
        (wk): Linear(in_features=128, out_features=32, bias=False)
        (wv): Linear(in_features=128, out_features=32, bias=False)
        (wo): Linear(in_features=128, out_features=128, bias=False)
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=128, out_features=512, bias=False)
        (w2): Linear(in_features=512, out_features=128, bias=False)
        (w3): Linear(in_features=128, out_features=512, bias=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (output): Linear(in_features=128, out_features=512, bias=False)
  (criterion): CrossEntropyLoss()
)


# Training

In [5]:
# data loading for training which generates a small batch of data of inputs x and targets y
def get_batch(split, batch_size):
    # whether we grab from our training or validation dataset
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - params.max_seq_len, (batch_size,))
    x = torch.stack([data[i:i+params.max_seq_len] for i in ix])
    y = torch.stack([data[i+1:i+params.max_seq_len+1] for i in ix])
    x, y = x.to(params.device), y.to(params.device)
    return x, y

In [6]:
@torch.no_grad()
def estimate_loss(model, batch_size, eval_iters = 5): # to estimate loss during the training loop
    out = {}
    model.eval() # sets model to eval mode
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, batch_size)
            logits, loss = model(X, targets=Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train() # just resets to training mode
    return out

In [7]:
# create a PyTorch optimizer
# this is not what they used, but this learning rate & weight decay work for our tiny minGemma
lr_init = 1e-2
weight_decay = 0.02
optimizer = torch.optim.AdamW(model.parameters(), lr=lr_init, weight_decay=weight_decay)

# how long we want to train for
max_iters = 10

# how often we want to check & see how our loss is doing
eval_interval = 5

# Warmup setup
warmup_iters = 2  # Number of warmup iterations
warmup_factor = 1e-2  # Warmup factor (initial learning rate is multiplied by this factor)

lr_final = 1e-5  # Minimum learning rate

def lr_lambda(current_iter):
    if current_iter < warmup_iters:
        # Warmup phase
        return warmup_factor + (1 - warmup_factor) * current_iter / warmup_iters
    else:
        # Cosine decay phase with minimum learning rate
        decay_iters = max_iters - warmup_iters
        cosine_decay = 0.5 * (1 + math.cos(math.pi * (current_iter - warmup_iters) / decay_iters))
        return max(cosine_decay, lr_final / lr_init)
        
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

In [8]:
start_time = time.time()

# Enable anomaly detection. uncomment these lines if you need to do extensive debugging
#torch.autograd.set_detect_anomaly(True)

for iter in range(max_iters):

    # sample a batch of data
    xb, yb = get_batch('train', params.max_batch_size)
    
    # train
    logits, loss = model(xb, targets=yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    # Update the learning rate
    scheduler.step()
    
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        current_time = time.time()
        elapsed_time = current_time - start_time
        losses = estimate_loss(model, params.max_batch_size)
        current_lr = optimizer.param_groups[0]['lr']
        print(f"step {iter:04d}: lr {current_lr:.6f}, train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, time elapsed: {elapsed_time:.2f} seconds")

# Disable anomaly detection after the training loop
#torch.autograd.set_detect_anomaly(False)

step 0000: lr 0.005050, train loss 6.3305, val loss 6.3268, time elapsed: 3.69 seconds
step 0005: lr 0.005000, train loss 4.4239, val loss 4.4928, time elapsed: 31.13 seconds
step 0009: lr 0.000010, train loss 4.2324, val loss 4.3182, time elapsed: 55.30 seconds


# inference test before you decide to save it

In [9]:
print(model.generate("JULIET:\nO Romeo, Romeo! wherefore art thou R"))

JULIET:
O Romeo, Romeo! wherefore art thou R a a for that mo the be you to doe have com sa of him your to the fort the ha in not the me a you is mo I of my a the that him to that al:

n pr, and you the of that to of that do you the to of of the so
To the the de de the to ma hes di ma, to hoe have sh, wo de the thee, the de, will be him not wo in the the of pr be my thee in the go do of be a hoe my dod to the with wo with you and deesn the dor the your pr c the be I have f that for w will this mod the to de a in that not the la it sh the prng br the thy the a ca that de to dore's ca of ws to iss a mo trll, w, not the from sh he to f not do I a wo a mo to fors the fa
l di's a the be so:
B in sh that have you wo and my wo di the be the to it this it ho, not in c in not pr I wo.

At lo sh, co not the of of to that dore my pr ct do's w in that a


# Saving your model

In [40]:
name = f'models/{model.__class__.__name__}_{time.strftime("%Y-%m-%d|%H-%M-%S")}'
torch.save(model.state_dict(), f'{name}.pth')

# Convert the dataclass object to a dictionary
params_dict = asdict(params)

# Serialize the dictionary to a JSON file
with open(f'{name}.json', 'w') as f:
    json.dump(params_dict, f)