In [None]:
import torch
import os
import math
import time
from model import MD4Config, MD4
# Device setup
device = "cuda:0"
print(f"Using device: {device}")
# Enable TF32 for faster matrix multiplications on Ampere+ GPUs
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

Using device: cuda:0


In [None]:
with open("blobs/wine_0.txt", "r", encoding="utf-8") as f0, \
     open("blobs/wine_1.txt", "r", encoding="utf-8") as f1, \
     open("blobs/wine_2.txt", "r", encoding="utf-8") as f2:
    text = f0.read() + f1.read() + f2.read()

# Create train / val tensor
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(f"Total Text length: {len(text)} characters")
print(f"Vocab size: {vocab_size}")
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
def encode(s):
    return [stoi[c] for c in s]
def decode(ids):
    return ''.join([itos[i] for i in ids if i < vocab_size])
data_tensor = torch.tensor(encode(text), dtype=torch.long)
n = int(0.95 * len(data_tensor)) # 90/10 split
train_data = data_tensor[:n]
val_data = data_tensor[n:]

Total Text length: 68091636 characters
Vocab size: 90


In [None]:
config = MD4Config(vocab_size)

# Specify wether to train from scratch or resume from a checkpoint
resume_from = None
# resume_from = "checkpoints/model_step_51976.pt"
# Specify the new max steps if resumed
# if resume_from: config.max_steps = 55000 

# Set model dtype 
dtype = torch.bfloat16
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == torch.float16))

# Init model
model = MD4(config)
model.to(device)
optimizer = model.configure_optimizers(weight_decay=1e-1, learning_rate=config.learning_rate, device_type=device)

num_params = sum(p.numel() for p in model.parameters())
print(f"Number of model parameters: {num_params}")

start_step = 0
if resume_from and os.path.exists(resume_from):
    print(f"Resuming training from {resume_from}...")
    checkpoint = torch.load(resume_from, map_location=device, weights_only=False)
    # for compiled models
    state_dict = checkpoint['model_state_dict']
    unwanted_prefix = '_orig_mod.'
    for k, v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict)
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_step = checkpoint['step']
    print(f"Successfully loaded. Resuming from step {start_step}")
else: print("Starting training from scratch...")

# Helpers
def get_lr(it):
    # 1) linear warmup
    if it < config.warmup_steps:
        return config.learning_rate * (it + 1) / (config.warmup_steps + 1)
    # 2) if it > max_steps, return min learning rate
    if it > config.max_steps:
        return config.min_lr
    # 3) in between, use cosine decay
    decay_ratio = (it - config.warmup_steps) / (config.max_steps - config.warmup_steps)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) 
    return config.min_lr + coeff * (config.learning_rate - config.min_lr)

def get_batch(split, seqlen, batch_size):
    d = train_data if split == 'train' else val_data
    ix = torch.randint(len(d) - seqlen, (batch_size,))
    X = torch.stack([d[i:i+seqlen] for i in ix])
    return X.to(device)

@torch.no_grad()
def estimate_loss(eval_iters=50):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X = get_batch(split, config.block_size, config.batch_size)
            with torch.amp.autocast(device_type="cuda", dtype=dtype):
                loss = model.compute_loss(X)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

tokens_per_step = config.batch_size * config.block_size
remaining_steps = config.max_steps - start_step
total_tokens_trained = tokens_per_step * config.max_steps
dataset_tokens = len(train_data)
print(f"--- Training Status ---")
print(f"Steps: {start_step} -> {config.max_steps}")
print(f"Remaining steps: {remaining_steps}")
print(f"----------------------")
os.makedirs("checkpoints", exist_ok=True)

# COMPILATION
model = torch.compile(model) 

  scaler = torch.cuda.amp.GradScaler(enabled=(dtype == torch.float16))


Optimizer using fused AdamW: True
Number of model parameters: 40622080
Starting training from scratch...
--- Training Status ---
Steps: 0 -> 50000
Remaining steps: 50000
----------------------


In [4]:
t0 = time.time()
start_time = time.time() 
for step in range(start_step, config.max_steps): 
    
    lr = get_lr(step)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr    
    xb = get_batch('train', seqlen=config.block_size, batch_size=config.batch_size)
    
    # Forward & Backward
    with torch.amp.autocast(device_type="cuda", dtype=dtype):
        loss = model.compute_loss(xb)
    optimizer.zero_grad(set_to_none=True)
    if dtype == torch.float16:
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
    else:
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

    # Logging
    if step % 100 == 0:
        torch.cuda.synchronize() 
        t1 = time.time()
        dt = t1 - t0
        t0 = t1
        tokens_processed = config.batch_size * config.block_size * 100
        tps = tokens_processed / dt
        print(f"Step {step:4d} | Loss: {loss.item():.4f} | LR: {lr:.2e} | Speed: {tps:.0f} tok/s")
    if step % 500 == 0:
        print("-> Evaluating...")
        losses = estimate_loss()
        print(f"-> Train Loss: {losses['train']:.4f} | Val Loss: {losses['val']:.4f}")
    if step % 1000 == 0:
        print("\n--- Generating ---")
        gen_ids = model.generate(seq_len=config.block_size, steps=64)
        print(decode(gen_ids))
        print("------------------\n")
    
    # Checkpoint
    if (step+1) % 5000 == 0:
        checkpoint_path = f"checkpoints/model_step_{(step+1)}.pt"
        torch.save({
            'step': (step+1),
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss.item(),
            'config': config,
        }, checkpoint_path)
        print(f"✓ Checkpoint saved: {checkpoint_path}")

total_time = time.time() - start_time
print(f"Training Complete. Total time: {total_time:.1f}s")
# -> 50k*128*128 =  819M tokens training
# -> roughly 1h @ 250kT/s

Step    0 | Loss: 573.8300 | LR: 2.00e-07 | Speed: 2936859 tok/s
-> Evaluating...
-> Train Loss: 588.0111 | Val Loss: 586.6883

--- Generating ---
JMzuC%
bVTtKEzMG]vNdS,-*K%.RVF6j@Iff'TBBt79H]OABiKR24
 w*UnAEv
Pj_)21J]NUx_`y:v5]CCE-loG#];QEN#:EacMB!DE4jTIs-Ha)FVp&$,"M/
------------------

Step  100 | Loss: 419.4408 | LR: 2.02e-05 | Speed: 124621 tok/s
Step  200 | Loss: 376.1278 | LR: 4.02e-05 | Speed: 244410 tok/s
Step  300 | Loss: 342.3965 | LR: 6.01e-05 | Speed: 246025 tok/s
Step  400 | Loss: 315.3194 | LR: 8.01e-05 | Speed: 246673 tok/s
Step  500 | Loss: 292.5927 | LR: 1.00e-04 | Speed: 246432 tok/s
-> Evaluating...
-> Train Loss: 272.7625 | Val Loss: 263.9272
Step  600 | Loss: 252.7778 | LR: 1.20e-04 | Speed: 174401 tok/s
Step  700 | Loss: 255.3970 | LR: 1.40e-04 | Speed: 243273 tok/s
Step  800 | Loss: 236.7417 | LR: 1.60e-04 | Speed: 246787 tok/s
Step  900 | Loss: 228.4727 | LR: 1.80e-04 | Speed: 244668 tok/s
Step 1000 | Loss: 241.4666 | LR: 2.00e-04 | Speed: 249019 tok/s
->

In [None]:
# Print all checkpoints to see their train / test loss error
# 'What's the overfitting situation?'

import glob
import re
model.eval()
# Find all checkpoint files
checkpoint_files = sorted(glob.glob('checkpoints/model_step_*.pt'))
# Extract step numbers and sort
checkpoint_files = sorted(checkpoint_files, key=lambda x: int(re.search(r'step_(\d+)', x).group(1)))
print("Evaluating checkpoints...\n")
print(f"{'Step':>8} | {'Train Loss':>10} | {'Val Loss':>10}")
print("-" * 35)
for checkpoint_path in checkpoint_files:
    # Extract step number from filename
    step_num = int(re.search(r'step_(\d+)', checkpoint_path).group(1))
    checkpoint = torch.load(checkpoint_path, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    losses = estimate_loss(eval_iters=50)
    print(f"{step_num:8d} | {losses['train']:10.4f} | {losses['val']:10.4f}")

Evaluating checkpoints...

    Step | Train Loss |   Val Loss
-----------------------------------
    5000 |   163.6587 |   161.5879
   10000 |   150.1948 |   149.5160
   15000 |   142.9646 |   139.1608
   20000 |   139.6942 |   136.4029
   25000 |   135.2363 |   132.2800
   30000 |   132.3328 |   130.1084
   35000 |   129.8577 |   128.3031
   40000 |   128.3549 |   126.6261
   45000 |   126.1109 |   125.4059
   50000 |   125.6819 |   125.6412


In [None]:
# Load once more and save it in 25MB blobs so that Github accepts it #


CHUNK_SIZE_LIMIT = 25 * 1024 * 1024 # MB per blob

import io
def save_chunked(obj, file_prefix, chunk_size=CHUNK_SIZE_LIMIT):
    """
    Serializes a python/torch object and saves it in parts 
    no larger than chunk_size.
    """
    buffer = io.BytesIO()
    torch.save(obj, buffer)
    buffer.seek(0)
    part_num = 0
    while True:
        chunk = buffer.read(chunk_size)
        if not chunk:
            break
        # filenames like: model_weights.pt.part000, model_weights.pt.part001
        filename = f"blobs/{file_prefix}.part{part_num:03d}"
        with open(filename, "wb") as f:
            f.write(chunk)
        print(f"Saved {filename} ({len(chunk) / 1024 / 1024:.2f} MB)")
        part_num += 1

# Load checkpoint once more and save it directly after
checkpoint = torch.load('checkpoints/model_step_50000.pt', map_location="cpu", weights_only=False)
# Model Dictionary (Weights + Metadata)
model_bundle = {
    'step': checkpoint['step'],
    'loss': checkpoint['loss'],
    'config': checkpoint['config'],
    'model_state_dict': checkpoint['model_state_dict'] 
}
# Optimizer Dictionary
optimizer_bundle = {
    'optimizer_state_dict': checkpoint['optimizer_state_dict']
}
print("--- Saving Model Bundle ---")
save_chunked(model_bundle, "my_model_weights.pt")
print("\n--- Saving Optimizer Bundle ---")
save_chunked(optimizer_bundle, "my_optimizer_state.pt")

--- Saving Model Bundle ---
Saved blobs/my_model_weights.pt.part000 (25.00 MB)
Saved blobs/my_model_weights.pt.part001 (25.00 MB)
Saved blobs/my_model_weights.pt.part002 (25.00 MB)
Saved blobs/my_model_weights.pt.part003 (25.00 MB)
Saved blobs/my_model_weights.pt.part004 (25.00 MB)
Saved blobs/my_model_weights.pt.part005 (25.00 MB)
Saved blobs/my_model_weights.pt.part006 (4.99 MB)

--- Saving Optimizer Bundle ---
Saved blobs/my_optimizer_state.pt.part000 (25.00 MB)
Saved blobs/my_optimizer_state.pt.part001 (25.00 MB)
Saved blobs/my_optimizer_state.pt.part002 (25.00 MB)
Saved blobs/my_optimizer_state.pt.part003 (25.00 MB)
Saved blobs/my_optimizer_state.pt.part004 (25.00 MB)
Saved blobs/my_optimizer_state.pt.part005 (25.00 MB)
Saved blobs/my_optimizer_state.pt.part006 (25.00 MB)
Saved blobs/my_optimizer_state.pt.part007 (25.00 MB)
Saved blobs/my_optimizer_state.pt.part008 (25.00 MB)
Saved blobs/my_optimizer_state.pt.part009 (25.00 MB)
Saved blobs/my_optimizer_state.pt.part010 (25.00 MB)
