# 🚀 The SFT Training Story: From Configuration to Completion

Welcome to an interactive journey through **Supervised Fine-Tuning (SFT)** in Forge!

## What You'll Learn

This notebook tells the complete story of how SFT training works:

1. **🎭 The Actor Model** - Understanding TrainerActor
2. **🔧 Setup Phase** - Loading models, data, and checkpoints
3. **🏃 Training Loop** - Forward passes, backprop, optimization
5. **🧹 Cleanup** - Saving checkpoints and releasing resources

---

## The Forge Actor Architecture

### What is a TrainerActor?

Think of a **TrainerActor** as the conductor of an orchestra:
- 🎭 **Manages multiple processes** across GPUs or nodes
- 🔧 **Controls the lifecycle** of training (setup → train → cleanup)
- 📊 **Coordinates distributed training** with FSDP, tensor parallelism, etc.

### The Training Journey

```
┌─────────────────────────────────────────┐
│  1. Configuration 📋                    │  ← You define parameters
│     (model, data, hyperparameters)      │
└──────────────┬──────────────────────────┘
               ↓
┌─────────────────────────────────────────┐
│  2. Spawn Actor 🎭                      │  ← Forge creates distributed processes
│     (launch 8 GPU processes)            │
└──────────────┬──────────────────────────┘
               ↓
┌─────────────────────────────────────────┐
│  3. Setup Phase 🔧                      │  ← Load model, data, checkpoints
│     - Initialize model with FSDP        │
│     - Load training dataset             │           │
│     - Restore from checkpoint (if any)  │
└──────────────┬──────────────────────────┘
               ↓
┌─────────────────────────────────────────┐
│  4. Training Loop 🔄                    │  ← The main training process
│     FOR each step:                      │
│       → Get batch from dataloader       │
│       → Forward pass (compute loss)     │
│       → Backward pass (compute grads)   │
│       → Optimizer step (update weights) │
│       → [Optional] Run validation       │
│       → [Optional] Save checkpoint      │
└──────────────┬──────────────────────────┘
               ↓
┌─────────────────────────────────────────┐
│  5. Cleanup Phase 🧹                    │  ← Save final state
│     - Save final checkpoint             │
│     - Release GPU memory                │
│     - Stop all processes                │
└─────────────────────────────────────────┘
```

### Why This Architecture?

✅ **Automatic Distribution** - Forge handles multi-GPU/multi-node complexity  
✅ **Fault Tolerance** - Checkpointing enables recovery from failures  
✅ **Flexibility** - Easy to switch between 1 GPU, 8 GPUs, or multiple nodes  
✅ **Production-Ready** - Used at Meta for large-scale training

---

Let's configure your training!

---

# 📚 Part 1: Configuration

## The Foundation - Defining Your Training

Before we can train, we need to tell Forge:
- **What model** to train (Llama3-8B, Qwen3-32B, etc.)
- **What data** to use (datasets, batch sizes)
- **How to train** (learning rate, optimizer, steps)
- **Where to run** (GPUs, FSDP settings)

Let's start by importing our tools...

## Import Dependencies

These imports give us access to:
- **OmegaConf**: Configuration management
- **TrainerActor**: The main training orchestrator
- **SpawnActor**: Helper for creating distributed actors

In [None]:
import asyncio
import logging
from omegaconf import OmegaConf, DictConfig

from apps.sft.trainer_actor import TrainerActor
from apps.sft.spawn_actor import SpawnActor, run_actor

## Configure Model and Process Settings

Define your model configuration and how many processes to use.

In [None]:
# Model Configuration
model_config = {
    "name": "llama3",
    "flavor": "8B",
    "hf_assets_path": "Path_to_hf_assets"
}

# Process Configuration
processes_config = {
    "procs": 8,        # Number of processes
    "with_gpus": True  # Use GPUs
}

print("Model Configuration:")
print(OmegaConf.to_yaml(OmegaConf.create(model_config)))
print("\nProcess Configuration:")
print(OmegaConf.to_yaml(OmegaConf.create(processes_config)))

## Configure Optimizer and LR Scheduler

### The Optimization Engine

The optimizer controls *how* the model learns from gradients:

**AdamW**: Adaptive learning rates with weight decay
- Most popular for transformer models
- Automatically adjusts learning rate per parameter
- Weight decay prevents overfitting

**Learning Rate (lr)**: Step size for weight updates
- 1e-5 (0.00001): Conservative, stable for fine-tuning
- 2e-5 (0.00002): More aggressive, faster convergence
- Too high → Model diverges (loss = NaN)
- Too low → Very slow learning

**Warmup Steps**: Gradually increase LR from 0 to target
- Prevents instability at training start
- 200 steps is typical for fine-tuning
- Rule of thumb: 5-10% of total training steps

In [None]:
# Optimizer Configuration
optimizer_config = {
    "name": "AdamW",
    "lr": 1e-5,    # Learning rate
    "eps": 1e-8
}

# Learning Rate Scheduler Configuration
lr_scheduler_config = {
    "warmup_steps": 200  # Number of warmup steps
}

print("Optimizer Configuration:")
print(OmegaConf.to_yaml(OmegaConf.create(optimizer_config)))
print("\nLR Scheduler Configuration:")
print(OmegaConf.to_yaml(OmegaConf.create(lr_scheduler_config)))

## Configure Training Settings

### Core Training Parameters

**local_batch_size**: Examples processed per GPU per step
- Start with 1 for large models (8B+)
- Increase to 2-4 if you have memory headroom
- Global batch = local_batch_size × num_GPUs

**seq_len**: Maximum sequence length in tokens
- 2048 tokens ≈ 1500 words
- Longer sequences = more context but slower training
- Reduce if running out of memory

**steps**: Total number of training iterations
- 100-500: Quick experiment
- 1000-5000: Solid fine-tune
- 10000+: Production training

**dataset**: Training data source (e.g., "c4", "alpaca")

In [None]:
training_config = {
    "local_batch_size": 1,  # Batch size per GPU
    "seq_len": 2048,         # Sequence length
    "max_norm": 1.0,         # Gradient clipping
    "steps": 1000,           # Total training steps
    "compile": False,        # PyTorch compilation
    "dataset": "c4"          # Dataset name
}

print("Training Configuration:")
print(OmegaConf.to_yaml(OmegaConf.create(training_config)))

## Configure Parallelism Settings

### How Work is Distributed Across GPUs

**FSDP (Fully Sharded Data Parallel)**:
- Splits model parameters across all GPUs
- Each GPU holds only a shard (e.g., 1/8th with 8 GPUs)
- Reduces memory per GPU significantly
- `data_parallel_shard_degree: -1` → Auto-use all GPUs

**Other Parallelism Options**:
- `tensor_parallel_degree`: Split individual layers across GPUs
- `pipeline_parallel_degree`: Split model into stages
- Usually kept at 1 for standard fine-tuning

**Why FSDP?**
- Enables training large models that don't fit on 1 GPU
- Automatically handles gradient synchronization
- Near-linear scaling with more GPUs

For more explanation, visit: https://github.com/pytorch/torchtitan/tree/main/docs

In [None]:
parallelism_config = {
    "data_parallel_replicate_degree": 1,
    "data_parallel_shard_degree": -1,  # -1 means use all available GPUs for FSDP
    "tensor_parallel_degree": 1,
    "pipeline_parallel_degree": 1,
    "context_parallel_degree": 1,
    "expert_parallel_degree": 1,
    "disable_loss_parallel": False
}

print("Parallelism Configuration:")
print(OmegaConf.to_yaml(OmegaConf.create(parallelism_config)))

## Configure Checkpoint and Activation Checkpointing

### Saving Your Progress

**Checkpointing**: Periodic saves of model state
- `interval: 500` → Save every 500 steps
- Allows resuming if training is interrupted
- Final checkpoint saved automatically
- Includes: model weights, optimizer state, training step

**Activation Checkpointing**: Memory optimization technique
- Trades compute for memory
- Recomputes activations during backward pass instead of storing them
- `mode: selective` → Only checkpoint specific operations
- `mode: full` → More aggressive, saves more memory

**When to Use:**
- Standard checkpointing: Always enable
- Activation checkpointing: Use when running out of GPU memory
- Slight slowdown (~10-20%) but can enable training larger models

In [None]:
# Checkpoint Configuration
checkpoint_config = {
    "enable": True,
    "folder": "Path_to_checkpoint_folder",
    "initial_load_path": "Path_to_hf_assets",
    "initial_load_in_hf": True,
    "last_save_in_hf": True,
    "interval": 500,           # Save every N steps
    "async_mode": "disabled"
}

# Activation Checkpoint Configuration
activation_checkpoint_config = {
    "mode": "selective",
    "selective_ac_option": "op"
}

print("Checkpoint Configuration:")
print(OmegaConf.to_yaml(OmegaConf.create(checkpoint_config)))
print("\nActivation Checkpoint Configuration:")
print(OmegaConf.to_yaml(OmegaConf.create(activation_checkpoint_config)))

## Configure Communication Settings

In [None]:
# Communication Configuration
comm_config = {
    "trace_buf_size": 0
}

print("Communication Configuration:")
print(OmegaConf.to_yaml(OmegaConf.create(comm_config)))

## Combine All Configurations

Now let's merge everything into a complete configuration!

In [None]:
# Combine all configs
complete_config = {
    "comm": comm_config,
    "model": model_config,
    "processes": processes_config,
    "optimizer": optimizer_config,
    "lr_scheduler": lr_scheduler_config,
    "training": training_config,
    "parallelism": parallelism_config,
    "checkpoint": checkpoint_config,
    "activation_checkpoint": activation_checkpoint_config
}

# Create OmegaConf DictConfig
cfg = OmegaConf.create(complete_config)


---

# 🎭 Part 2: The Actor Lifecycle

## Understanding Spawn, Setup, Train, and Cleanup

### Phase 1: Spawn the Actor 🎭

**What's happening:**
- `SpawnActor` creates a launcher for `TrainerActor`
- `spawn()` launches 8 Python processes (one per GPU)
- Each process initializes:
  - CUDA device assignment (GPU 0, 1, 2, ...)
  - Distributed communication (NCCL)
  - Process group setup (RANK, LOCAL_RANK, WORLD_SIZE)

**Behind the scenes:**
```
GPU 0: Process 0 (RANK=0, LOCAL_RANK=0)
GPU 1: Process 1 (RANK=1, LOCAL_RANK=1)
...
GPU 7: Process 7 (RANK=7, LOCAL_RANK=7)
```

All processes are now waiting for instructions!
### What Happens When You Run This?

1. **Spawn** 🎭: Forge creates 8 GPU processes (based on `procs: 8`)
2. **Setup** 🔧: Each process loads its shard of the model + data
3. **Train** 🏃: Training loop runs for 1000 steps
4. **Cleanup** 🧹: Final checkpoint saved, resources released

Uncomment the line below to start training!

In [None]:
# Create the spawner
spawner = SpawnActor(TrainerActor, cfg)

# Spawn the actor
actor = await spawner.spawn()
print(f"✓ Actor spawned: {actor}")

### Phase 2: Setup 🔧

**What's happening:**
- **Model Loading**: Each process loads its shard of the model
  - With FSDP, GPU 0 might get layers 0-10
  - GPU 1 gets layers 11-20, etc.
  - Each GPU only holds ~1/8th of the full model
- **Dataset Loading**: Training and validation dataloaders created
  - Same dataset, but different random seeds per GPU
  - Ensures each GPU sees different data
- **Checkpoint Loading**: If resuming, restore training state
  - Model weights, optimizer state, current step number

**What `setup()` does internally:**
```python
def setup(self):
    # 1. Initialize model with FSDP
    self.model = load_model_with_fsdp(cfg.model)
    
    # 2. Create training dataloader
    self.train_dataloader = setup_data(
        dataset_path=cfg.dataset.path,
        dataset_split=cfg.dataset.split
    )
    
    # 3. Create validation dataloader (if enabled)
    self.val_dataloader = setup_data(
        dataset_path=cfg.dataset_val.path,
        dataset_split=cfg.dataset_val.split
    )
    
    # 4. Restore from checkpoint (if any)
    self.checkpointer.load(step=self.current_step)
```

After setup, all 8 GPUs are synchronized and ready to train!

In [None]:
# Setup (load data, checkpoints, etc.)
await spawner.setup()
print("✓ Actor setup complete")

### Phase 3: Training Loop 🔄

**What's happening:**

The training loop runs for `cfg.training.steps` iterations. Each step:

```python
for step in range(current_step, max_steps):
    # 1. Get next batch from dataloader
    batch = next(train_dataloader)
    # Shape: [batch_size, seq_len] per GPU
    
    # 2. Forward pass - compute predictions and loss
    outputs = model(batch['input_ids'])
    loss = compute_loss(outputs, batch['labels'])
    
    # 3. Backward pass - compute gradients
    loss.backward()
    # FSDP automatically synchronizes gradients across all GPUs!
    
    # 4. Optimizer step - update model weights
    optimizer.step()
    optimizer.zero_grad()
    
    # 5. Periodic validation (if enabled)
    if validation_enabled and step % eval_interval == 0:
        val_metrics = evaluate()
        log(f"Step {step}: Val Loss = {val_metrics['val_loss']}")
    
    # 6. Periodic checkpointing
    if step % checkpoint_interval == 0:
        save_checkpoint(step)
```

**Key insights:**
- **FSDP synchronization**: Gradients automatically reduced across GPUs
- **Loss should decrease**: If not, check learning rate or data
- **Validation metrics**: Track generalization on held-out data
- **Checkpoints**: Resume training if interrupted

**What you'll see:**
- Training loss decreasing over time
- Periodic validation metrics (if enabled)
- Checkpoint saves at regular intervals
- Step timing information (seconds per step)

In [None]:
# Run training
await spawner.run()
print("✓ Training complete")

### Phase 4: Cleanup 🧹

**What's happening:**

```python
def cleanup(self):
    # 1. Save final checkpoint
    self.checkpointer.save(
        step=self.current_step,
        force=True  # Always save, even if not at interval
    )
    
    # 2. Release model from GPU memory
    del self.model
    torch.cuda.empty_cache()
    
    # 3. Shutdown distributed process group
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()
    
    # 4. Log final statistics
    log(f"Training complete!")
    log(f"Final step: {self.current_step}")
    log(f"Checkpoint saved to: {checkpoint_path}")
```

**Why cleanup matters:**
- ✅ **Saves final state**: Even if you Ctrl+C, final checkpoint is saved
- ✅ **Frees GPU memory**: Other jobs can now use the GPUs
- ✅ **Clean shutdown**: Prevents zombie processes
- ✅ **Logs summary**: Know exactly where training ended

**After cleanup:**
- Model weights saved to checkpoint folder
- GPUs are free and available
- Training can be resumed from last checkpoint
- All distributed processes cleanly terminated

In [None]:
# Cleanup resources
await spawner.cleanup()
print("✓ Cleanup complete")

---

## Running the Complete Lifecycle

### What Happens When You Run Training?

**The full journey:**

1. **Spawn** 🎭 
   - Forge creates 8 GPU processes (based on `procs: 8`)
   - Each process gets assigned to a GPU
   - Distributed communication initialized

2. **Setup** 🔧
   - Each process loads its 1/8th shard of the model
   - Dataloaders created with different random seeds
   - Checkpoint restored if resuming training

3. **Train** 🏃
   - Training loop runs for 1000 steps
   - Loss computed, gradients synced, weights updated
   - Periodic validation and checkpointing

4. **Cleanup** 🧹
   - Final checkpoint saved
   - GPU memory released
   - All processes terminated cleanly

**Time estimate:**
- With 8 GPUs, ~2-3 seconds per step
- 1000 steps ≈ 40-50 minutes
- Plus validation time (if enabled)


In [None]:

await run_actor(TrainerActor, cfg)


---

# Quick Configuration Templates

Here are ready-to-use templates for common scenarios!

## Template 1: Quick Test (Single GPU, Small Steps)

In [None]:
quick_test_config = OmegaConf.create({
    "comm": {"trace_buf_size": 0},
    "model": {
        "name": "llama3",
        "flavor": "8B",
        "hf_assets_path": "/tmp/Meta-Llama-3.1-8B-Instruct"
    },
    "processes": {"procs": 1, "with_gpus": True},
    "optimizer": {"name": "AdamW", "lr": 1e-5, "eps": 1e-8},
    "lr_scheduler": {"warmup_steps": 10},
    "training": {
        "local_batch_size": 1,
        "seq_len": 1024,
        "max_norm": 1.0,
        "steps": 100,  # Just 100 steps for quick testing
        "compile": False,
        "dataset": "c4"
    },
    "parallelism": {
        "data_parallel_replicate_degree": 1,
        "data_parallel_shard_degree": 1,
        "tensor_parallel_degree": 1,
        "pipeline_parallel_degree": 1,
        "context_parallel_degree": 1,
        "expert_parallel_degree": 1,
        "disable_loss_parallel": False
    },
    "checkpoint": {
        "enable": True,
        "folder": "/tmp/quick_test_checkpoints",
        "initial_load_path": "/tmp/Meta-Llama-3.1-8B-Instruct/",
        "initial_load_in_hf": True,
        "last_save_in_hf": True,
        "interval": 50,
        "async_mode": "disabled"
    },
    "activation_checkpoint": {
        "mode": "selective",
        "selective_ac_option": "op"
    }
})

print("Quick Test Configuration:")
print(OmegaConf.to_yaml(quick_test_config))

# To use: await run_actor(TrainerActor, quick_test_config)

## Template 2: Multi-GPU Training (8 GPUs with FSDP)

In [None]:
multi_gpu_config = OmegaConf.create({
    "comm": {"trace_buf_size": 0},
    "model": {
        "name": "llama3",
        "flavor": "8B",
        "hf_assets_path": "/tmp/Meta-Llama-3.1-8B-Instruct"
    },
    "processes": {"procs": 8, "with_gpus": True},
    "optimizer": {"name": "AdamW", "lr": 2e-5, "eps": 1e-8},
    "lr_scheduler": {"warmup_steps": 200},
    "training": {
        "local_batch_size": 2,
        "seq_len": 2048,
        "max_norm": 1.0,
        "steps": 5000,
        "compile": False,
        "dataset": "c4"
    },
    "parallelism": {
        "data_parallel_replicate_degree": 1,
        "data_parallel_shard_degree": 8,  # FSDP across 8 GPUs
        "tensor_parallel_degree": 1,
        "pipeline_parallel_degree": 1,
        "context_parallel_degree": 1,
        "expert_parallel_degree": 1,
        "disable_loss_parallel": False
    },
    "checkpoint": {
        "enable": True,
        "folder": "/tmp/multi_gpu_checkpoints",
        "initial_load_path": "/tmp/Meta-Llama-3.1-8B-Instruct/",
        "initial_load_in_hf": True,
        "last_save_in_hf": True,
        "interval": 500,
        "async_mode": "disabled"
    },
    "activation_checkpoint": {
        "mode": "selective",
        "selective_ac_option": "op"
    }
})

print("Multi-GPU Configuration:")
print(OmegaConf.to_yaml(multi_gpu_config))

# To use: await run_actor(TrainerActor, multi_gpu_config)

---

# Tips & Tricks

## Memory Optimization
- ⬇️ Reduce `seq_len` if running out of memory
- ⬇️ Reduce `local_batch_size` if running out of memory
- ✅ Enable `activation_checkpoint` for memory savings

## Training Speed
- ⬆️ Increase `local_batch_size` for faster training (if memory allows)
- 🚀 Use multiple GPUs with FSDP (`data_parallel_shard_degree > 1`)
- ⚡ Enable `compile: true` for PyTorch compilation (experimental)

## Debugging
- 🧪 Start with small `steps` (e.g., 10-100) to test quickly
- 🔍 Use single GPU first (`procs: 1`)
- 📊 Monitor loss values in logs

## Checkpoint Management
- 💾 Set `interval` based on how often you want to save
- 📁 Ensure `folder` path exists and has enough space
- 🔄 Use `initial_load_path` to resume from checkpoints