# ResNet-BK Quick Start

Train a small ResNet-BK model in under 5 minutes!

This notebook demonstrates:
- Basic model setup
- Training on WikiText-2
- Evaluation and perplexity measurement

**Runtime**: ~3-5 minutes on Google Colab (free tier T4 GPU)

In [None]:
# Install dependencies
!pip install datasets torch -q

In [None]:
# Clone repository (if not already cloned)
import os
if not os.path.exists('src'):
    !git clone https://github.com/neko-jpg/Project-ResNet-BK-An-O-N-Language-Model-Architecture.git
    %cd Project-ResNet-BK-An-O-N-Language-Model-Architecture

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import time
import math

from src.models import LanguageModel
from src.utils import get_data_loader

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Configuration

Small model for quick training:

In [None]:
# Hyperparameters
D_MODEL = 64
N_SEQ = 128
BATCH_SIZE = 20
N_LAYERS = 4
NUM_EXPERTS = 4
EPOCHS = 2  # Quick training
LR = 1e-3

torch.manual_seed(42)

## Load Data

In [None]:
train_data, vocab, get_batch = get_data_loader(
    batch_size=BATCH_SIZE,
    n_seq=N_SEQ,
    dataset_name='wikitext-2',
    data_limit=100000  # Limit for quick training
)

VOCAB_SIZE = vocab['vocab_size']
print(f"Vocabulary size: {VOCAB_SIZE}")
print(f"Training tokens: {train_data.numel()}")

## Create Model

In [None]:
model = LanguageModel(
    vocab_size=VOCAB_SIZE,
    d_model=D_MODEL,
    n_layers=N_LAYERS,
    n_seq=N_SEQ,
    num_experts=NUM_EXPERTS,
    top_k=1,
).to(device)

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model parameters: {num_params/1e6:.2f}M")

## Train

In [None]:
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

model.train()
for epoch in range(1, EPOCHS + 1):
    total_loss = 0.0
    num_batches = 0
    start_time = time.time()
    
    for i in range(0, train_data.size(0) - 1, N_SEQ):
        x_batch, y_batch = get_batch(train_data, i)
        x_batch = x_batch.t().contiguous()
        
        if x_batch.size(1) != N_SEQ:
            continue
        
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        
        optimizer.zero_grad()
        logits = model(x_batch)
        loss = criterion(logits.view(-1, logits.size(-1)), y_batch)
        
        if torch.isnan(loss) or torch.isinf(loss):
            continue
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
    
    avg_loss = total_loss / max(1, num_batches)
    perplexity = math.exp(avg_loss)
    elapsed = time.time() - start_time
    
    print(f"Epoch {epoch}/{EPOCHS} | Loss: {avg_loss:.4f} | PPL: {perplexity:.2f} | Time: {elapsed:.1f}s")

print("\n‚úÅETraining complete!")

## Results

You've successfully trained a ResNet-BK model!

**Next steps:**
- Try the Full Training notebook for better results
- Experiment with different hyperparameters
- Check out the Benchmarking notebook to compare configurations