In [4]:
## IMPORTS
import os
import random

from tqdm import tqdm
from tokenizers import ByteLevelBPETokenizer
import torch

from model import AttentionModel, init_weights
from utils import get_all_files, get_lr, save_checkpoint, tokenize_files, get_random_batch

In [7]:
## CONSTANTS
BATCH_SIZE = 16
CONTEXT_SIZE = 1792
VOCAB_SIZE = 1024
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"{DEVICE=}")

DEVICE=device(type='cuda')


In [8]:
train_files = get_all_files("./data/train")
val_files = get_all_files("./data/val")

print(f"Train files: {len(train_files)}")
print(f"Validation files: {len(val_files)}")

Train files: 216788
Validation files: 4425


In [9]:
tokenizer = ByteLevelBPETokenizer()

START_TOKEN = "<|startoftext|>"
EOD_TOKEN = "<|endoftext|>"
PAD_TOKEN = "<|pad|>"

tokenizer.train(
    files=train_files + val_files,
    vocab_size=VOCAB_SIZE,
    min_frequency=2,
    special_tokens=[EOD_TOKEN,PAD_TOKEN]
)

def encode(text):
    return tokenizer.encode(text).ids

def decode(digits):
    return tokenizer.decode(digits)

print(decode(encode("Hello, world!")))




Hello, world!


In [11]:
# Get the vocabulary as a dictionary: {token: index}
vocab_dict = tokenizer.get_vocab()

# # Print the first 20 tokens (sorted by index)
# sorted_vocab = sorted(vocab_dict.items(), key=lambda item: item[1])
# for token, idx in sorted_vocab[:1000]:
#     print(f"{idx}: {repr(token)}")


In [None]:
lens = []
train_files_clean = []
for file in tqdm(train_files):
    with open(file, "r", encoding='utf-8') as f:
        text = f.read()
    text_encoded = encode(text)
    lens.append(len(text_encoded))
    if len(text_encoded)<CONTEXT_SIZE:
        train_files_clean.append(file)

 73%|█████████████████████████████████████████████████████████████▏                      | 157781/216788 [02:20<00:56, 1046.50it/s]

In [13]:
val_files_clean = []
for file in tqdm(val_files):
    with open(file, "r", encoding='utf-8') as f:
        text = f.read()
    text_encoded = encode(text)
    lens.append(len(text_encoded))
    if len(text_encoded)<CONTEXT_SIZE:
        val_files_clean.append(file)

100%|████████████████████████████████████████████████████████████████████████████████████████| 4425/4425 [00:04<00:00, 1099.85it/s]


In [20]:
train_data = []
val_data = []

pad_id = encode(PAD_TOKEN)
print(pad_id)

for file in tqdm(train_files_clean):
    with open(file, "r", encoding='utf-8') as f:
        text = f.read()
    text_encoded = encode(text)
    text_encoded += pad_id * (CONTEXT_SIZE - len(text_encoded))
    train_data.append(text_encoded)

for file in tqdm(val_files_clean):
    with open(file, "r", encoding='utf-8') as f:
        text = f.read()
    text_encoded = encode(text)
    text_encoded += pad_id * (CONTEXT_SIZE - len(text_encoded))
    val_data.append(text_encoded)

100%|████████████████████████████████████████████████████████████████████████████████████████| 4424/4424 [00:03<00:00, 1125.24it/s]


In [21]:
train_data = torch.tensor(train_data).to(DEVICE)
val_data = torch.tensor(val_data).to(DEVICE)

In [22]:
def get_random_batch(data, batch_size):
    ix = torch.randint(len(data) , (batch_size,))
    x = torch.stack([data[i][:-1] for i in ix])
    y = torch.stack([data[i][1:] for i in ix])
    return x, y

In [26]:
xb, yb = get_random_batch(train_data, BATCH_SIZE)

print(decode(xb[0].tolist()[:100]))
print("*" * 50)
print(decode(yb[0].tolist()[:100]))

(;GM[1]FF[4]SZ[19]PB[Black]PW[White]KM[0]RE[B+8.5]TM[60]TT;B[br];W[pp];B[dd];W[pd];B[nq];W[pn];B[jp];W[cq];B[eq];W[dq];B[ep];W[cn];B[nc];W[qf];B[pb];W
**************************************************
GM[1]FF[4]SZ[19]PB[Black]PW[White]KM[0]RE[B+8.5]TM[60]TT;B[br];W[pp];B[dd];W[pd];B[nq];W[pn];B[jp];W[cq];B[eq];W[dq];B[ep];W[cn];B[nc];W[qf];B[pb];W[


In [27]:
len(train_data)/16

13535.0625

In [28]:
epoch_count = 4
num_of_steps = int(epoch_count * 4376 * 3.5)
warmup_steps = 2000  # Warm up for first 2000 steps
check_val_every = 500
eval_count = 250

lr_max = 3e-4  # Maximum learning rate
lr_min = 5e-6  # Minimum learning rate
total_steps = num_of_steps

print(f"{epoch_count=}")
print(f"{4376*4=}")
print(f"{num_of_steps=}")
print(f"{warmup_steps=}")

epoch_count=4
4376*4=17504
num_of_steps=61264
warmup_steps=2000


In [46]:
import time
from torch.amp import autocast, GradScaler
import matplotlib.pyplot as plt
from IPython.display import clear_output
import time
import math
import numpy as np
import torch

import os
import json

os.makedirs("figures", exist_ok=True)
os.makedirs("results", exist_ok=True)

hparam_search = [
    # {"att_size": 512, "head_count": 8, "dropout": 0.1, "layer_count":8, "gpt_init":True},
    # {"att_size": 1024, "head_count": 16, "dropout": 0, "layer_count":16, "gpt_init":False},
    # {"att_size": 1024, "head_count": 16, "dropout": 0.1, "layer_count":16, "gpt_init":False},
    # {"att_size": 1024, "head_count": 16, "dropout": 0.1, "layer_count":16, "gpt_init":True},
    {"att_size": 768, "head_count": 12, "dropout": 0.1, "layer_count":12, "gpt_init":True},
]
hparams = hparam_search[0]

In [47]:
def calculate_grad_norm(model):
    total_norm = 0.0
    for p in model.parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
    total_norm = total_norm ** 0.5
    return total_norm

def compute_param_norm(model):
    total_norm = 0.0
    for p in model.parameters():
        if p.requires_grad:
            total_norm += p.data.norm(2).item() ** 2
    return total_norm ** 0.5

# def compute_weight_update_norm(model, prev_params):
#     total_update_norm = 0.0
#     for p, prev_p in zip(model.parameters(), prev_params):
#         if p.requires_grad:
#             delta = (p.data - prev_p).norm(2).item()
#             total_update_norm += delta ** 2
#     return total_update_norm ** 0.5

# compute_entropy_and_perplexity
def training_step(model, train_tokens, context_size, batch_size):
    # Time the step
    step_start = time.time()

    # Training Step
    xb, yb = get_random_batch(train_tokens, batch_size)
    optimizer.zero_grad(set_to_none=True)
    with autocast(device_type='cuda', dtype=torch.bfloat16):
        logits, loss = model(xb, yb)

    entropy, perplexity = compute_entropy_and_perplexity(logits.float(), yb)
    
    scaler.scale(loss).backward()

    if torch.isnan(loss).any() or torch.isinf(loss).any():
        raise ValueError(f"!!! Invalid loss at step {step}: {loss.item()}")
    
    # Important: unscale before clipping!
    scaler.unscale_(optimizer)

    grad_norm = torch.nn.utils.clip_grad_norm_(m.parameters(), 1.0)
    
    scaler.step(optimizer)
    scaler.update()
    # tmp_losses.append(loss.item())

    # Update LR
    current_lr = get_lr(step, total_steps, lr_max, lr_min, warmup_steps, num_of_steps*10) * 3
    for param_group in optimizer.param_groups:
        param_group['lr'] = current_lr

    step_time = time.time() - step_start
    return loss.item(), grad_norm.item(), current_lr, entropy, perplexity, step_time


def evaluate_model(model, val_tokens, context_size, batch_size, eval_count):
    model.eval()
    tmp_eval_perplexity, tmp_eval_entropy = [], []
    with torch.no_grad():
        with autocast(device_type='cuda', dtype=torch.bfloat16):
            tmp_test_losses = []
            for _ in range(eval_count):
                xb, yb = get_random_batch(val_tokens, batch_size)
                test_logits, test_loss = m(xb, yb)
                entropy, perplexity = compute_entropy_and_perplexity(test_logits, yb)
                tmp_test_losses.append(test_loss.item())
                tmp_eval_perplexity.append(perplexity)
                tmp_eval_entropy.append(entropy)
    test_loss_avg = sum(tmp_test_losses) / len(tmp_test_losses)
    test_perplexity_avg = sum(tmp_eval_perplexity) / len(tmp_eval_perplexity)
    test_entropy_avg = sum(tmp_eval_entropy) / len(tmp_eval_entropy)
    return test_loss_avg, test_perplexity_avg, test_entropy_avg


def plot_metrics(steps, losses, test_losses, learning_rates, grad_norms):
    # --- Save figure ---
    fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, figsize=(10, 10), gridspec_kw={'height_ratios': [3, 1, 2, 3]})
    
    # Full loss curves
    ax1.plot(steps, losses, 'b-', label='Train Loss')
    ax1.plot(steps, test_losses, 'y-', label='Test Loss')
    ax1.set_title(f'Step {step} | Train {losses[-1]:.4f} | Test {test_losses[-1]:.4f}')
    ax1.grid()
    ax1.legend()
    
    # Learning rate
    ax2.plot(steps, learning_rates, 'g-')
    ax2.set_title('Learning Rate')
    ax2.grid()
    
    # Zoomed-in loss view (last 10 steps)
    last_n = 10
    ax3.plot(steps[-last_n:], losses[-last_n:], 'b-', label='Train Loss (Last 10)')
    ax3.plot(steps[-last_n:], test_losses[-last_n:], 'y-', label='Test Loss (Last 10)')
    ax3.set_title('Zoomed-in Loss View (Last 10 Steps)')
    ax3.grid()
    ax3.legend()
    
    ax4.plot(steps, grad_norms, 'y-', label='Grad Norm')
    ax4.set_title(f'Step {step} | Last Norm {grad_norm:.4f} ')
    ax4.grid()
    ax4.legend()
    
    
    plt.tight_layout()
    plt.savefig(f"figures/step_{step}.png")
    clear_output(wait=True)  # Clear the previous plot
    plt.show()
    plt.close(fig)


def sample_generations(model, test_prompts, encode, device, temperature, max_new_tokens):
    generations = []
    for prompt in test_prompts:
        input_matrix = torch.concat(([torch.tensor([encode(prompt)])])).to(device)
        generation = model.generate(input_matrix, max_new_tokens=max_new_tokens, temperature=temperature).tolist()
        generations += generation
    for generation in generations:
        print(decode(generation))
        print("\n********\n")

In [48]:
# --- Model ---
m = AttentionModel(
    vocab_size=tokenizer.get_vocab_size(),
    att_size=hparams["att_size"],
    head_count=hparams["head_count"],
    layer_count= hparams["layer_count"],
    context_size=CONTEXT_SIZE,
    drop_out=hparams["dropout"]
)

if hparams["gpt_init"]:
    m.apply(init_weights)
m = m.to(DEVICE)
m = torch.compile(m)
print(next(m.parameters()).device)

optimizer = torch.optim.AdamW(m.parameters(),  betas=(0.9, 0.95),  lr=1e-3, weight_decay=0.1)
scaler = GradScaler('cuda')

start_time = time.time()

steps = []
losses = []
test_losses = []
learning_rates = []
tmp_losses = []
step_times = []
grad_norms = []

tmp_train_perplexity, tmp_train_entropy = [], []

start = time.time()
print(f"{num_of_steps=}")

cuda:0
num_of_steps=61264


In [49]:
import torch
import torch.nn.functional as F

def compute_entropy_and_perplexity(logits, targets):
    """
    Args:
        logits: (batch_size, seq_len, vocab_size) — raw model outputs
        targets: (batch_size, seq_len) — ground-truth token ids
        pad_token_id: Optional — to mask out padding tokens in loss

    Returns:
        avg_entropy: average token-level entropy
        perplexity: exp(avg_cross_entropy_loss)
    """
    # Flatten for easier computation
    vocab_size = logits.size(-1)
    logits_flat = logits.view(-1, vocab_size)
    targets_flat = targets.view(-1)

    # Compute log probs
    log_probs = F.log_softmax(logits_flat, dim=-1)
    probs = log_probs.exp()

    # Cross-entropy loss (used for perplexity)
    ce_loss = F.nll_loss(log_probs, targets_flat, reduction='none')  # (num_tokens,)

    # Entropy of each prediction
    entropy = -(probs * log_probs).sum(dim=-1)  # (num_tokens,)

    avg_ce = ce_loss.mean().item()
    avg_entropy = entropy.mean().item()
    perplexity = torch.exp(torch.tensor(avg_ce)).item()

    return avg_entropy, perplexity

In [50]:
if torch.cuda.is_available():
    device_name = torch.cuda.get_device_name(0)
    print(f"GPU Model: {device_name}")
else:
    device_name = "CPU"
    print("CUDA is not available.")
    
import wandb
run = wandb.init(
    entity="entity_name",
    project="project_name",
    config={
        "architecture": "GPT-2-V1",
        "dataset": "sgfs",
        "device_name": device_name,
        **hparams
    },
)

GPU Model: NVIDIA GeForce RTX 5090


In [None]:
run_id = 1
test_prompts = [
    "(",
    "(",
    "("
]
times = []
for step in range(0, num_of_steps):
    loss, grad_norm, current_lr, train_entropy, train_perplexity, step_time = training_step(m, train_data, CONTEXT_SIZE, BATCH_SIZE)

    tmp_train_entropy.append(train_entropy)
    tmp_train_perplexity.append(train_perplexity)

    tmp_losses.append(loss)
    step_times.append(step_time)

    # Every 100 steps: Evaluate + Save
    if step % check_val_every == 0:
        grad_norms.append(grad_norm)
        
        current_loss = sum(tmp_losses) / len(tmp_losses)
        avg_train_entropy = sum(tmp_train_entropy)/len(tmp_train_entropy)
        avg_train_perplexiy = sum(tmp_train_perplexity)/len(tmp_train_perplexity)

        tmp_losses, tmp_train_entropy, tmp_train_perplexity = [], [], []

        test_loss_avg, test_perplexity_avg, test_entropy_avg = evaluate_model(m, val_data, CONTEXT_SIZE, BATCH_SIZE, eval_count)
        
        steps.append(step)
        losses.append(current_loss)
        test_losses.append(test_loss_avg)
        learning_rates.append(current_lr)

        param_norm = compute_param_norm(m)

        run.log({
            "train loss": current_loss,
            "val loss": test_loss_avg,
            "learning rate":current_lr,
            "grad norm": grad_norm,
            "param_norm":param_norm,
            "train_entropy":avg_train_entropy,
            "train_perplexity":avg_train_perplexiy,
            "test_entropy":test_entropy_avg,
            "test_perplexity":test_perplexity_avg,
        })

        plot_metrics(steps, losses, test_losses, learning_rates, grad_norms)

        # sample_generations(m, test_prompts, encode, DEVICE, temperature=0.2)

        m.train()

        # --- Save partial result ---
        partial_result = {
            "run_id": run_id,
            "step": step,
            "hparams": hparams,
            "train_loss": current_loss,
            "test_loss": test_loss_avg,
            "lr": current_lr,
            "avg_step_time_sec_last_100": sum(step_times[-check_val_every:]) / len(step_times[-check_val_every:]),
        }
        with open(f"results/partial_run_{run_id}_step_{step}.json", "w") as f:
            json.dump(partial_result, f, indent=2)

end_time = time.time()
total_time = end_time - start_time

# Final Save
final_result = {
    "run_id": run_id,
    "hparams": hparams,
    "final_train_loss": losses[-1],
    "final_val_loss": test_losses[-1],
    "duration_sec": total_time,
    "avg_step_time_sec": sum(step_times) / len(step_times),
}
with open(f"results/final_result_{run_id}.json", "w") as f:
    json.dump(final_result, f, indent=2)

In [None]:
run.finish()

In [None]:
sample_generations(m, test_prompts, encode, DEVICE,0.4, 1700)