# Reasoning Model Training on A100

This notebook trains a custom 7B parameter reasoning model from scratch,
then distills it to a 500M student model for local inference.

**Requirements:**
- A100 80GB GPU (Colab Pro+)
- ~24-48 hours for full 7B training
- ~4-8 hours for distillation

**Outputs:**
- 7B teacher model checkpoint
- 500M student model (for HP tower deployment)

In [None]:
# Check GPU availability
import torch

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU: {gpu_name}")
    print(f"Memory: {gpu_memory:.1f} GB")
    
    if "A100" in gpu_name and gpu_memory > 70:
        print("\n✓ A100 80GB detected - perfect for 7B training!")
    elif gpu_memory > 40:
        print("\n⚠ Large GPU detected - may need gradient checkpointing")
    else:
        print("\n⚠ Smaller GPU - consider training 3B model instead")
else:
    print("❌ No GPU - training will be extremely slow")

In [None]:
# Install dependencies
!pip install -q torch transformers datasets accelerate wandb

# Clone the reasoning-lab repository
!git clone https://github.com/YOUR_USERNAME/reasoning-lab.git
%cd reasoning-lab

# Or mount Google Drive if storing there
# from google.colab import drive
# drive.mount('/content/drive')
# %cd /content/drive/MyDrive/reasoning-lab

In [None]:
import sys
sys.path.insert(0, '.')

import torch
from src.models import create_model, create_7b_config, create_500m_config
from src.data import (
    create_tokenizer, 
    DatasetConfig, 
    create_combined_dataset,
    ReasoningDataset,
    create_dataloaders,
)
from src.training import TrainingConfig, Trainer

print("Imports successful!")

## Configuration

Adjust these settings based on your needs and GPU.

In [None]:
# Training configuration
CONFIG = {
    # Model
    "model_size": "7b",  # "3b", "7b", or "13b"
    
    # Data
    "num_samples": 100000,  # None for full dataset (~150k)
    "max_seq_length": 2048,
    
    # Training
    "epochs": 3,
    "batch_size": 2,  # Per GPU
    "gradient_accumulation": 16,  # Effective batch = 32
    "learning_rate": 1e-4,
    "warmup_steps": 500,
    
    # Efficiency
    "gradient_checkpointing": True,
    "bf16": True,
    
    # Saving
    "save_steps": 500,
    "output_dir": "checkpoints/teacher_7b",
    
    # Logging
    "use_wandb": True,
    "wandb_project": "reasoning-7b",
}

print("Configuration:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

## 1. Create Model

In [None]:
# Create model configuration
if CONFIG["model_size"] == "7b":
    model_config = create_7b_config()
elif CONFIG["model_size"] == "3b":
    from src.models import create_3b_config
    model_config = create_3b_config()
elif CONFIG["model_size"] == "13b":
    from src.models import create_13b_config
    model_config = create_13b_config()

model_config.gradient_checkpointing = CONFIG["gradient_checkpointing"]

print(f"Model: {model_config.name}")
print(f"Parameters: {model_config.num_parameters() / 1e9:.2f}B")

# Memory estimate
mem = model_config.memory_footprint(dtype_bytes=2, batch_size=CONFIG["batch_size"])
print(f"Estimated training memory: {mem['total_training_gb']:.1f} GB")

In [None]:
# Initialize model
model = create_model(model_config)
print(f"Created model with {model.num_parameters() / 1e9:.2f}B parameters")

# Load tokenizer (using Mistral - open source, similar vocab to Llama)
tokenizer = create_tokenizer(
    base_tokenizer="mistralai/Mistral-7B-v0.1",
    vocab_size=model_config.vocab_size,
    add_reasoning_tokens=True,
)
print(f"Tokenizer vocabulary: {len(tokenizer)} tokens")

# Resize embeddings if needed
if len(tokenizer) != model_config.vocab_size:
    print(f"Resizing embeddings to match tokenizer...")
    model.embed_tokens = torch.nn.Embedding(len(tokenizer), model_config.hidden_size)
    if model.lm_head is not None:
        model.lm_head = torch.nn.Linear(model_config.hidden_size, len(tokenizer), bias=False)

## 2. Load Training Data

In [None]:
# Configure data sources
data_config = DatasetConfig(
    max_seq_length=CONFIG["max_seq_length"],
    seed=42,
)

# Limit samples if specified
if CONFIG["num_samples"]:
    samples_per_source = CONFIG["num_samples"] // len(data_config.sources)
    for source in data_config.sources:
        source["sample_size"] = min(source.get("sample_size", samples_per_source), samples_per_source)

print("Loading datasets...")
dataset = create_combined_dataset(data_config)
print(f"Total examples: {len(dataset)}")

In [None]:
# Train/val split
train_size = int(0.98 * len(dataset))
train_data = dataset.select(range(train_size))
val_data = dataset.select(range(train_size, len(dataset)))

print(f"Train: {len(train_data)}, Val: {len(val_data)}")

# Create PyTorch datasets
train_dataset = ReasoningDataset(train_data, tokenizer, max_length=CONFIG["max_seq_length"])
val_dataset = ReasoningDataset(val_data, tokenizer, max_length=CONFIG["max_seq_length"])

# Create dataloaders
dataloaders = create_dataloaders(
    train_dataset,
    val_dataset,
    batch_size=CONFIG["batch_size"],
    num_workers=2,
)

print(f"Training batches: {len(dataloaders['train'])}")

## 3. Train Model

In [None]:
# Optional: Login to Weights & Biases
if CONFIG["use_wandb"]:
    import wandb
    wandb.login()

# Training configuration
training_config = TrainingConfig(
    num_epochs=CONFIG["epochs"],
    learning_rate=CONFIG["learning_rate"],
    batch_size=CONFIG["batch_size"],
    gradient_accumulation_steps=CONFIG["gradient_accumulation"],
    warmup_steps=CONFIG["warmup_steps"],
    mixed_precision=True,
    bf16=CONFIG["bf16"],
    gradient_checkpointing=CONFIG["gradient_checkpointing"],
    output_dir=CONFIG["output_dir"],
    save_steps=CONFIG["save_steps"],
    use_wandb=CONFIG["use_wandb"],
    wandb_project=CONFIG["wandb_project"],
)

print(f"Effective batch size: {training_config.effective_batch_size}")

In [None]:
# Create trainer and start training
trainer = Trainer(
    model=model,
    config=training_config,
    train_dataloader=dataloaders["train"],
    eval_dataloader=dataloaders.get("val"),
)

# Train!
results = trainer.train()

print("\n" + "="*60)
print("Training Complete!")
print("="*60)
print(f"Total steps: {results['total_steps']}")
print(f"Final loss: {results['final_loss']:.4f}")
print(f"Training time: {results['training_time']/3600:.2f} hours")

## 4. Distill to 500M Student

Now we distill the trained 7B teacher to a 500M student model
that can run on your HP tower.

In [None]:
from src.training import DistillationConfig, DistillationTrainer

# Create student model
student_config = create_500m_config()
student = create_model(student_config)

print(f"Teacher: {model_config.num_parameters() / 1e9:.2f}B")
print(f"Student: {student_config.num_parameters() / 1e6:.0f}M")
print(f"Compression: {model_config.num_parameters() / student_config.num_parameters():.1f}x")

In [None]:
# Distillation configuration
distill_config = DistillationConfig(
    temperature=2.0,
    alpha_ce=0.5,
    alpha_kl=0.5,
    alpha_hidden=0.1,
    freeze_teacher=True,
)

# Optimizer for student
optimizer = torch.optim.AdamW(
    student.parameters(),
    lr=5e-5,
    weight_decay=0.1,
)

# Create distillation trainer
distill_trainer = DistillationTrainer(
    teacher=model,
    student=student,
    config=distill_config,
    train_dataloader=dataloaders["train"],
    eval_dataloader=dataloaders.get("val"),
    optimizer=optimizer,
)

# Distill for 5 epochs
print("Starting distillation...")
for epoch in range(5):
    print(f"\nEpoch {epoch + 1}/5")
    losses = distill_trainer.train_epoch()
    print(f"  Losses: {losses}")
    
    if dataloaders.get("val"):
        eval_metrics = distill_trainer.evaluate()
        print(f"  Eval: {eval_metrics}")

In [None]:
import os

# Save student model
student_dir = "checkpoints/student_500m"
os.makedirs(student_dir, exist_ok=True)

torch.save(student.state_dict(), f"{student_dir}/model.pt")
student_config.save(f"{student_dir}/config.json")

print(f"Student model saved to {student_dir}")
print(f"Model size: {os.path.getsize(f'{student_dir}/model.pt') / 1e9:.2f} GB")

## 5. Test the Student Model

In [None]:
# Test inference
student.eval()
device = next(student.parameters()).device

test_prompts = [
    "What is 15 + 27? Think step by step.",
    "If all dogs are mammals, and some mammals can fly, can some dogs fly? Explain your reasoning.",
    "Debug this Python code: def add(a, b): return a - b",
]

for prompt in test_prompts:
    print(f"\nPrompt: {prompt}")
    print("-" * 50)
    
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    
    with torch.no_grad():
        outputs = student.generate(
            inputs.input_ids,
            max_new_tokens=200,
            temperature=0.7,
            top_p=0.9,
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"Response: {response[len(prompt):]}")

## 6. Download Models

Download the trained models to your local machine or Google Drive.

In [None]:
# Option 1: Copy to Google Drive
from google.colab import drive
drive.mount('/content/drive')

!cp -r checkpoints/student_500m /content/drive/MyDrive/reasoning-models/
!cp -r checkpoints/teacher_7b/final /content/drive/MyDrive/reasoning-models/teacher_7b

print("Models copied to Google Drive!")

In [None]:
# Option 2: Download directly
from google.colab import files

# Compress student model
!tar -czvf student_500m.tar.gz checkpoints/student_500m

# Download
files.download('student_500m.tar.gz')

## Next Steps

1. **Deploy on HP Tower:**
   - Copy the 500M model to your HP machine
   - See `scripts/inference_local.py` for running locally

2. **Raspberry Pi Cluster:**
   - The 500M model can be quantized to INT8 for Pi deployment
   - See `src/deployment/quantize.py` (coming soon)

3. **Further Training:**
   - Generate more synthetic data with `src/data/synthetic.py`
   - Fine-tune on domain-specific reasoning tasks

4. **Evaluation:**
   - Run `scripts/evaluate_models.py --compare` to compare teacher/student
   - Check retention metrics for distillation quality