# Setup

In [3]:
# my virtual environments are rarely properly connected to jupyter so this fixes that
# running it shouldn't break anything for u
import sys
import os
current_dir = os.getcwd()  # Get the current working directory
venv_dir = os.path.join(current_dir, 'venv') 
python_version = str(sys.version_info.major) + '.' + str(sys.version_info.minor)
site_packages_path = os.path.join(venv_dir, 'lib', 'python' + python_version, 'site-packages')
sys.path.append(site_packages_path) 

In [4]:
# importing the model config
from config import *

# importing N-GPT
from model import cosine_norm, Model

# imports for the tokenizer
import pickle
from tokenizer.tokenizer import BPE_Tokenizer

# used in the training loop
import time
import math

# used to save & load models
import json
from dataclasses import asdict

# Instantiate a brand new model

In [6]:
cfg = ModelConfig()
print(cfg)
tcfg = TrainConfig()
print(tcfg)

# size options are 512, 1024 and 2048
with open(f'tokenizer/models/{cfg.vocab_len - 3}.model', 'rb') as f:
        tokenizer_data = pickle.load(f)
tokenizer = BPE_Tokenizer(tokenizer_data['merges']) 

ModelConfig(dim=96, device='mps', max_seq_len=256, theta=10000, vocab_len=2048, num_layers=8, num_heads=4, mlp_hidden_mult=4)
TrainConfig(model_name='N-GPT_1m', micro_batch_size=16, grad_accum_steps=4, max_iters=1000, eval_interval=50, beta1=0.9, beta2=0.95, epsilon=1e-08, lr_init=0.005, lr_final=1e-05)


In [7]:
model = Model(cfg).to(cfg.device)

# print the number of parameters in the model
print(f'{model.get_num_params()} parameters')
print(model)

1089792 parameters
Model(
  (precompute_freqs): PrecomputeRotaryFrequencies()
  (token_embedder): Embedding(2048, 96)
  (layers): ModuleList(
    (0-7): 8 x Layer(
      (attn): SelfAttention(
        (Wq): Linear(in_features=96, out_features=96, bias=False)
        (Wk): Linear(in_features=96, out_features=96, bias=False)
        (Wv): Linear(in_features=96, out_features=96, bias=False)
        (Wo): Linear(in_features=96, out_features=96, bias=False)
      )
      (mlp): MLP(
        (Wup): Linear(in_features=96, out_features=256, bias=False)
        (Wgate): Linear(in_features=96, out_features=256, bias=False)
        (Wdown): Linear(in_features=256, out_features=96, bias=False)
      )
    )
  )
  (output): Linear(in_features=96, out_features=2048, bias=False)
  (criterion): CrossEntropyLoss()
)


# Training

In [9]:
# load the dataset
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# the first 200 characters. It's just one continuous text document with all of the works of shakespeare back-to-back
print(text[:200])

# Train and test splits
data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be our training dataset, the rest for validation
train_data = data[:n]
val_data = data[n:]

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you


KeyboardInterrupt: 

In [None]:
# data loading for training which generates a small batch of data of inputs x and targets y
def get_batch(split, batch_size):
    # whether we grab from our training or validation dataset
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - cfg.max_seq_len, (batch_size,))
    x = torch.stack([data[i:i+cfg.max_seq_len] for i in ix])
    y = torch.stack([data[i+1:i+cfg.max_seq_len+1] for i in ix])
    x, y = x.to(cfg.device), y.to(cfg.device)
    return x, y

In [None]:
@torch.no_grad()
def estimate_loss(model, batch_size, eval_iters = 3): # to estimate loss during the training loop
    out = {}
    model.eval() # sets model to eval mode
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, batch_size)
            logits, loss = model(X, target_token_ids=Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train() # just resets to training mode
    return out

In [None]:
# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=tcfg.lr_init, weight_decay=0.0)
    # No weight decay to keep vectors on the unit hypersphere

# Learning rate schedule without warmup
def lr_lambda(current_iter):
    # Cosine decay phase only
    cosine_decay = 0.5 * (1 + math.cos(math.pi * current_iter / tcfg.max_iters))
    return max(cosine_decay, tcfg.lr_final / tcfg.lr_init)
        
# Scheduler using cosine decay
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

In [None]:
start_time = time.time()
model.train()

for iter in range(tcfg.max_iters):
    # every once in a while evaluate the loss on train and val sets
    if iter % tcfg.eval_interval == 0 or iter == tcfg.max_iters - 1:
        current_time = time.time()
        elapsed_time = current_time - start_time
        losses = estimate_loss(model, tcfg.micro_batch_size)
        current_lr = optimizer.param_groups[0]['lr']
        print(f"step {iter:04d}: lr {current_lr:.6f}, train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, time elapsed: {elapsed_time:.2f} seconds")
        
    # setup for training
    optimizer.zero_grad()
    loss_accum = 0.0

    # we can simulate a larget batch size by accumulating gradients over many micro batches
    for micro_step in range(tcfg.grad_accum_steps):
        # sample a batch of data
        xb, yb = get_batch('train', tcfg.micro_batch_size)
        
        # train
        logits, loss = model(input_token_ids = xb, target_token_ids = yb)
        
        # accounting for the size of the micro batch
        loss = loss / tcfg.grad_accum_steps
        # adding the micro batch's loss to the total loss
        loss_accum += loss.detach()
        loss.backward()
        
    # update the parameters
    optimizer.step()
    
    # Cosine normalization for all Linear layers 
    with torch.no_grad():
        for module in model.modules():
            if isinstance(module, torch.nn.Linear):
                weight = module.weight
                # Find the dimension that matches cfg.dim
                dim_to_normalize = None
                for dim, size in enumerate(weight.shape):
                    if size == cfg.dim:
                        dim_to_normalize = dim
                        break
                
                if dim_to_normalize is not None:
                    # Normalize the weights
                    module.weight.data = cosine_norm(module.weight.data, dim=dim_to_normalize)
                    # Calculate the norm along the specified dimension
                    #norm = weight.data.norm(dim=dim_to_normalize)
                    #assert torch.allclose(norm, torch.ones_like(norm), atol=1e-4), \
                        #f"Weights in Linear layer {module} are not properly normalized."

        for layer in model.layers:
            # Now loop over all named parameters in each submodule
            for name, param in layer.named_parameters():
                # Check if the parameter's name matches 'a_A' or 'a_M'
                if name in ['a_A', 'a_M']:
                    # Apply absolute value to the parameter in place
                    param.data = param.data.abs()
                    #assert (param.data >= 0).all(), f"Parameter {name} contains negative values."
                
                        
    # Update the learning rate
    scheduler.step()

In [None]:
# Final absolute value check after training
for module in model.modules():
    if isinstance(module, torch.nn.ModuleList):
        for submodule in module:
            for name, param in submodule.named_parameters():
                if name in ['a_A', 'a_M']:
                    if (param.data < 0).any():
                        print(f"Warning: Parameter {name} contains negative values post-training.")

# Final cosine normalization check
for module in model.modules():
    if isinstance(module, torch.nn.Linear):
        weight = module.weight
        dim_to_normalize = None
        for dim, size in enumerate(weight.shape):
            if size == cfg.dim:
                dim_to_normalize = dim
                break

        if dim_to_normalize is not None:
            norm = weight.data.norm(dim=dim_to_normalize)
            if not torch.allclose(norm, torch.ones_like(norm), atol=1e-4):
                print(f"Warning: Weights in Linear layer {module} are not normalized post-training.\n{norm}")

In [None]:
import torch
from collections import defaultdict
from tabulate import tabulate

# Dictionary to store parameters grouped by name
params = defaultdict(list)
scale_names = ['a_A', 'a_M', 's_qk', 's_u', 's_v', 's_z']

# Collect all parameters
for module in model.modules():
    for name, param in module.named_parameters():
        if name in scale_names:
            params[name].append({
                'shape': tuple(param.shape),
                'mean': torch.mean(param).item(),
                'std': torch.std(param).item()
            })

# Print results for each parameter type
for param_name in scale_names:
    if params[param_name]:
        print(f"\n=== {param_name} Parameters ===")
        table_data = [[
            i+1,
            str(p['shape']),
            f"{p['mean']:.4f}",
            f"{p['std']:.4f}"
        ] for i, p in enumerate(params[param_name])]
        
        print(tabulate(
            table_data,
            headers=['#', 'Shape', 'Mean', 'Std'],
            tablefmt='simple',
            floatfmt='.4f'
        ))

# inference test before you decide to save it

In [None]:
from inference import generate
output = generate(
    "JULIET:\nO Romeo, Romeo! wherefore art thou", 
    model, 
    tokenizer, 
    temperature=0.01, # really weird that we've gotta use a pretty damn low temperature
    max_gen_len = 128
)
print(output)

# Saving your model

In [None]:
os.makedirs(f'models/{tcfg.model_name}', exist_ok=True)

# saving model
torch.save(model.state_dict(), f'models/{tcfg.model_name}/model.pth')

# saving configs
cfg_dict = asdict(cfg)
with open(f'models/{tcfg.model_name}/model_config.json', 'w') as f:
    json.dump(cfg_dict, f)
tcfg_dict = asdict(tcfg)
with open(f'models/{tcfg.model_name}/train_config.json', 'w') as f:
    json.dump(tcfg_dict, f)

print(f'model successfully saved to models/{tcfg.model_name}/')