# SASLM v2 Training Notebook

Train the Sri Aurobindo Small Language Model with:
- Weighted sampling (mature works prioritized)
- Checkpointing to Google Drive (survives disconnects)
- Grokking detection
- Comprehensive logging

## Experiments
- **EXP-A1**: From scratch, prose only
- **EXP-B1**: Fine-tune GPT-2, prose only
- **EXP-A2**: From scratch, prose + poetry
- **EXP-B2**: Fine-tune GPT-2, prose + poetry

## 1. Setup

In [None]:
# Mount Google Drive (for checkpoint persistence)
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Clone the repository (or upload files)
# Option 1: Clone from GitHub
# !git clone https://github.com/YOUR_USERNAME/saslm.git /content/saslm

# Option 2: Upload zip and extract
# from google.colab import files
# uploaded = files.upload()  # Upload saslm.zip
# !unzip saslm.zip -d /content/

# For now, assume files are in /content/saslm
%cd /content/saslm

In [None]:
# Install dependencies
!pip install torch transformers tokenizers datasets wandb tqdm pyyaml numpy -q

In [None]:
# Verify GPU
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 2. Build Corpus (if not already done)

In [None]:
import os

# Check if corpus exists
if not os.path.exists('./data/clean_prose/corpus_stats.json'):
    print("Building prose corpus...")
    !python src/data/build_corpus.py --mode prose --source ./processed_text
else:
    print("Prose corpus already exists")
    !python src/data/build_corpus.py --info --output ./data/clean_prose

## 3. Train Tokenizer (if not already done)

In [None]:
# Check if tokenizer exists
if not os.path.exists('./tokenizers/tokenizer_16k/tokenizer.json'):
    print("Training tokenizer...")
    !python src/data/train_tokenizer.py \
        --corpus ./data/clean_prose \
        --vocab-size 16384 \
        --output ./tokenizers/tokenizer_16k
else:
    print("Tokenizer already exists")
    !python src/data/train_tokenizer.py --analyze ./tokenizers/tokenizer_16k --corpus ./data/clean_prose

## 4. Select Experiment

In [None]:
# Choose experiment
EXPERIMENT = "EXP-A1"  # Options: EXP-A1, EXP-B1, EXP-A2, EXP-B2

config_map = {
    "EXP-A1": "configs/exp_a1_prose_only.yaml",
    "EXP-B1": "configs/exp_b1_prose_only_finetune.yaml",
    "EXP-A2": "configs/exp_a2_prose_poetry.yaml",
    "EXP-B2": "configs/exp_b2_prose_poetry_finetune.yaml",
}

CONFIG_PATH = config_map[EXPERIMENT]
print(f"Selected: {EXPERIMENT}")
print(f"Config: {CONFIG_PATH}")

In [None]:
# View config
!cat {CONFIG_PATH}

## 5. Run Training

Training will:
- Auto-resume from checkpoint if disconnected
- Save checkpoints to Google Drive every 1000 steps
- Log metrics to wandb (optional)
- Detect grokking phenomenon

In [None]:
# Optional: Login to Weights & Biases for tracking
# import wandb
# wandb.login()

In [None]:
# Run training (will auto-resume if checkpoint exists)
!python src/training/train.py \
    --config {CONFIG_PATH} \
    --resume

## 6. Evaluate Model

In [None]:
# Load the best model and generate samples
import torch
from tokenizers import Tokenizer
import sys
sys.path.insert(0, '.')

from src.training.train import GPT
from src.training.checkpoint_manager import CheckpointManager

# Load tokenizer
tokenizer = Tokenizer.from_file('./tokenizers/tokenizer_16k/tokenizer.json')
vocab_size = tokenizer.get_vocab_size()

# Create model
model = GPT(
    vocab_size=vocab_size,
    block_size=512,
    n_layer=6,
    n_head=6,
    n_embd=384,
)

# Load best checkpoint
checkpoint_mgr = CheckpointManager(EXPERIMENT)
checkpoint_mgr.load_best(model, device='cuda')
model = model.cuda()
model.eval()

print("Model loaded!")

In [None]:
# Generate samples
prompts = [
    "The Supermind is",
    "The psychic being differs from the soul in that",
    "The goal of Integral Yoga is not merely liberation but",
    "In the process of spiritual evolution,",
    "The three modes of Nature are",
]

for prompt in prompts:
    # Encode
    encoded = tokenizer.encode(prompt)
    input_ids = torch.tensor([encoded.ids], device='cuda')
    
    # Generate
    with torch.no_grad():
        output = model.generate(input_ids, max_new_tokens=100, temperature=0.8, top_k=50)
    
    # Decode
    generated = tokenizer.decode(output[0].tolist())
    
    print(f"\n{'='*60}")
    print(f"Prompt: {prompt}")
    print(f"Generated: {generated}")

## 7. Run LLM Judge Evaluation

In [None]:
# Set API key for evaluation (choose one)
import os

# Option 1: OpenAI
# os.environ['OPENAI_API_KEY'] = 'your-key-here'

# Option 2: Anthropic
# os.environ['ANTHROPIC_API_KEY'] = 'your-key-here'

# Option 3: Google
# os.environ['GEMINI_API_KEY'] = 'your-key-here'

In [None]:
# Run evaluation
# !python src/evaluate.py \
#     --model-path /content/drive/MyDrive/saslm/experiments/{EXPERIMENT}/best_model.pt \
#     --tokenizer ./tokenizers/tokenizer_16k \
#     --judge claude \
#     --output ./results/{EXPERIMENT}_eval.csv

## 8. Upload to HuggingFace (Optional)

In [None]:
# Login to HuggingFace
# from huggingface_hub import login
# login()

In [None]:
# Upload model
# from huggingface_hub import HfApi
# api = HfApi()
# 
# api.upload_folder(
#     folder_path=f'/content/drive/MyDrive/saslm/experiments/{EXPERIMENT}',
#     repo_id='your-username/saslm-v2',
#     repo_type='model',
# )

## 9. View Training Curves

In [None]:
import json
import matplotlib.pyplot as plt

# Load metrics
metrics_path = f'/content/drive/MyDrive/saslm/experiments/{EXPERIMENT}/metrics.jsonl'

steps = []
train_losses = []
val_losses = []

with open(metrics_path, 'r') as f:
    for line in f:
        data = json.loads(line)
        steps.append(data['step'])
        if 'train_loss' in data:
            train_losses.append((data['step'], data['train_loss']))
        if 'val_loss' in data:
            val_losses.append((data['step'], data['val_loss']))

# Plot
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))

# Training loss
if train_losses:
    x, y = zip(*train_losses)
    ax1.plot(x, y, label='Train Loss', alpha=0.7)
if val_losses:
    x, y = zip(*val_losses)
    ax1.plot(x, y, label='Val Loss', alpha=0.7)
ax1.set_xlabel('Step')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Gap (for grokking detection)
if train_losses and val_losses:
    # Align by step
    train_dict = dict(train_losses)
    val_dict = dict(val_losses)
    common_steps = sorted(set(train_dict.keys()) & set(val_dict.keys()))
    gaps = [val_dict[s] - train_dict[s] for s in common_steps]
    ax2.plot(common_steps, gaps, label='Val - Train Gap', color='purple')
    ax2.axhline(y=0, color='gray', linestyle='--')
    ax2.set_xlabel('Step')
    ax2.set_ylabel('Gap')
    ax2.set_title('Generalization Gap (Grokking Indicator)')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f'./results/{EXPERIMENT}_training_curves.png', dpi=150)
plt.show()

---

## Notes

### If Colab Disconnects
Just re-run from the "Run Training" cell. The training will automatically resume from the last checkpoint.

### Checkpoint Locations
- Latest: `/content/drive/MyDrive/saslm/experiments/{EXPERIMENT}/checkpoint_latest.pt`
- Best: `/content/drive/MyDrive/saslm/experiments/{EXPERIMENT}/best_model.pt`
- Metrics: `/content/drive/MyDrive/saslm/experiments/{EXPERIMENT}/metrics.jsonl`

### Expected Training Time
- EXP-A1 (from scratch): ~8-12 hours for 100K steps on T4
- EXP-B1 (fine-tune): ~4-6 hours for 50K steps on T4