# Setup

In [1]:
# my virtual environments are rarely properly connected to jupyter so this fixes that. 
# you won't need this cell but running it won't hurt anything either
import sys
import os
current_dir = os.getcwd()  # Get the current working directory
venv_dir = os.path.join(current_dir, './venv') 
python_version = str(sys.version_info.major) + '.' + str(sys.version_info.minor)
site_packages_path = os.path.join(venv_dir, 'lib', 'python' + python_version, 'site-packages')
sys.path.append(site_packages_path) 

In [2]:
# dataloader
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset

# model modules
from model import *

# used in the training loop
import time
import tqdm

# used to save & load models
import json
from dataclasses import asdict

# inference code
from inference import *

# Instantiate a brand new model

In [3]:
# tokenizer
from tokenizer import *
size = 512 # size options are 128, 256, 512 and 1024
path = f'./tokenizers/tiny_stories_tokenizer_{size}.model'
tokenizer = get_tokenizer(path) 

# config file
from config import *
cfg = Config()
cfg.vocab_len = tokenizer.vocab_len
print(cfg)

Config(dim=128, vocab_len=515, device='cpu', num_layers=10, pre_connect_dropout=False, second_resid_norm=False, mlp_hidden_mult=2, mlp_bias=False, mlp_nonlinearity='GeLU', mlp_gated=True, num_q_heads=4, num_kv_heads=1, theta=10000, max_seq_len=512, scale_first_resid=True, norm_type='RMSNorm', norm_affine=True, norm_bias=True, eps=1e-06, max_batch_size=1)


In [4]:
model = customGPT(cfg).to(cfg.device)

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')
print(model)

1463.936 K parameters
customGPT(
  (token_embedder): Embedding(515, 128)
  (layers): ModuleList(
    (0-9): 10 x ResidualLayer(
      (pre_attn_norm): Norm()
      (attn): MQSA(
        (Wq): Linear(in_features=128, out_features=128, bias=False)
        (Wk): Linear(in_features=128, out_features=32, bias=False)
        (Wv): Linear(in_features=128, out_features=32, bias=False)
        (Wo): Linear(in_features=128, out_features=128, bias=False)
      )
      (pre_mlp_norm): Norm()
      (mlp): MLP(
        (Wgate): Linear(in_features=128, out_features=256, bias=False)
        (Wup): Linear(in_features=128, out_features=256, bias=False)
        (Wdown): Linear(in_features=256, out_features=128, bias=False)
        (nonlinearity): GELU(approximate='none')
      )
    )
  )
  (final_norm): Norm()
  (criterion): CrossEntropyLoss()
)


# Training

In [5]:
class TinyStoriesDataset(Dataset):
    def __init__(self, split):
        # Load the dataset
        self.dataset = load_dataset("noanabeshima/TinyStoriesV2", split=split)
        
    def __len__(self):
        # Return the size of the dataset
        return len(self.dataset)
    
    def __getitem__(self, idx):
        # Fetch one item from the dataset
        return self.dataset[idx]['text']

def get_data_loader(batch_size=32, shuffle=True, split='train', num_workers=0):
    # Create the dataset
    dataset = TinyStoriesDataset(split)
    # Create the DataLoader
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)

In [6]:
batch_size = 32
train_data_loader = get_data_loader(batch_size=batch_size, split='train')
test_data_loader = get_data_loader(batch_size=batch_size, split='validation')

Found cached dataset json (/Users/tunadorable/.cache/huggingface/datasets/noanabeshima___json/noanabeshima--TinyStoriesV2-226173b7dd235c68/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)
Found cached dataset json (/Users/tunadorable/.cache/huggingface/datasets/noanabeshima___json/noanabeshima--TinyStoriesV2-226173b7dd235c68/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)


In [7]:
# To get a batch of data
b = next(iter(train_data_loader))
print(b[0])

There was once a woman named Alexandra who was very charming. One day, Alexandra stepped out of her house to go for a walk. As she was walking, she heard a voice calling her name. It was her neighbour, Peter. Peter was also very charming. He asked her what she was doing and Alexandra said she was taking a stroll. He then offered to join her.
They walked together and started chatting. Alexandra was feeling a little peckish, so Peter offered her some ice cream. She happily accepted, and it was the most delicious ice cream she had ever tasted. They shared their ice creams and enjoyed the beautiful sunshine.
When it was time to go home, Alexandra thanked Peter for the delicious cream and for the lovely walk. "It was so much fun having you as my companion," she said. Peter blushed upon hearing her kind words. Before saying goodbye, Peter stepped forward and gave her a big hug.
Alexandra walked home feeling very content.


In [8]:
def torcherize_batch(batch, max_seq_len):
    b = torch.zeros(len(batch), max_seq_len+1)
    for i, s in enumerate(batch):
        b[i] = torch.tensor(
            tokenizer.encode(s, bos=True, eos=True, pad=max_seq_len+1), 
            device=cfg.device
        )
    x, y = b[:,:max_seq_len], b[:, 1:]
    return x.to(torch.long), y.to(torch.long)

@torch.no_grad()
def estimate_loss(model, dataloader, eval_iters = 3): # to estimate loss during the training loop
    out = {}
    model.eval() # sets model to eval mode
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            batch = next(iter(dataloader))
            X, Y = torcherize_batch(batch, model.max_seq_len)
            logits, loss = model(X, target_token_ids=Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train() # just resets to training mode
    return out

In [9]:
# create a PyTorch optimizer
# this is not what they used, but this learning rate & weight decay work for our tiny minGemma
lr_init = 1e-2
weight_decay = 0.02
optimizer = torch.optim.AdamW(model.parameters(), lr=lr_init, weight_decay=weight_decay)

# how long we want to train for
max_iters = 4

# how often we want to check & see how our loss is doing
eval_interval = 2

# Warmup setup
warmup_iters = 2  # Number of warmup iterations
warmup_factor = 1e-3  # Warmup factor (initial learning rate is multiplied by this factor)

lr_final = 1e-5  # Minimum learning rate

def lr_lambda(current_iter):
    if current_iter < warmup_iters:
        # Warmup phase
        return warmup_factor + (1 - warmup_factor) * current_iter / warmup_iters
    else:
        # Cosine decay phase with minimum learning rate
        decay_iters = max_iters - warmup_iters
        cosine_decay = 0.5 * (1 + math.cos(math.pi * (current_iter - warmup_iters) / decay_iters))
        return max(cosine_decay, lr_final / lr_init)
        
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

In [10]:
start_time = time.time()

# Enable anomaly detection. uncomment these lines if you need to do extensive debugging
#torch.autograd.set_detect_anomaly(True)

for i in range(max_iters):

    # sample a batch of data
    batch = next(iter(train_data_loader))
    x,y = torcherize_batch(batch, cfg.max_seq_len)
    
    # train
    logits, loss = model(x, target_token_ids=y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    # Update the learning rate
    scheduler.step()
    
    # every once in a while evaluate the loss on train and val sets
    if i % eval_interval == 0 or i == max_iters - 1:
        current_time = time.time()
        elapsed_time = current_time - start_time
        losses = estimate_loss(model, test_data_loader)
        current_lr = optimizer.param_groups[0]['lr']
        print(
            f"step {i:04d}: "
            f"lr {current_lr:.6f}, "
            f"train loss {losses['train'].item():.4f}, "
            f"val loss {losses['val'].item():.4f}, "
            f"time elapsed: {elapsed_time:.2f} seconds"
        )

# Disable anomaly detection after the training loop
#torch.autograd.set_detect_anomaly(False)

step 0000: lr 0.005005, train loss 102.0891, val loss 103.0794, time elapsed: 4.22 seconds
step 0002: lr 0.005000, train loss 50.0940, val loss 48.9597, time elapsed: 23.66 seconds
step 0003: lr 0.000010, train loss 37.1550, val loss 35.8669, time elapsed: 38.88 seconds


# inference test before you decide to save it

In [11]:
prompt = "Once upon a time, there was a boy named Evin"
output = generate(
    prompt, 
    model, 
    tokenizer,
    memory_saver_div = 4,
)
print(output)

maximum attention matrix size in memory will be 128x512 rather than 512x512

Once upon a time, there was a boy named Evin                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      


# Saving your model

In [12]:
name = f'models/{model.__class__.__name__}_{time.strftime("%Y-%m-%d|%H-%M-%S")}'
torch.save(model.state_dict(), f'{name}.pth')

# Convert the dataclass object to a dictionary
cfg_dict = asdict(cfg)

# Serialize the dictionary to a JSON file
with open(f'{name}.json', 'w') as f:
    json.dump(cfg_dict, f)