diff --git a/apps/sft/actor.py b/apps/sft/actor.py new file mode 100644 index 000000000..8607a39c4 --- /dev/null +++ b/apps/sft/actor.py @@ -0,0 +1,133 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Abstract Actor class for training/inference actors in Forge. + +This provides a base class that can be extended for different types of actors +(e.g., Trainer, Evaluator, Inferencer, etc.) +""" + +import logging +import math +import os +from abc import ABC, abstractmethod +from typing import Any, Optional + +import torch +from forge.controller import ForgeActor +from monarch.actor import current_rank, current_size +from omegaconf import DictConfig, OmegaConf +from torch import nn +from torchtitan.components.loss import LossFunction +from torchtitan.components.lr_scheduler import LRSchedulersContainer +from torchtitan.components.optimizer import OptimizersContainer +from torchtitan.distributed import ParallelDims +from torchtitan.experiments.forge.engine import ForgeEngine +from torchtitan.experiments.forge.job_config import ForgeJobConfig + +Checkpointer = Any +Dataloader = Any +MetricLogger = Any +Profiler = Any +Tokenizer = Any + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class BaseForgeActor(ForgeActor, ForgeEngine, ABC): + """ + Abstract base class for Forge actors. + + This class handles common initialization, distributed setup, and provides + abstract methods that must be implemented by concrete actor classes. + """ + + job_config: ForgeJobConfig + parallel_dims: ParallelDims + model: list[nn.Module] + loss_fn: Optional[LossFunction] + optimizer: Optional[OptimizersContainer] + lr_scheduler: Optional[LRSchedulersContainer] + checkpointer: Optional[Checkpointer] + tokenizer: Optional[Tokenizer] + metric_logger: Optional[MetricLogger] + profiler: Optional[Profiler] + device: torch.device + + def __init__(self, config: DictConfig): + """ + Initialize the base actor with configuration. + + Args: + config: Configuration dictionary containing job settings + """ + job_config = ForgeJobConfig().to_dict() + job_config = OmegaConf.merge(job_config, config) + + self.current_step = 0 + self.metric_logger = None + self.gradient_accumulation_steps = 1 + self._rank = current_rank().rank + self._size = math.prod(current_size().values()) + + self._init_dist() + super().__init__(job_config) + + def _init_dist(self): + """ + Initialize torch distributed environment. + + Sets up environment variables required for distributed training + in the Monarch actor framework. + """ + env = { + "RANK": str(self._rank), + "LOCAL_RANK": str(self._rank), + "LOCAL_WORLD_SIZE": str(self._size), + "GROUP_RANK": str(self._size), + "GROUP_WORLD_SIZE": str(self._size), + "ROLE_RANK": str(self._rank), + "ROLE_WORLD_SIZE": str(self._size), + "ROLE_NAME": "rank", + "WORLD_SIZE": str(self._size), + "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", + } + os.environ.update(env) + logger.info(f"Initialized distributed environment: {env}") + + @abstractmethod + async def setup(self): + """ + Setup the actor (load data, checkpoint, etc.). + + This method must be implemented by concrete actor classes. + """ + pass + + @abstractmethod + async def run(self): + """ + Main execution logic for the actor. + + This method must be implemented by concrete actor classes. + """ + pass + + @abstractmethod + async def cleanup(self): + """ + Cleanup resources (close checkpointer, logger, etc.). + + This method must be implemented by concrete actor classes. + """ + pass + + @abstractmethod + def __repr__(self) -> str: + """String representation of the actor.""" + pass diff --git a/apps/sft/interactive_config_notebook.ipynb b/apps/sft/interactive_config_notebook.ipynb new file mode 100644 index 000000000..9b37a2451 --- /dev/null +++ b/apps/sft/interactive_config_notebook.ipynb @@ -0,0 +1,947 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ๐Ÿš€ The SFT Training Story: From Configuration to Completion\n", + "\n", + "Welcome to an interactive journey through **Supervised Fine-Tuning (SFT)** in Forge!\n", + "\n", + "## What You'll Learn\n", + "\n", + "This notebook tells the complete story of how SFT training works:\n", + "\n", + "1. **๐ŸŽญ The Actor Model** - Understanding TrainerActor (built on Monarch)\n", + "2. **๐Ÿ”ง Setup Phase** - Loading models, data, and checkpoints\n", + "3. **๐Ÿƒ Training Loop** - Forward passes, backprop, optimization\n", + "5. **๐Ÿงน Cleanup** - Saving checkpoints and releasing resources\n", + "\n", + "---\n", + "\n", + "## The Forge Actor Architecture\n", + "\n", + "### What is Monarch?\n", + "\n", + "**Monarch** is Meta's distributed actor framework that powers Forge:\n", + "- ๐ŸŒ **Distributed by design** - Built for multi-node, multi-GPU training\n", + "- ๐ŸŽญ **Actor model** - Encapsulates distributed processes as actors\n", + "- ๐Ÿ“ก **Remote communication** - Seamless RPC between actors\n", + "- ๐Ÿ”ง **Lifecycle management** - Spawn โ†’ Setup โ†’ Run โ†’ Cleanup pattern\n", + "\n", + "Forge leverages Monarch to abstract away distributed training complexity!\n", + "\n", + "For more information on Monarch, visit https://github.com/meta-pytorch/monarch/tree/main/docs\n", + "\n", + "### What is a TrainerActor?\n", + "\n", + "A **TrainerActor** is Forge's Monarch actor for training:\n", + "- ๐ŸŽญ **Manages multiple processes** across GPUs or nodes\n", + "- ๐Ÿ”ง **Controls the lifecycle** using Monarch's actor pattern\n", + "- ๐Ÿ“Š **Coordinates distributed training** with FSDP, tensor parallelism, etc.\n", + "\n", + "Think of it as the conductor of an orchestra - coordinating 8 GPU processes working together!\n", + "\n", + "### The Training Journey (Monarch Actor Lifecycle)\n", + "\n", + "```\n", + "โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”\n", + "โ”‚ 1. Configuration ๐Ÿ“‹ โ”‚ โ† You define parameters\n", + "โ”‚ (model, data, hyperparameters) โ”‚\n", + "โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜\n", + " โ†“\n", + "โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”\n", + "โ”‚ 2. Spawn Actor ๐ŸŽญ [MONARCH] โ”‚ โ† Monarch creates distributed processes\n", + "โ”‚ (launch 8 GPU processes) โ”‚\n", + "โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜\n", + " โ†“\n", + "โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”\n", + "โ”‚ 3. Setup Phase ๐Ÿ”ง [MONARCH] โ”‚ โ† Actor.setup() endpoint\n", + "โ”‚ - Initialize model with FSDP โ”‚\n", + "โ”‚ - Load training dataset โ”‚\n", + "โ”‚ - Restore from checkpoint (if any) โ”‚\n", + "โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜\n", + " โ†“\n", + "โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”\n", + "โ”‚ 4. Training Loop ๐Ÿ”„ [MONARCH] โ”‚ โ† Actor.train() endpoint\n", + "โ”‚ FOR each step: โ”‚\n", + "โ”‚ โ†’ Get batch from dataloader โ”‚\n", + "โ”‚ โ†’ Forward pass (compute loss) โ”‚\n", + "โ”‚ โ†’ Backward pass (compute grads) โ”‚\n", + "โ”‚ โ†’ Optimizer step (update weights) โ”‚\n", + "โ”‚ โ†’ [Optional] Run validation โ”‚\n", + "โ”‚ โ†’ [Optional] Save checkpoint โ”‚\n", + "โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜\n", + " โ†“\n", + "โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”\n", + "โ”‚ 5. Cleanup Phase ๐Ÿงน [MONARCH] โ”‚ โ† Actor.cleanup() endpoint\n", + "โ”‚ - Save final checkpoint โ”‚\n", + "โ”‚ - Release GPU memory โ”‚\n", + "โ”‚ - Stop all processes โ”‚\n", + "โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜\n", + "```\n", + "\n", + "### Why This Architecture?\n", + "\n", + "โœ… **Automatic Distribution** - Monarch handles multi-GPU/multi-node complexity \n", + "โœ… **Fault Tolerance** - Checkpointing enables recovery from failures \n", + "โœ… **Flexibility** - Easy to switch between 1 GPU, 8 GPUs, or multiple nodes \n", + "โœ… **Production-Ready** - Used at Meta for large-scale training \n", + "โœ… **Actor Pattern** - Clean separation of concerns with lifecycle methods\n", + "\n", + "#### For more information regarding Forge visit: https://github.com/meta-pytorch/torchforge/tree/main/docs\n", + "---\n", + "\n", + "Let's configure your training!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "\n", + "# ๐Ÿ“š Part 1: Configuration\n", + "\n", + "## The Foundation - Defining Your Training\n", + "\n", + "Before we can train, we need to tell Forge:\n", + "- **What model** to train (Llama3-8B, Qwen3-32B, etc.)\n", + "- **What data** to use (datasets, batch sizes)\n", + "- **How to train** (learning rate, optimizer, steps)\n", + "- **Where to run** (GPUs, FSDP settings)\n", + "\n", + "Let's start by importing our tools..." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Import Dependencies\n", + "\n", + "These imports give us access to:\n", + "- **OmegaConf**: Configuration management\n", + "- **TrainerActor**: The main training orchestrator\n", + "- **SpawnActor**: Helper for creating distributed actors" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "output": { + "id": 1438699627218568, + "loadingStatus": "loaded" + } + }, + "outputs": [], + "source": [ + "import asyncio\n", + "import logging\n", + "from omegaconf import OmegaConf, DictConfig\n", + "\n", + "from apps.sft.trainer_actor import TrainerActor\n", + "from apps.sft.spawn_actor import SpawnActor, run_actor" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Configure Model and Process Settings\n", + "\n", + "Define your model configuration and how many processes to use." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "output": { + "id": 779738548196921, + "loadingStatus": "loaded" + } + }, + "outputs": [], + "source": [ + "# Model Configuration\n", + "model_config = {\n", + " \"name\": \"llama3\",\n", + " \"flavor\": \"8B\",\n", + " \"hf_assets_path\": \"Path_to_hf_assets\"\n", + "}\n", + "\n", + "# Process Configuration\n", + "processes_config = {\n", + " \"procs\": 8, # Number of processes\n", + " \"with_gpus\": True # Use GPUs\n", + "}\n", + "\n", + "print(\"Model Configuration:\")\n", + "print(OmegaConf.to_yaml(OmegaConf.create(model_config)))\n", + "print(\"\\nProcess Configuration:\")\n", + "print(OmegaConf.to_yaml(OmegaConf.create(processes_config)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Configure Optimizer and LR Scheduler\n", + "\n", + "### The Optimization Engine\n", + "\n", + "The optimizer controls *how* the model learns from gradients:\n", + "\n", + "**AdamW**: Adaptive learning rates with weight decay\n", + "- Most popular for transformer models\n", + "- Automatically adjusts learning rate per parameter\n", + "- Weight decay prevents overfitting\n", + "\n", + "**Learning Rate (lr)**: Step size for weight updates\n", + "- 1e-5 (0.00001): Conservative, stable for fine-tuning\n", + "- 2e-5 (0.00002): More aggressive, faster convergence\n", + "- Too high โ†’ Model diverges (loss = NaN)\n", + "- Too low โ†’ Very slow learning\n", + "\n", + "**Warmup Steps**: Gradually increase LR from 0 to target\n", + "- Prevents instability at training start\n", + "- 200 steps is typical for fine-tuning\n", + "- Rule of thumb: 5-10% of total training steps" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "output": { + "id": 837311532606040, + "loadingStatus": "loaded" + } + }, + "outputs": [], + "source": [ + "# Optimizer Configuration\n", + "optimizer_config = {\n", + " \"name\": \"AdamW\",\n", + " \"lr\": 1e-5, # Learning rate\n", + " \"eps\": 1e-8\n", + "}\n", + "\n", + "# Learning Rate Scheduler Configuration\n", + "lr_scheduler_config = {\n", + " \"warmup_steps\": 200 # Number of warmup steps\n", + "}\n", + "\n", + "print(\"Optimizer Configuration:\")\n", + "print(OmegaConf.to_yaml(OmegaConf.create(optimizer_config)))\n", + "print(\"\\nLR Scheduler Configuration:\")\n", + "print(OmegaConf.to_yaml(OmegaConf.create(lr_scheduler_config)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Configure Training Settings\n", + "\n", + "### Core Training Parameters\n", + "\n", + "**local_batch_size**: Examples processed per GPU per step\n", + "- Start with 1 for large models (8B+)\n", + "- Increase to 2-4 if you have memory headroom\n", + "- Global batch = local_batch_size ร— num_GPUs\n", + "\n", + "**seq_len**: Maximum sequence length in tokens\n", + "- 2048 tokens โ‰ˆ 1500 words\n", + "- Longer sequences = more context but slower training\n", + "- Reduce if running out of memory\n", + "\n", + "**steps**: Total number of training iterations\n", + "- 100-500: Quick experiment\n", + "- 1000-5000: Solid fine-tune\n", + "- 10000+: Production training\n", + "\n", + "**dataset**: Training data source (e.g., \"c4\", \"alpaca\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "output": { + "id": 1331479275018407, + "loadingStatus": "loaded" + } + }, + "outputs": [], + "source": [ + "training_config = {\n", + " \"local_batch_size\": 1, # Batch size per GPU\n", + " \"seq_len\": 2048, # Sequence length\n", + " \"max_norm\": 1.0, # Gradient clipping\n", + " \"steps\": 1000, # Total training steps\n", + " \"compile\": False, # PyTorch compilation\n", + " \"dataset\": \"c4\" # Dataset name\n", + "}\n", + "\n", + "print(\"Training Configuration:\")\n", + "print(OmegaConf.to_yaml(OmegaConf.create(training_config)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Configure Parallelism Settings\n", + "\n", + "### How Work is Distributed Across GPUs\n", + "\n", + "**FSDP (Fully Sharded Data Parallel)**:\n", + "- Splits model parameters across all GPUs\n", + "- Each GPU holds only a shard (e.g., 1/8th with 8 GPUs)\n", + "- Reduces memory per GPU significantly\n", + "- `data_parallel_shard_degree: -1` โ†’ Auto-use all GPUs\n", + "\n", + "**Other Parallelism Options**:\n", + "- `tensor_parallel_degree`: Split individual layers across GPUs\n", + "- `pipeline_parallel_degree`: Split model into stages\n", + "- Usually kept at 1 for standard fine-tuning\n", + "\n", + "**Why FSDP?**\n", + "- Enables training large models that don't fit on 1 GPU\n", + "- Automatically handles gradient synchronization\n", + "- Near-linear scaling with more GPUs\n", + "\n", + "For more explanation, visit: https://github.com/pytorch/torchtitan/tree/main/docs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "parallelism_config = {\n", + " \"data_parallel_replicate_degree\": 1,\n", + " \"data_parallel_shard_degree\": -1, # -1 means use all available GPUs for FSDP\n", + " \"tensor_parallel_degree\": 1,\n", + " \"pipeline_parallel_degree\": 1,\n", + " \"context_parallel_degree\": 1,\n", + " \"expert_parallel_degree\": 1,\n", + " \"disable_loss_parallel\": False\n", + "}\n", + "\n", + "print(\"Parallelism Configuration:\")\n", + "print(OmegaConf.to_yaml(OmegaConf.create(parallelism_config)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Configure Checkpoint and Activation Checkpointing\n", + "\n", + "### Saving Your Progress\n", + "\n", + "**Checkpointing**: Periodic saves of model state\n", + "- `interval: 500` โ†’ Save every 500 steps\n", + "- Allows resuming if training is interrupted\n", + "- Final checkpoint saved automatically\n", + "- Includes: model weights, optimizer state, training step\n", + "\n", + "**Activation Checkpointing**: Memory optimization technique\n", + "- Trades compute for memory\n", + "- Recomputes activations during backward pass instead of storing them\n", + "- `mode: selective` โ†’ Only checkpoint specific operations\n", + "- `mode: full` โ†’ More aggressive, saves more memory\n", + "\n", + "**When to Use:**\n", + "- Standard checkpointing: Always enable\n", + "- Activation checkpointing: Use when running out of GPU memory\n", + "- Slight slowdown (~10-20%) but can enable training larger models" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "output": { + "id": 1594338181936740, + "loadingStatus": "loaded" + } + }, + "outputs": [], + "source": [ + "# Checkpoint Configuration\n", + "checkpoint_config = {\n", + " \"enable\": True,\n", + " \"folder\": \"Path_to_checkpoint_folder\",\n", + " \"initial_load_path\": \"Path_to_hf_assets\",\n", + " \"initial_load_in_hf\": True,\n", + " \"last_save_in_hf\": True,\n", + " \"interval\": 500, # Save every N steps\n", + " \"async_mode\": \"disabled\"\n", + "}\n", + "\n", + "# Activation Checkpoint Configuration\n", + "activation_checkpoint_config = {\n", + " \"mode\": \"selective\",\n", + " \"selective_ac_option\": \"op\"\n", + "}\n", + "\n", + "print(\"Checkpoint Configuration:\")\n", + "print(OmegaConf.to_yaml(OmegaConf.create(checkpoint_config)))\n", + "print(\"\\nActivation Checkpoint Configuration:\")\n", + "print(OmegaConf.to_yaml(OmegaConf.create(activation_checkpoint_config)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Configure Communication Settings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "output": { + "id": 1115922440689149, + "loadingStatus": "loaded" + } + }, + "outputs": [], + "source": [ + "# Communication Configuration\n", + "comm_config = {\n", + " \"trace_buf_size\": 0\n", + "}\n", + "\n", + "print(\"Communication Configuration:\")\n", + "print(OmegaConf.to_yaml(OmegaConf.create(comm_config)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Combine All Configurations\n", + "\n", + "Now let's merge everything into a complete configuration!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "output": { + "id": 890825846616212, + "loadingStatus": "loaded" + } + }, + "outputs": [], + "source": [ + "# Combine all configs\n", + "complete_config = {\n", + " \"comm\": comm_config,\n", + " \"model\": model_config,\n", + " \"processes\": processes_config,\n", + " \"optimizer\": optimizer_config,\n", + " \"lr_scheduler\": lr_scheduler_config,\n", + " \"training\": training_config,\n", + " \"parallelism\": parallelism_config,\n", + " \"checkpoint\": checkpoint_config,\n", + " \"activation_checkpoint\": activation_checkpoint_config\n", + "}\n", + "\n", + "# Create OmegaConf DictConfig\n", + "cfg = OmegaConf.create(complete_config)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "\n", + "# ๐ŸŽญ Part 2: The Actor Lifecycle\n", + "\n", + "## Understanding Spawn, Setup, Train, and Cleanup\n", + "\n", + "### Phase 1: Spawn the Actor ๐ŸŽญ\n", + "\n", + "**What's happening:**\n", + "- `SpawnActor` creates a launcher for `TrainerActor`\n", + "- `spawn()` launches 8 Python processes (one per GPU)\n", + "- Each process initializes:\n", + " - CUDA device assignment (GPU 0, 1, 2, ...)\n", + " - Distributed communication (NCCL)\n", + " - Process group setup (RANK, LOCAL_RANK, WORLD_SIZE)\n", + "\n", + "**Behind the scenes:**\n", + "```\n", + "GPU 0: Process 0 (RANK=0, LOCAL_RANK=0)\n", + "GPU 1: Process 1 (RANK=1, LOCAL_RANK=1)\n", + "...\n", + "GPU 7: Process 7 (RANK=7, LOCAL_RANK=7)\n", + "```\n", + "\n", + "All processes are now waiting for instructions!\n", + "### What Happens When You Run This?\n", + "\n", + "1. **Spawn** ๐ŸŽญ: Forge creates 8 GPU processes (based on `procs: 8`)\n", + "2. **Setup** ๐Ÿ”ง: Each process loads its shard of the model + data\n", + "3. **Train** ๐Ÿƒ: Training loop runs for 1000 steps\n", + "4. **Cleanup** ๐Ÿงน: Final checkpoint saved, resources released\n", + "\n", + "Uncomment the line below to start training!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create the spawner\n", + "spawner = SpawnActor(TrainerActor, cfg)\n", + "\n", + "# Spawn the actor\n", + "actor = await spawner.spawn()\n", + "print(f\"โœ“ Actor spawned: {actor}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Phase 2: Setup ๐Ÿ”ง [Monarch Endpoint]\n", + "\n", + "**What's happening:**\n", + "\n", + "Monarch calls the `@endpoint` decorated `setup()` method on all 8 actor instances:\n", + "\n", + "```python\n", + "class TrainerActor:\n", + " @endpoint\n", + " async def setup(self):\n", + " # This runs on all 8 GPUs simultaneously\n", + " ...\n", + "```\n", + "\n", + "Each actor instance:\n", + "- **Loads its shard of the model**: With FSDP, each GPU only loads ~1/8th\n", + " - GPU 0 might get layers 0-10\n", + " - GPU 1 gets layers 11-20, etc.\n", + "- **Creates dataloaders**: Same dataset, different random seeds per GPU\n", + "- **Restores checkpoint**: If resuming, loads saved state\n", + "\n", + "**What `setup()` does internally:**\n", + "```python\n", + "@endpoint\n", + "async def setup(self):\n", + " # 1. Initialize model with FSDP sharding\n", + " self.model = load_model_with_fsdp(cfg.model)\n", + " \n", + " # 2. Create training dataloader\n", + " self.train_dataloader = setup_data(\n", + " dataset_path=cfg.dataset.path,\n", + " dataset_split=cfg.dataset.split\n", + " )\n", + " \n", + " # 3. Create validation dataloader (if enabled)\n", + " self.val_dataloader = setup_data(\n", + " dataset_path=cfg.dataset_val.path,\n", + " dataset_split=cfg.dataset_val.split\n", + " )\n", + " \n", + " # 4. Restore from checkpoint (if any)\n", + " self.checkpointer.load(step=self.current_step)\n", + "```\n", + "\n", + "**Monarch magic:**\n", + "- The `@endpoint` decorator makes this method callable remotely\n", + "- Monarch ensures all 8 actors complete setup before proceeding\n", + "- Distributed state (model shards) automatically synchronized\n", + "\n", + "After setup, all 8 GPU actors are synchronized and ready to train!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Setup (load data, checkpoints, etc.)\n", + "await spawner.setup()\n", + "print(\"โœ“ Actor setup complete\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Phase 3: Training Loop ๐Ÿ”„ [Monarch Endpoint]\n", + "\n", + "**What's happening:**\n", + "\n", + "Monarch calls the `@endpoint` decorated `train()` method, which runs the training loop for `cfg.training.steps` iterations. Each step:\n", + "\n", + "```python\n", + "@endpoint\n", + "async def train(self):\n", + " for step in range(current_step, max_steps):\n", + " # 1. Get next batch from dataloader\n", + " batch = next(train_dataloader)\n", + " # Shape: [batch_size, seq_len] per GPU\n", + "\n", + " # 2. Forward pass - compute predictions and loss\n", + " outputs = model(batch['input_ids'])\n", + " loss = compute_loss(outputs, batch['labels'])\n", + "\n", + " # 3. Backward pass - compute gradients\n", + " loss.backward()\n", + " # FSDP automatically synchronizes gradients across all GPUs!\n", + "\n", + " # 4. Optimizer step - update model weights\n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", + "\n", + " # 5. Periodic validation (if enabled)\n", + " if validation_enabled and step % eval_interval == 0:\n", + " val_metrics = evaluate()\n", + " log(f\"Step {step}: Val Loss = {val_metrics['val_loss']}\")\n", + "\n", + " # 6. Periodic checkpointing\n", + " if step % checkpoint_interval == 0:\n", + " save_checkpoint(step)\n", + "```\n", + "\n", + "**Key insights:**\n", + "- **FSDP synchronization**: Gradients automatically reduced across GPUs\n", + "- **Loss should decrease**: If not, check learning rate or data\n", + "- **Validation metrics**: Track generalization on held-out data\n", + "- **Checkpoints**: Resume training if interrupted\n", + "\n", + "**What you'll see:**\n", + "- Training loss decreasing over time\n", + "- Periodic validation metrics (if enabled)\n", + "- Checkpoint saves at regular intervals\n", + "- Step timing information (seconds per step)\n", + "\n", + "**Monarch magic:**\n", + "- The `@endpoint` decorator makes this long-running training loop remotely callable\n", + "- All 8 actor instances run training in sync\n", + "- Monarch handles any RPC timeouts for long-running operations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Run training\n", + "await spawner.run()\n", + "print(\"โœ“ Training complete\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Phase 4: Cleanup ๐Ÿงน\n", + "\n", + "**What's happening:**\n", + "\n", + "```python\n", + "def cleanup(self):\n", + " # 1. Save final checkpoint\n", + " self.checkpointer.save(\n", + " step=self.current_step,\n", + " force=True # Always save, even if not at interval\n", + " )\n", + " \n", + " # 2. Release model from GPU memory\n", + " del self.model\n", + " torch.cuda.empty_cache()\n", + " \n", + " # 3. Shutdown distributed process group\n", + " if torch.distributed.is_initialized():\n", + " torch.distributed.destroy_process_group()\n", + " \n", + " # 4. Log final statistics\n", + " log(f\"Training complete!\")\n", + " log(f\"Final step: {self.current_step}\")\n", + " log(f\"Checkpoint saved to: {checkpoint_path}\")\n", + "```\n", + "\n", + "**Why cleanup matters:**\n", + "- โœ… **Saves final state**: Even if you Ctrl+C, final checkpoint is saved\n", + "- โœ… **Frees GPU memory**: Other jobs can now use the GPUs\n", + "- โœ… **Clean shutdown**: Prevents zombie processes\n", + "- โœ… **Logs summary**: Know exactly where training ended\n", + "\n", + "**After cleanup:**\n", + "- Model weights saved to checkpoint folder\n", + "- GPUs are free and available\n", + "- Training can be resumed from last checkpoint\n", + "- All distributed processes cleanly terminated" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "output": { + "id": 742146812207438, + "loadingStatus": "loaded" + } + }, + "outputs": [], + "source": [ + "# Cleanup resources\n", + "await spawner.cleanup()\n", + "print(\"โœ“ Cleanup complete\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## Running the Complete Lifecycle\n", + "\n", + "### What Happens When You Run Training?\n", + "\n", + "**The full journey:**\n", + "\n", + "1. **Spawn** ๐ŸŽญ \n", + " - Forge creates 8 GPU processes (based on `procs: 8`)\n", + " - Each process gets assigned to a GPU\n", + " - Distributed communication initialized\n", + "\n", + "2. **Setup** ๐Ÿ”ง\n", + " - Each process loads its 1/8th shard of the model\n", + " - Dataloaders created with different random seeds\n", + " - Checkpoint restored if resuming training\n", + "\n", + "3. **Train** ๐Ÿƒ\n", + " - Training loop runs for 1000 steps\n", + " - Loss computed, gradients synced, weights updated\n", + " - Periodic validation and checkpointing\n", + "\n", + "4. **Cleanup** ๐Ÿงน\n", + " - Final checkpoint saved\n", + " - GPU memory released\n", + " - All processes terminated cleanly\n", + "\n", + "**Time estimate:**\n", + "- With 8 GPUs, ~2-3 seconds per step\n", + "- 1000 steps โ‰ˆ 40-50 minutes\n", + "- Plus validation time (if enabled)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "await run_actor(TrainerActor, cfg)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "\n", + "# Quick Configuration Templates\n", + "\n", + "Here are ready-to-use templates for common scenarios!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Template 1: Quick Test (Single GPU, Small Steps)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "quick_test_config = OmegaConf.create({\n", + " \"comm\": {\"trace_buf_size\": 0},\n", + " \"model\": {\n", + " \"name\": \"llama3\",\n", + " \"flavor\": \"8B\",\n", + " \"hf_assets_path\": \"/tmp/Meta-Llama-3.1-8B-Instruct\"\n", + " },\n", + " \"processes\": {\"procs\": 1, \"with_gpus\": True},\n", + " \"optimizer\": {\"name\": \"AdamW\", \"lr\": 1e-5, \"eps\": 1e-8},\n", + " \"lr_scheduler\": {\"warmup_steps\": 10},\n", + " \"training\": {\n", + " \"local_batch_size\": 1,\n", + " \"seq_len\": 1024,\n", + " \"max_norm\": 1.0,\n", + " \"steps\": 100, # Just 100 steps for quick testing\n", + " \"compile\": False,\n", + " \"dataset\": \"c4\"\n", + " },\n", + " \"parallelism\": {\n", + " \"data_parallel_replicate_degree\": 1,\n", + " \"data_parallel_shard_degree\": 1,\n", + " \"tensor_parallel_degree\": 1,\n", + " \"pipeline_parallel_degree\": 1,\n", + " \"context_parallel_degree\": 1,\n", + " \"expert_parallel_degree\": 1,\n", + " \"disable_loss_parallel\": False\n", + " },\n", + " \"checkpoint\": {\n", + " \"enable\": True,\n", + " \"folder\": \"/tmp/quick_test_checkpoints\",\n", + " \"initial_load_path\": \"/tmp/Meta-Llama-3.1-8B-Instruct/\",\n", + " \"initial_load_in_hf\": True,\n", + " \"last_save_in_hf\": True,\n", + " \"interval\": 50,\n", + " \"async_mode\": \"disabled\"\n", + " },\n", + " \"activation_checkpoint\": {\n", + " \"mode\": \"selective\",\n", + " \"selective_ac_option\": \"op\"\n", + " }\n", + "})\n", + "\n", + "print(\"Quick Test Configuration:\")\n", + "print(OmegaConf.to_yaml(quick_test_config))\n", + "\n", + "# To use: await run_actor(TrainerActor, quick_test_config)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Template 2: Multi-GPU Training (8 GPUs with FSDP)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "multi_gpu_config = OmegaConf.create({\n", + " \"comm\": {\"trace_buf_size\": 0},\n", + " \"model\": {\n", + " \"name\": \"llama3\",\n", + " \"flavor\": \"8B\",\n", + " \"hf_assets_path\": \"/tmp/Meta-Llama-3.1-8B-Instruct\"\n", + " },\n", + " \"processes\": {\"procs\": 8, \"with_gpus\": True},\n", + " \"optimizer\": {\"name\": \"AdamW\", \"lr\": 2e-5, \"eps\": 1e-8},\n", + " \"lr_scheduler\": {\"warmup_steps\": 200},\n", + " \"training\": {\n", + " \"local_batch_size\": 2,\n", + " \"seq_len\": 2048,\n", + " \"max_norm\": 1.0,\n", + " \"steps\": 5000,\n", + " \"compile\": False,\n", + " \"dataset\": \"c4\"\n", + " },\n", + " \"parallelism\": {\n", + " \"data_parallel_replicate_degree\": 1,\n", + " \"data_parallel_shard_degree\": 8, # FSDP across 8 GPUs\n", + " \"tensor_parallel_degree\": 1,\n", + " \"pipeline_parallel_degree\": 1,\n", + " \"context_parallel_degree\": 1,\n", + " \"expert_parallel_degree\": 1,\n", + " \"disable_loss_parallel\": False\n", + " },\n", + " \"checkpoint\": {\n", + " \"enable\": True,\n", + " \"folder\": \"/tmp/multi_gpu_checkpoints\",\n", + " \"initial_load_path\": \"/tmp/Meta-Llama-3.1-8B-Instruct/\",\n", + " \"initial_load_in_hf\": True,\n", + " \"last_save_in_hf\": True,\n", + " \"interval\": 500,\n", + " \"async_mode\": \"disabled\"\n", + " },\n", + " \"activation_checkpoint\": {\n", + " \"mode\": \"selective\",\n", + " \"selective_ac_option\": \"op\"\n", + " }\n", + "})\n", + "\n", + "print(\"Multi-GPU Configuration:\")\n", + "print(OmegaConf.to_yaml(multi_gpu_config))\n", + "\n", + "# To use: await run_actor(TrainerActor, multi_gpu_config)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "\n", + "# Tips & Tricks\n", + "\n", + "## Memory Optimization\n", + "- โฌ‡๏ธ Reduce `seq_len` if running out of memory\n", + "- โฌ‡๏ธ Reduce `local_batch_size` if running out of memory\n", + "- โœ… Enable `activation_checkpoint` for memory savings\n", + "\n", + "## Training Speed\n", + "- โฌ†๏ธ Increase `local_batch_size` for faster training (if memory allows)\n", + "- ๐Ÿš€ Use multiple GPUs with FSDP (`data_parallel_shard_degree > 1`)\n", + "- โšก Enable `compile: true` for PyTorch compilation (experimental)\n", + "\n", + "## Debugging\n", + "- ๐Ÿงช Start with small `steps` (e.g., 10-100) to test quickly\n", + "- ๐Ÿ” Use single GPU first (`procs: 1`)\n", + "- ๐Ÿ“Š Monitor loss values in logs\n", + "\n", + "## Checkpoint Management\n", + "- ๐Ÿ’พ Set `interval` based on how often you want to save\n", + "- ๐Ÿ“ Ensure `folder` path exists and has enough space\n", + "- ๐Ÿ”„ Use `initial_load_path` to resume from checkpoints" + ] + } + ], + "metadata": { + "fileHeader": "", + "fileUid": "924c63b2-fa48-4468-a04b-437f8bd23456", + "isAdHoc": false, + "kernelspec": { + "display_name": "forge (conda)", + "language": "python", + "name": "conda_forge" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.18" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/apps/sft/spawn_actor.py b/apps/sft/spawn_actor.py new file mode 100644 index 000000000..af235dfa4 --- /dev/null +++ b/apps/sft/spawn_actor.py @@ -0,0 +1,139 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +SpawnActor - Orchestrates the spawning and lifecycle management of actors. + +This module provides a high-level interface for creating, setting up, running, +and cleaning up different types of actors (e.g., Trainer, Evaluator, etc.) +""" + +import logging +from typing import Any, Type + +from apps.sft.actor import BaseForgeActor +from omegaconf import DictConfig + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class SpawnActor: + """ + Orchestrator for spawning and managing actor lifecycles. + + This class handles the creation, setup, execution, and cleanup of actors + in a standardized way. + """ + + def __init__(self, actor_class: Type[BaseForgeActor], config: DictConfig): + """ + Initialize the spawn actor orchestrator. + + Args: + actor_class: The actor class to instantiate (must inherit from BaseForgeActor) + config: Configuration dictionary for the actor + """ + self.actor_class = actor_class + self.config = config + self.actor = None + + if not issubclass(actor_class, BaseForgeActor): + raise TypeError( + f"actor_class must be a subclass of BaseForgeActor, got {actor_class}" + ) + + async def spawn(self) -> Any: + """ + Spawn the actor instance with the given configuration. + + Returns: + The spawned actor instance + """ + logger.info(f"Spawning {self.actor_class.__name__}...") + + process_cfg = self.config.pop("processes", {}) + + self.actor = await self.actor_class.options(**process_cfg).as_actor(self.config) + + logger.info(f"{self.actor_class.__name__} spawned successfully.") + return self.actor + + async def setup(self): + """ + Setup the spawned actor (load data, checkpoints, etc.). + """ + if self.actor is None: + raise RuntimeError( + "Actor must be spawned before setup. Call spawn() first." + ) + + logger.info(f"Setting up {self.actor_class.__name__}...") + await self.actor.setup.call() + logger.info(f"{self.actor_class.__name__} setup complete.") + + async def run(self): + """ + Run the main execution logic of the actor. + """ + if self.actor is None: + raise RuntimeError( + "Actor must be spawned before running. Call spawn() first." + ) + + logger.info(f"Running {self.actor_class.__name__}...") + await self.actor.run.call() + logger.info(f"{self.actor_class.__name__} execution complete.") + + async def cleanup(self): + """ + Cleanup the actor resources and stop the mesh. + """ + if self.actor is None: + raise RuntimeError( + "Actor must be spawned before cleanup. Call spawn() first." + ) + + logger.info(f"Cleaning up {self.actor_class.__name__}...") + await self.actor.cleanup.call() + + if hasattr(self.actor, "mesh"): + await self.actor.mesh.stop() + + logger.info(f"{self.actor_class.__name__} cleanup complete.") + + async def run_full_lifecycle(self): + """ + Execute the complete actor lifecycle: spawn -> setup -> run -> cleanup. + + This is a convenience method that runs all phases in sequence. + """ + logger.info(f"Starting full lifecycle for {self.actor_class.__name__}...") + + try: + await self.spawn() + await self.setup() + await self.run() + finally: + if self.actor is not None: + await self.cleanup() + + logger.info(f"Full lifecycle complete for {self.actor_class.__name__}.") + + +async def run_actor( + actor_class: Type[BaseForgeActor], + config: DictConfig, +) -> None: + """ + Convenience function to run an actor with full lifecycle management. + + Args: + actor_class: The actor class to instantiate + config: Configuration dictionary for the actor + """ + spawner = SpawnActor(actor_class, config) + await spawner.run_full_lifecycle() diff --git a/apps/sft/trainer_actor.py b/apps/sft/trainer_actor.py new file mode 100644 index 000000000..bd0e4630a --- /dev/null +++ b/apps/sft/trainer_actor.py @@ -0,0 +1,189 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Trainer actor implementation for SFT training. + +This is a concrete implementation of BaseForgeActor for supervised fine-tuning. +""" + +import logging + +import torch +import torchtitan.experiments.forge.train_spec as forge_train_spec +from apps.sft.actor import BaseForgeActor +from apps.sft.utils import ( + create_context_parallel_context, + log_training_step, + move_batch_to_device, + setup_sft_dataloader, + setup_tokenizer, +) +from monarch.actor import endpoint +from omegaconf import DictConfig + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class TrainerActor(BaseForgeActor): + """ + Concrete trainer actor for supervised fine-tuning. + + Handles training loop, forward/backward passes, and checkpoint management. + """ + + train_spec: forge_train_spec.ForgeTrainSpec + train_dataloader: any + num_training_steps: int + + def __init__(self, config: DictConfig): + """ + Initialize the trainer actor. + + Args: + config: Configuration dictionary containing training settings + """ + super().__init__(config) + self.num_training_steps = self.job_config.training.steps + + @endpoint + async def setup(self): + """ + Setup the trainer (load data, checkpoint, etc.). + """ + logger.info("Setting up trainer actor...") + + self.tokenizer = setup_tokenizer( + hf_assets_path=self.job_config.model.hf_assets_path + ) + + self.train_dataloader = setup_sft_dataloader( + tokenizer=self.tokenizer, + dataset_path="yahma/alpaca-cleaned", + dataset_split="train", + target_tokens_per_pack=self.job_config.training.seq_len, + batch_size=self.job_config.training.local_batch_size, + device=self.device, + ) + + if self.checkpointer: + logger.info("Loading checkpoint...") + self.checkpointer.load(step=self.current_step) + + logger.info("Trainer setup complete.") + + def forward_backward( + self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor + ) -> torch.Tensor: + """ + Perform forward and backward pass. + + Args: + input_dict: Dictionary containing input tokens + labels: Ground truth labels + + Returns: + Computed loss value + """ + model_parts = self.model_parts + parallel_dims = self.parallel_dims + inputs = input_dict["tokens"] + + optional_context_parallel_ctx = create_context_parallel_context( + parallel_dims=parallel_dims, + inputs=inputs, + labels=labels, + model_parts=model_parts, + rotate_method=self.job_config.parallelism.context_parallel_rotate_method, + ) + + if parallel_dims.pp_enabled: + with self.train_context(optional_context_parallel_ctx): + targets, losses = ( + (labels, []) if self.pp_has_last_stage else (None, None) + ) + if self.pp_has_first_stage: + self.pp_schedule.step( + inputs, target=targets, losses=losses, input_batch=inputs + ) + else: + self.pp_schedule.step( + target=targets, losses=losses, input_batch=inputs + ) + + loss = ( + torch.mean(torch.stack(losses)).to(self.device) + if self.pp_has_last_stage + else torch.tensor([-1.0], device=self.device) + ) + else: + with self.train_context(optional_context_parallel_ctx): + assert len(model_parts) == 1 + with self.maybe_enable_amp: + pred = model_parts[0](inputs) + loss = self.loss_fn(pred, labels) + del pred + loss.backward() + + return loss + + def train_step(self, batch: dict[str, torch.Tensor]) -> None: + """ + Execute a single training step. + + Args: + batch: Dictionary containing batch data (tokens, labels, etc.) + """ + labels = batch.pop("labels") + loss = self.forward_backward(batch, labels) + + log_training_step(self.current_step, self.num_training_steps, loss, logger) + + self.optimizers.step() + self.lr_schedulers.step() + + @endpoint + async def run(self) -> None: + """ + Main training loop. + """ + logger.info("Starting training loop...") + + dataloader = iter(self.train_dataloader) + self.optimizers.zero_grad() + + while self.current_step < self.num_training_steps: + batch = next(dataloader) + batch = move_batch_to_device(batch, self.device) + + self.train_step(batch) + self.current_step += 1 + + if self.checkpointer: + self.checkpointer.save( + curr_step=self.current_step, + last_step=self.current_step == self.num_training_steps, + ) + + logger.info("Training complete!") + + @endpoint + async def cleanup(self) -> None: + """ + Cleanup resources (close checkpointer, logger, etc.). + """ + logger.info("Cleaning up trainer actor...") + + if self.checkpointer: + self.checkpointer.close() + if self.metric_logger: + self.metric_logger.close() + + logger.info("Cleanup complete.") + + def __repr__(self) -> str: + return "TrainerActor" diff --git a/apps/sft/utils.py b/apps/sft/utils.py new file mode 100644 index 000000000..6d0219805 --- /dev/null +++ b/apps/sft/utils.py @@ -0,0 +1,187 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Utility functions for SFT training actors. + +These utilities handle data loading, model setup, and common operations. +""" + +import logging +import os +from functools import partial +from typing import Any, Optional + +import torch +from forge.data.collate import collate_packed +from forge.data.datasets.packed import PackedDataset, TextPacker +from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset +from forge.data.tokenizer import HuggingFaceModelTokenizer +from torchdata.stateful_dataloader import StatefulDataLoader +from torchtitan.distributed import ParallelDims, utils as dist_utils + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +def setup_tokenizer( + hf_assets_path: str, + tokenizer_filename: str = "tokenizer.json", + tokenizer_config_filename: str = "tokenizer_config.json", + generation_config_filename: str = "generation_config.json", +) -> HuggingFaceModelTokenizer: + """ + Setup HuggingFace tokenizer from model assets. + + Args: + hf_assets_path: Path to the directory containing tokenizer files + tokenizer_filename: Name of the tokenizer JSON file + tokenizer_config_filename: Name of the tokenizer config JSON file + generation_config_filename: Name of the generation config JSON file + + Returns: + Initialized HuggingFaceModelTokenizer + """ + tokenizer_json_path = os.path.join(hf_assets_path, tokenizer_filename) + tokenizer_config_path = os.path.join(hf_assets_path, tokenizer_config_filename) + generation_config_path = os.path.join(hf_assets_path, generation_config_filename) + + logger.info(f"Loading tokenizer from: {tokenizer_json_path}") + + tokenizer = HuggingFaceModelTokenizer( + tokenizer_json_path=tokenizer_json_path, + tokenizer_config_json_path=tokenizer_config_path, + generation_config_path=generation_config_path, + ) + + return tokenizer + + +def setup_sft_dataloader( + tokenizer: HuggingFaceModelTokenizer, + dataset_path: str, + dataset_split: str, + target_tokens_per_pack: int, + batch_size: int, + device: torch.device, + padding_idx: int = 0, + message_transform: Optional[Any] = None, +) -> StatefulDataLoader: + """ + Setup dataloader for SFT training. + + Args: + tokenizer: Tokenizer to use for processing text + dataset_path: Path or name of the dataset (e.g., "yahma/alpaca-cleaned") + dataset_split: Dataset split to use (e.g., "train", "validation") + target_tokens_per_pack: Target sequence length for packing + batch_size: Batch size for training + device: Device to move tensors to + padding_idx: Padding token index + message_transform: Transform to convert dataset format to messages + + Returns: + Configured StatefulDataLoader + """ + if message_transform is None: + message_transform = AlpacaToMessages() + + logger.info(f"Loading SFT dataset from: {dataset_path}, split: {dataset_split}") + + dataset = sft_iterable_dataset( + model_transform=tokenizer, + message_transform=message_transform, + path=dataset_path, + split=dataset_split, + ) + + packer = TextPacker(padding_idx=padding_idx) + dataset = PackedDataset( + dataset=dataset, + packer=packer, + target_tokens_per_pack=target_tokens_per_pack, + ) + + dataloader = StatefulDataLoader( + dataset=dataset, + batch_size=batch_size, + collate_fn=partial( + collate_packed, mask_fn=packer.create_block_mask, device=device + ), + ) + + logger.info( + f"Created dataloader with batch_size={batch_size}, target_tokens={target_tokens_per_pack}" + ) + + return dataloader + + +def create_context_parallel_context( + parallel_dims: ParallelDims, + inputs: torch.Tensor, + labels: torch.Tensor, + model_parts: list, + rotate_method: str, +): + """ + Create context parallel context for distributed training. + + Args: + parallel_dims: Parallel dimensions configuration + inputs: Input tensor + labels: Label tensor + model_parts: List of model parts + rotate_method: Context parallel rotation method + + Returns: + Context parallel context or None if CP is not enabled + """ + if not parallel_dims.cp_enabled: + return None + + return dist_utils.create_context_parallel_ctx( + cp_mesh=parallel_dims.world_mesh["cp"], + cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], + cp_seq_dims=[1, 1] + [0 for _ in model_parts], + cp_no_restore_buffers={inputs, labels}, + cp_rotate_method=rotate_method, + ) + + +def move_batch_to_device(batch: dict[str, Any], device: torch.device) -> dict[str, Any]: + """ + Move batch tensors to the specified device. + + Args: + batch: Dictionary containing batch data + device: Target device + + Returns: + Batch with tensors moved to device + """ + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + batch[key] = value.to(device) + return batch + + +def log_training_step( + step: int, + total_steps: int, + loss: torch.Tensor, + logger: logging.Logger, +): + """ + Log training step information. + + Args: + step: Current training step + total_steps: Total number of training steps + loss: Current loss value + logger: Logger instance + """ + logger.info(f"Step {step}/{total_steps} | Loss: {loss.item():.4f}")