# SFT Training on MATH Dataset with Qwen3-4B

This notebook demonstrates the impact of Supervised Fine-Tuning (SFT) on math reasoning by:
1. Evaluating the base Qwen3-4B model (zero-shot)
2. Fine-tuning on MATH training data
3. Evaluating the fine-tuned model
4. Comparing results to see the improvement

**Default Model: Qwen/Qwen3-4B** (4B params, ~8GB)

**Other Supported Models:**
- Qwen/Qwen2.5-Math-1.5B (1.5B params, ~3GB) - optimized for math
- Qwen/Qwen2.5-0.5B (0.5B params, ~1GB) - for testing
- Qwen/Qwen3-1.7B (1.7B params, ~4GB) - smaller Qwen3 variant
- meta-llama/Llama-3.1-8B (8B params, ~16GB) - requires HuggingFace login

**Requirements:**
- GPU runtime (A100 recommended for Qwen3-4B)
- ~24GB+ GPU memory for Qwen3-4B
- ~16GB+ GPU memory for Qwen 1.5B models

**Before running:**
1. Go to Runtime → Change runtime type → Select GPU (A100 if available)
2. For Llama models: Login to HuggingFace (see Section 5)

## 1. Check GPU

In [None]:
# Verify GPU is available
!nvidia-smi

## 2. Clone Repository

In [None]:
# Clone the repository
!git clone https://github.com/bearbearyu1223/assignment5-alignment.git
%cd assignment5-alignment
!git checkout han/dev

## 3. Install Dependencies

In [None]:
# Install uv (fast Python package manager)
!curl -LsSf https://astral.sh/uv/install.sh | sh

# Add uv to PATH for this session
import os
os.environ['PATH'] = f"{os.path.expanduser('~')}/.local/bin:{os.environ['PATH']}"

In [None]:
# Install dependencies (without flash-attn to avoid build issues)
!uv sync

## 4. Download Training Data

In [None]:
# Download MATH dataset
!uv run python scripts/download_math.py

In [None]:
# Verify data is downloaded
!head -2 data/math/train.jsonl

## 5. Download Model

**Default: Qwen3-4B** - Latest Qwen3 architecture with 4B parameters.

### Alternative Models (Optional)
- Qwen2.5-Math-1.5B: Smaller, optimized for math
- Qwen2.5-0.5B: For quick testing
- Llama-3.1-8B: Requires HuggingFace login

In [None]:
# List available models
!uv run python scripts/download_model.py --list

In [None]:
# Download Qwen3-4B (default, ~8GB)
!uv run python scripts/download_model.py --model-name Qwen/Qwen3-4B

In [None]:
# Alternative: Download smaller model (uncomment if needed)
# !uv run python scripts/download_model.py --model-name Qwen/Qwen2.5-Math-1.5B

# Alternative: Download Llama (requires HuggingFace login first)
# from huggingface_hub import login
# login()  # This will prompt for your HF token
# !uv run python scripts/download_model.py --model-name meta-llama/Llama-3.1-8B

## 6. Evaluate Base Model (Zero-Shot)

First, let's evaluate the base Qwen3-4B model on 100 MATH test examples to establish a baseline.
This shows the model's performance **before** fine-tuning.

In [None]:
# Evaluate base Qwen3-4B model (zero-shot) on 100 test examples
!uv run python scripts/run_math_eval.py \
    --model-name-or-path models/qwen3-4b \
    --output-path outputs/qwen3_base_eval.jsonl \
    --backend transformers \
    --num-samples 100

In [None]:
# Show base model accuracy
import json
with open('outputs/qwen3_base_eval.jsonl') as f:
    results = [json.loads(line) for line in f]
correct = sum(1 for r in results if r.get('metrics', {}).get('answer_reward', 0) == 1.0)
total = len(results)
base_accuracy = correct / total * 100
print(f"Base Qwen3-4B Accuracy (Zero-Shot): {correct}/{total} = {base_accuracy:.1f}%")

## 7. Fine-Tune with SFT (100 Training Examples)

Now let's fine-tune the model on 100 MATH training examples to see the impact of SFT.

In [None]:
# Fine-tune Qwen3-4B on 100 training examples
!uv run python scripts/run_sft.py --auto \
    --model-name-or-path models/qwen3-4b \
    --train-data-path data/math/train.jsonl \
    --output-dir outputs/sft_qwen3_100 \
    --num-samples 100 \
    --num-epochs 1

## 8. View Training Results

In [None]:
# Display training curves
from IPython.display import Image, display
display(Image(filename='outputs/sft_qwen3_100/training_curves.png'))

In [None]:
# View training metrics
import json
with open('outputs/sft_qwen3_100/training_metrics.json') as f:
    metrics = json.load(f)
print(f"Initial loss: {metrics['losses'][0]:.4f}")
print(f"Final loss: {metrics['losses'][-1]:.4f}")
print(f"Total steps: {len(metrics['steps'])}")

## 9. Evaluate Fine-Tuned Model

Now let's evaluate the fine-tuned model on the same 100 test examples to measure the improvement.

In [None]:
# Evaluate fine-tuned model on 100 test examples
!uv run python scripts/run_math_eval.py \
    --model-name-or-path outputs/sft_qwen3_100/final \
    --output-path outputs/qwen3_sft_eval.jsonl \
    --backend transformers \
    --num-samples 100

## 10. Compare Results: Before vs After SFT

In [None]:
# Compare base vs fine-tuned model accuracy
import json

# Load base model results
with open('outputs/qwen3_base_eval.jsonl') as f:
    base_results = [json.loads(line) for line in f]
base_correct = sum(1 for r in base_results if r.get('metrics', {}).get('answer_reward', 0) == 1.0)
base_total = len(base_results)
base_accuracy = base_correct / base_total * 100

# Load fine-tuned model results
with open('outputs/qwen3_sft_eval.jsonl') as f:
    sft_results = [json.loads(line) for line in f]
sft_correct = sum(1 for r in sft_results if r.get('metrics', {}).get('answer_reward', 0) == 1.0)
sft_total = len(sft_results)
sft_accuracy = sft_correct / sft_total * 100

# Display comparison
print("=" * 50)
print("MATH Evaluation Results (100 test examples)")
print("=" * 50)
print(f"Base Qwen3-4B (Zero-Shot): {base_correct}/{base_total} = {base_accuracy:.1f}%")
print(f"Fine-Tuned Qwen3-4B (SFT): {sft_correct}/{sft_total} = {sft_accuracy:.1f}%")
print("-" * 50)
improvement = sft_accuracy - base_accuracy
print(f"Improvement: {improvement:+.1f}%")
print("=" * 50)

## 11. Full Training and Evaluation (Optional)

For production use, train on the full MATH dataset and evaluate on all test examples.

In [None]:
# Full training on entire MATH dataset (uncomment when ready)
# This will take several hours depending on GPU

# !uv run python scripts/run_sft.py --auto \
#     --model-name-or-path models/qwen3-4b \
#     --train-data-path data/math/train.jsonl \
#     --output-dir outputs/sft_qwen3_full \
#     --num-epochs 1 \
#     --learning-rate 2e-5

In [None]:
# Full evaluation on all test examples (uncomment after full training)
# !uv run python scripts/run_math_eval.py \
#     --model-name-or-path outputs/sft_qwen3_full/final \
#     --output-path outputs/qwen3_sft_full_eval.jsonl \
#     --backend transformers

## 12. Save Model to Google Drive (Optional)

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Copy trained model to Drive
!cp -r outputs/sft_qwen3_100/final /content/drive/MyDrive/sft_qwen3_math

## 13. Multi-GPU Training (Lambda Labs / Cloud)

For faster training with multiple GPUs, use `accelerate launch`:

### Quick Setup on Lambda Labs

```bash
# SSH into your Lambda instance
ssh ubuntu@your-instance-ip

# Clone and setup
git clone https://github.com/bearbearyu1223/assignment5-alignment.git
cd assignment5-alignment
git checkout han/dev

# Install uv
curl -LsSf https://astral.sh/uv/install.sh | sh
# source ~/.local/bin/env

# Install with CUDA support
# uv sync --extra cuda

# Download model and data
uv run python scripts/download_model.py --model-name Qwen/Qwen3-1.7B
uv run python scripts/download_math.py

# Run with AUTO mode (auto-detects GPUs and optimal settings)
uv run accelerate launch --multi_gpu scripts/run_sft.py --auto \
    --model-name-or-path models/qwen3-1.7b \
    --train-data-path data/math/train.jsonl \
    --output-dir outputs/sft_qwen3 

# Evaluate the trained model
uv run python scripts/run_math_eval.py \
    --model-name-or-path outputs/sft_qwen3/final \
    --output-path outputs/sft_qwen3_eval.jsonl 
```

### Scaling Guide

| Model | GPUs | batch_size | grad_accum | Effective Batch | VRAM/GPU |
|-------|------|------------|------------|-----------------|----------|
| Qwen3-4B | 1 | 2 | 8 | 16 | ~24GB |
| Qwen3-4B | 2 | 2 | 4 | 16 | ~24GB |
| Qwen3-4B | 4 | 4 | 1 | 16 | ~24GB |
| Qwen 1.5B | 1 | 4 | 4 | 16 | ~16GB |
| Llama 8B | 1 | 1 | 8 | 8 | ~40GB |