# ODIN Training Notebook

Train ODIN Selective SSM model on Google Colab.

**Instructions:**
1. Run cells in order
2. Checkpoints auto-save to Google Drive
3. If disconnected, just run again - it auto-resumes

## 1. Setup

In [None]:
# Mount Google Drive for checkpoints
from google.colab import drive
drive.mount('/content/drive')

# Create checkpoint directory
!mkdir -p /content/drive/MyDrive/odin_checkpoints

In [None]:
# Clone the repo
!git clone https://github.com/badalraj9/odin.git
%cd odin

In [None]:
# Install dependencies
!pip install -q torch numpy scipy tqdm

In [None]:
# Check GPU
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 2. Generate Dataset (if not already done)

In [None]:
# Check if dataset exists
import os
if not os.path.exists('data/scenarios/stats.json'):
    print("Generating dataset...")
    %cd data/scripts
    !python generator.py --output ../scenarios --num 10000 --seed 42
    %cd ../..
else:
    print("Dataset already exists!")
    import json
    with open('data/scenarios/stats.json') as f:
        stats = json.load(f)
    print(f"Scenarios: {stats['total']}")

## 3. Create Model

In [None]:
import sys
sys.path.insert(0, '.')

from odin.core import ODINModel, ODINConfig

# Choose model size
# - 'tiny': ~1.7M params, 10-15 min training
# - 'small': ~10M params, 1-2 hour training
# - 'medium': ~85M params, 4-6 hour training

MODEL_SIZE = 'tiny'  # Change this!

config_map = {
    'tiny': ODINConfig.tiny,
    'small': ODINConfig.small,
    'medium': ODINConfig.medium,
}

config = config_map[MODEL_SIZE]()
print(f"Config: {MODEL_SIZE}")
print(f"  State dim: {config.state_dim}")
print(f"  Layers: {config.num_layers}")
print(f"  Rank: {config.rank}")

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = ODINModel(config).to(device)

print(f"\nModel parameters: {model.count_parameters():,}")
print(f"Device: {device}")

## 4. Test Forward Pass

In [None]:
# Quick test
batch_size = 4
h = model.init_state(batch_size, device)
u = torch.randn(batch_size, config.input_dim, device=device)

h_next, y, info = model(h, u)

print(f"Input u: {u.shape}")
print(f"State h: {h.shape} -> {h_next.shape}")
print(f"Output y: {y.shape}")
print(f"\nState norm: {info['h_norm']:.4f}")
print(f"Stance probs: {info['outputs']['stance_probs']}")
print(f"Confidence: {info['outputs']['confidence']:.2%}")
print("\n✓ Forward pass works!")

## 5. Train!

In [None]:
from odin.dataset import create_dataloader
from train import ODINTrainer

# Create dataloader
train_loader = create_dataloader(
    scenario_dir='data/scenarios',
    batch_size=config.batch_size,
    shuffle=True,
    max_steps=10,
)

print(f"Dataset: {len(train_loader.dataset)} scenarios")
print(f"Batches per epoch: {len(train_loader)}")

# Create trainer with Drive checkpoint
trainer = ODINTrainer(
    model=model,
    config=config,
    train_loader=train_loader,
    checkpoint_dir='/content/drive/MyDrive/odin_checkpoints',
    device=device,
)

# Train! (auto-resumes if checkpoint exists)
history = trainer.train(resume=True)

## 6. Results

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].plot(history['train_loss'])
axes[0].set_title('Training Loss')
axes[0].set_xlabel('Epoch')

axes[1].plot(history['stance_acc'])
axes[1].set_title('Stance Accuracy')
axes[1].set_xlabel('Epoch')

axes[2].plot(history['action_acc'])
axes[2].set_title('Action Accuracy')
axes[2].set_xlabel('Epoch')

plt.tight_layout()
plt.savefig('/content/drive/MyDrive/odin_checkpoints/training_curves.png')
plt.show()

print(f"\nFinal metrics:")
print(f"  Loss: {history['train_loss'][-1]:.4f}")
print(f"  Stance Acc: {history['stance_acc'][-1]:.2%}")
print(f"  Action Acc: {history['action_acc'][-1]:.2%}")

## 7. Export Model

In [None]:
# Export final model
output_path = f'/content/drive/MyDrive/odin_checkpoints/odin_{MODEL_SIZE}.pt'
trainer.export_model(output_path)
print(f"\nModel saved to: {output_path}")

## 8. Load and Test

In [None]:
# Load the exported model
checkpoint = torch.load(output_path)

# Create new model with same config
loaded_model = ODINModel(config).to(device)
loaded_model.load_state_dict(checkpoint['model_state_dict'])
loaded_model.eval()

print("Model loaded successfully!")

# Test inference
with torch.no_grad():
    h = loaded_model.init_state(1, device)
    u = torch.randn(1, config.input_dim, device=device)
    h_next, y, info = loaded_model(h, u)
    
    print(f"\nInference test:")
    print(f"  Stance probs: {info['outputs']['stance_probs']}")
    print(f"  Confidence: {info['outputs']['confidence']:.2%}")

---

## Done!

Your trained ODIN model is saved to Google Drive.

**Next steps:**
1. If tiny works → try small
2. If small works → try medium
3. Download the `.pt` file for local use