# ToDo

- [x] figure out how to implement memory_saver_div into the kv cache
- [x] add dropout
- [ ] train bigger version (longer context length?)
- [ ] copy & paste into model.py
- [ ] make a train.py
- [ ] make a params.py
- [ ] build colab notebook

# Setup

In [1]:
# my virtual environments are rarely properly connected to jupyter so this fixes that
import sys
import os
current_dir = os.getcwd()  # Get the current working directory
venv_dir = os.path.join(current_dir, 'venv') 
python_version = str(sys.version_info.major) + '.' + str(sys.version_info.minor)
site_packages_path = os.path.join(venv_dir, 'lib', 'python' + python_version, 'site-packages')
sys.path.append(site_packages_path) 

In [2]:
# importing the model config
from params import *

# importing minLlama3
from model import *

# used in the training loop
import time

# used to save & load models
import json
from dataclasses import asdict

In [3]:
# load the dataset
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# the first 200 characters. It's just one continuous text document with all of the works of shakespeare back-to-back
print(text[:200])

# Train and test splits
data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be our training dataset, the rest for validation
train_data = data[:n]
val_data = data[n:]

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you


# Instantiate a brand new model

In [10]:
params = ModelArgs()
print(params)

ModelArgs(dim=128, n_layers=12, n_heads=4, n_kv_heads=1, vocab_size=512, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, rope_theta=10000, max_batch_size=24, max_seq_len=512, device='cpu', dropout_rate=0.1)


In [11]:
model = Llama3(params, tokenizer).to(params.device)

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

print(model)

2985.088 K parameters
Llama3(
  (tok_embeddings): Embedding(512, 128)
  (layers): ModuleList(
    (0-11): 12 x TransformerBlock(
      (attention): Attention(
        (wq): Linear(in_features=128, out_features=128, bias=False)
        (wk): Linear(in_features=128, out_features=32, bias=False)
        (wv): Linear(in_features=128, out_features=32, bias=False)
        (wo): Linear(in_features=128, out_features=128, bias=False)
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=128, out_features=512, bias=False)
        (w2): Linear(in_features=512, out_features=128, bias=False)
        (w3): Linear(in_features=128, out_features=512, bias=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (output): Linear(in_features=128, out_features=512, bias=False)
  (criterion): CrossEntropyLoss()
)


# Training

In [12]:
# data loading for training which generates a small batch of data of inputs x and targets y
def get_batch(split, batch_size):
    # whether we grab from our training or validation dataset
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - params.max_seq_len, (batch_size,))
    x = torch.stack([data[i:i+params.max_seq_len] for i in ix])
    y = torch.stack([data[i+1:i+params.max_seq_len+1] for i in ix])
    x, y = x.to(params.device), y.to(params.device)
    return x, y

In [13]:
@torch.no_grad()
def estimate_loss(model, batch_size, eval_iters = 5): # to estimate loss during the training loop
    out = {}
    model.eval() # sets model to eval mode
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, batch_size)
            logits, loss = model(X, targets=Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train() # just resets to training mode
    return out

In [14]:
# create a PyTorch optimizer
# this is not what they used, but this learning rate & weight decay work for our tiny minGemma
lr_init = 1e-2
weight_decay = 0.02
optimizer = torch.optim.AdamW(model.parameters(), lr=lr_init, weight_decay=weight_decay)

# how long we want to train for
max_iters = 2000

# how often we want to check & see how our loss is doing
eval_interval = 100

# Warmup setup
warmup_iters = 50  # Number of warmup iterations
warmup_factor = 1e-3  # Warmup factor (initial learning rate is multiplied by this factor)

lr_final = 1e-5  # Minimum learning rate

def lr_lambda(current_iter):
    if current_iter < warmup_iters:
        # Warmup phase
        return warmup_factor + (1 - warmup_factor) * current_iter / warmup_iters
    else:
        # Cosine decay phase with minimum learning rate
        decay_iters = max_iters - warmup_iters
        cosine_decay = 0.5 * (1 + math.cos(math.pi * (current_iter - warmup_iters) / decay_iters))
        return max(cosine_decay, lr_final / lr_init)
        
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

In [15]:
start_time = time.time()

# Enable anomaly detection. uncomment these lines if you need to do extensive debugging
#torch.autograd.set_detect_anomaly(True)

for iter in range(max_iters):

    # sample a batch of data
    xb, yb = get_batch('train', params.max_batch_size)
    
    # train
    logits, loss = model(xb, targets=yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    # Update the learning rate
    scheduler.step()
    
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        current_time = time.time()
        elapsed_time = current_time - start_time
        losses = estimate_loss(model, params.max_batch_size)
        current_lr = optimizer.param_groups[0]['lr']
        print(f"step {iter:04d}: lr {current_lr:.6f}, train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, time elapsed: {elapsed_time:.2f} seconds")

# Disable anomaly detection after the training loop
#torch.autograd.set_detect_anomaly(False)

step 0000: lr 0.000210, train loss 6.4700, val loss 6.4506, time elapsed: 3.92 seconds
step 0100: lr 0.009983, train loss 3.0054, val loss 3.1842, time elapsed: 382.95 seconds
step 0200: lr 0.009853, train loss 2.5898, val loss 2.8687, time elapsed: 769.07 seconds
step 0300: lr 0.009597, train loss 2.3720, val loss 2.6845, time elapsed: 1161.52 seconds
step 0400: lr 0.009222, train loss 2.1987, val loss 2.6030, time elapsed: 1545.14 seconds
step 0500: lr 0.008737, train loss 2.1417, val loss 2.5857, time elapsed: 1924.22 seconds
step 0600: lr 0.008156, train loss 2.0663, val loss 2.5644, time elapsed: 2301.25 seconds
step 0700: lr 0.007493, train loss 1.9619, val loss 2.5837, time elapsed: 2680.06 seconds
step 0800: lr 0.006765, train loss 1.9134, val loss 2.5494, time elapsed: 3057.81 seconds
step 0900: lr 0.005992, train loss 1.8540, val loss 2.5846, time elapsed: 3436.21 seconds
step 1000: lr 0.005193, train loss 1.7917, val loss 2.6647, time elapsed: 3831.66 seconds
step 1100: lr 0

KeyboardInterrupt: 

# inference test before you decide to save it

In [21]:
print(model.generate("JULIET:\nO Romeo, Romeo! wherefore art thou R"))

JULIET:
O Romeo, Romeo! wherefore art thou Romeo?
Ah, Romeo! thy hardness shall have heavy fire,
That thou shalt swear thy country and thy brother,
Shall I be patient to thy grief again;
And I for thee here a barge for thee,
Look, then thou hast thy side both being all,
But thou hast won thee to thy hand and me?
O she be not thy shoulder-black deserved thoughts,
Or I with banishment with some axe,
Shall I ever will have the time a happy days
And turn thy love, and my good night have shorted
That I am in a golden side.

ROMEO:
I am a poison of this fair is a soldier,
And I am gone with testimonies with my head.

FRIAR LAURENCE:
A thousand times are too dear fortune's death!

ROMEO:
A fair sir, that is thy device that way
Had on thy lives and meditating with my hand;
The trumpet, book in thy daughter shall be gone.

FRIAR LAURENCE:
And thou love me, and first 


# Saving your model

In [22]:
name = f'models/{model.__class__.__name__}_{time.strftime("%Y-%m-%d|%H-%M-%S")}'
torch.save(model.state_dict(), f'{name}.pth')

# Convert the dataclass object to a dictionary
params_dict = asdict(params)

# Serialize the dictionary to a JSON file
with open(f'{name}.json', 'w') as f:
    json.dump(params_dict, f)