# ResNet-BK Full Training

Reproduce paper results with comprehensive training and evaluation.

This notebook includes:
- Full WikiText-2 training
- Learning rate scheduling
- Comprehensive metrics logging
- Checkpoint saving

**Runtime**: ~20-30 minutes on Google Colab (free tier T4 GPU)

In [None]:
# Install dependencies
!pip install datasets torch matplotlib -q

In [None]:
# Clone repository
import os
if not os.path.exists('src'):
    !git clone https://github.com/YOUR_USERNAME/resnet-bk.git
    %cd resnet-bk

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
import time
import math

from src.models import LanguageModel
from src.utils import get_data_loader, TrainingMetrics, MetricsLogger

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Configuration

In [None]:
# Hyperparameters (paper settings)
D_MODEL = 64
N_SEQ = 128
BATCH_SIZE = 20
N_LAYERS = 4
NUM_EXPERTS = 4
EPOCHS = 3
LR = 1e-3
WEIGHT_DECAY = 0.01

torch.manual_seed(42)

## Load Data

In [None]:
train_data, vocab, get_batch = get_data_loader(
    batch_size=BATCH_SIZE,
    n_seq=N_SEQ,
    dataset_name='wikitext-2',
    data_limit=500000
)

VOCAB_SIZE = vocab['vocab_size']
print(f"Vocabulary size: {VOCAB_SIZE}")
print(f"Training tokens: {train_data.numel()}")

## Create Model and Optimizer

In [None]:
model = LanguageModel(
    vocab_size=VOCAB_SIZE,
    d_model=D_MODEL,
    n_layers=N_LAYERS,
    n_seq=N_SEQ,
    num_experts=NUM_EXPERTS,
    top_k=1,
).to(device)

optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
criterion = nn.CrossEntropyLoss()

num_total_steps = (train_data.size(0) // N_SEQ) * EPOCHS
scheduler = CosineAnnealingLR(optimizer, T_max=num_total_steps, eta_min=LR / 10)

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model parameters: {num_params/1e6:.2f}M")
print(f"Total training steps: {num_total_steps}")

## Setup Logging

In [None]:
logger = MetricsLogger(log_dir='logs', experiment_name='resnet_bk_full')

## Train

In [None]:
model.train()
global_step = 0

for epoch in range(1, EPOCHS + 1):
    epoch_start = time.time()
    
    for i in range(0, train_data.size(0) - 1, N_SEQ):
        step_start = time.time()
        
        x_batch, y_batch = get_batch(train_data, i)
        x_batch = x_batch.t().contiguous()
        
        if x_batch.size(1) != N_SEQ:
            continue
        
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        
        optimizer.zero_grad()
        logits = model(x_batch)
        loss = criterion(logits.view(-1, logits.size(-1)), y_batch)
        
        if torch.isnan(loss) or torch.isinf(loss):
            continue
        
        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        scheduler.step()
        global_step += 1
        
        step_time = time.time() - step_start
        
        # Log metrics
        if global_step % 50 == 0:
            metrics = TrainingMetrics(
                step=global_step,
                epoch=epoch,
                loss=loss.item(),
                learning_rate=scheduler.get_last_lr()[0],
                step_time=step_time,
                grad_norm=grad_norm.item(),
            )
            logger.log(metrics)
    
    epoch_time = time.time() - epoch_start
    print(f"\nEpoch {epoch} completed in {epoch_time:.1f}s\n")

# Save final metrics
logger.save_json()
logger.print_summary()

print("\nâœ“ Training complete!")

## Save Checkpoint

In [None]:
checkpoint = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'config': {
        'vocab_size': VOCAB_SIZE,
        'd_model': D_MODEL,
        'n_layers': N_LAYERS,
        'n_seq': N_SEQ,
        'num_experts': NUM_EXPERTS,
    },
    'metrics': logger.get_summary(),
}

torch.save(checkpoint, 'checkpoints/resnet_bk_final.pt')
print("Checkpoint saved to checkpoints/resnet_bk_final.pt")