# Get Started with PyTorch FSDP2 and Ray Train: A Complete Guide

This notebook will show you how to train large models that don't fit in a single GPU's memory using PyTorch's Fully Sharded Data Parallel (FSDP2) with Ray Train. FSDP2 enables model sharding across multiple GPUs and nodes, significantly reducing memory footprint compared to standard Distributed Data Parallel (DDP).

In this tutorial, you:
1. Learn how FSDP2 shards model parameters, gradients, and optimizer states across workers
2. Configure memory optimization techniques: CPU offloading, mixed precision, and resharding strategies
3. Use PyTorch Distributed Checkpoint (DCP) for efficient checkpointing of sharded models
4. Profile GPU memory usage with PyTorch's memory snapshot API
5. Load the trained model for inference

## What is FSDP2?

[Fully Sharded Data Parallel (FSDP2)](https://docs.pytorch.org/tutorials/intermediate/FSDP_tutorial.html) is PyTorch's native solution for training large models that exceed single GPU memory:

> FSDP2 shards model parameters, gradients, and optimizer states across data-parallel workers. During the forward pass, parameters are all-gathered for computation, then re-sharded. This enables training models that are much larger than what fits on a single GPU.

FSDP2 is a significant improvement over FSDP1:
- **Per-parameter sharding**: Chunks each parameter on dim-0 across workers (vs. flattening/concatenating in FSDP1)
- **DTensor integration**: Better support for tensor parallelism and multi-dimensional parallelism
- **Relaxed constraints**: Handles frozen parameters more naturally
- **Communication-free state dicts**: Sharded state dicts without collective communication

## When to Use FSDP2?

Use FSDP2 when:
- Your model doesn't fit in a single GPU's memory even with gradient checkpointing
- You want native PyTorch integration (no external dependencies)
- You need fine-grained control over sharding strategies
- You're building custom training loops with PyTorch's ecosystem

For comparison with DeepSpeed (another memory optimization solution), see the [DeepSpeed tutorial](./DeepSpeed_RayTrain_Tutorial.ipynb) in this folder.

## Prerequisites

This tutorial requires:
- A Ray cluster with GPU workers (this example uses 2x T4 GPUs)
- PyTorch 2.0+ with CUDA support
- Shared storage accessible from all workers (e.g., `/mnt/cluster_storage/`)

When running on open-source Ray (without Anyscale), you'll need to:
- Configure your Ray cluster manually
- Set up NFS or cloud storage for checkpointing
- Ensure PyTorch is installed on all worker nodes

## `Step 1`: Environment Setup

First, let's check the Ray cluster status and install dependencies. This tutorial requires:
- A Ray cluster with GPU workers (this example uses 2 GPUs)
- PyTorch 2.0+ with CUDA support
- Shared storage accessible from all workers (e.g., `/mnt/cluster_storage/`)

When running on open-source Ray (without Anyscale), you'll need to:
- Configure your Ray cluster manually
- Set up NFS or cloud storage for checkpointing
- Ensure PyTorch is installed on all worker nodes

In [None]:
# Check Ray cluster status
!ray status

In [None]:
%%bash
pip install -q torch torchvision matplotlib

In [None]:
# Verify installation and check versions
import torch
import ray

print(f"PyTorch version: {torch.__version__}")
print(f"Ray version: {ray.__version__}")

In [None]:
# Enable Ray Train V2 API (recommended for latest features)
import os
os.environ["RAY_TRAIN_V2_ENABLED"] = "1"

# Standard library imports
import tempfile
import uuid
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

## `Step 2`: Model Definition

We'll use a Vision Transformer (ViT) for this tutorial. ViT has clear, repeatable block structures (transformer encoder blocks) that are ideal for demonstrating FSDP2's sharding capabilities.

### Key Architecture Decisions:

| Parameter | Value | Rationale |
|-----------|-------|----------|
| `image_size` | 28 | FashionMNIST native resolution |
| `patch_size` | 7 | Creates 4x4 = 16 patches per image |
| `num_layers` | 10 | Sufficient depth for FSDP2 demonstration |
| `hidden_dim` | 128 | Moderate model size for T4 GPUs |
| `num_classes` | 10 | FashionMNIST categories |

In [None]:
import torch
from torchvision.models import VisionTransformer
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose

def init_model() -> torch.nn.Module:
    """Initialize a Vision Transformer model for FashionMNIST classification.
    
    The model is configured for 28x28 grayscale images with 10 output classes.
    We modify the patch embedding layer to accept single-channel input.
    
    Returns:
        torch.nn.Module: Configured ViT model
    """
    logger.info("Initializing Vision Transformer model...")

    model = VisionTransformer(
        image_size=28,        # FashionMNIST image size
        patch_size=7,         # Creates 4x4 = 16 patches
        num_layers=10,        # Number of transformer encoder layers
        num_heads=2,          # Attention heads per layer
        hidden_dim=128,       # Embedding dimension
        mlp_dim=128,          # Feed-forward network dimension
        num_classes=10,       # FashionMNIST has 10 classes
    )

    # Modify patch embedding for grayscale images (1 channel instead of 3)
    model.conv_proj = torch.nn.Conv2d(
        in_channels=1,        # Grayscale input
        out_channels=128,     # Match hidden_dim
        kernel_size=7,        # Match patch_size
        stride=7,             # Non-overlapping patches
    )

    return model

# Quick test: verify model initialization
test_model = init_model()
param_count = sum(p.numel() for p in test_model.parameters())
print(f"Model parameters: {param_count:,} ({param_count / 1e6:.2f}M)")
del test_model

## `Step 3`: FSDP2 Sharding Configuration

FSDP2's `fully_shard` API provides several knobs for memory-performance tradeoffs. Every configuration option is highlighted and explained with a numbered comment; for example, "[1]."

### Key Configuration Options:

**Device Mesh [1]**: The `DeviceMesh` describes the topology of your training cluster. For data parallelism, we use a simple 1D mesh where each dimension represents a data-parallel rank. For more advanced setups (tensor parallelism, pipeline parallelism), you can use multi-dimensional meshes.

**CPU Offloading [2]**: Stores sharded parameters, gradients, and optimizer states on CPU, copying to GPU only during computation. Use when GPU memory is constrained, but be aware of increased CPU-GPU transfer overhead.

**Mixed Precision [3]**: Uses FP16 for parameters and gradient reductions while maintaining FP32 for critical operations. Provides ~2x memory reduction and faster computation on tensor cores.

**Resharding After Forward [4]**: The `reshard_after_forward=True` flag frees all-gathered weights after the forward pass, reducing peak memory at the cost of more communication.

In [None]:
# FSDP2 imports
from torch.distributed.fsdp import (
    fully_shard,
    FSDPModule,
    CPUOffloadPolicy,
    MixedPrecisionPolicy,
)
from torch.distributed.device_mesh import init_device_mesh

import ray.train

def shard_model(model: torch.nn.Module):
    """Apply FSDP2 sharding to the model.
    
    Sharding strategy:
    1. Shard each transformer encoder block individually
    2. Wrap the entire model with FSDP2
    
    This granularity provides good memory savings while limiting
    communication overhead between workers.
    """
    logger.info("Applying FSDP2 sharding to model...")

    # [1] Create device mesh for data parallelism.
    # The DeviceMesh describes the topology of your training cluster.
    # For data parallelism, we use a simple 1D mesh.
    # =================================================================
    world_size = ray.train.get_context().get_world_size()
    mesh = init_device_mesh(
        device_type="cuda",
        mesh_shape=(world_size,),
        mesh_dim_names=("data_parallel",)
    )

    # [2] Configure CPU offloading (reduces GPU memory usage).
    # This stores sharded parameters, gradients, and optimizer states on CPU,
    # copying to GPU only during forward/backward computation.
    # ================================================================
    offload_policy = CPUOffloadPolicy()

    # [3] Configure mixed precision (FP16 for efficiency).
    # Uses FP16 for parameters and gradient reductions while maintaining
    # FP32 for critical operations. Provides ~2x memory reduction.
    # ================================================================
    mp_policy = MixedPrecisionPolicy(
        param_dtype=torch.float16,
        reduce_dtype=torch.float16,
    )

    # [4] Shard each encoder block (per-layer sharding).
    # This provides good memory savings while limiting communication overhead.
    # ========================================================================
    for encoder_block in model.encoder.layers.children():
        fully_shard(
            encoder_block,
            mesh=mesh,
            reshard_after_forward=True,  # Free all-gathered weights after forward
            offload_policy=offload_policy,
            mp_policy=mp_policy
        )

    # [5] Shard the root model.
    # =========================
    fully_shard(
        model,
        mesh=mesh,
        reshard_after_forward=True,
        offload_policy=offload_policy,
        mp_policy=mp_policy
    )
    
    logger.info(f"Model sharded across {world_size} workers")

## `Step 4`: Distributed Checkpointing

PyTorch Distributed Checkpoint (DCP) provides efficient checkpointing for sharded models. Key features:

- **Sharded saving**: Each worker saves only its shard (parallel I/O)
- **Automatic resharding**: Load checkpoints even if worker count changes
- **Optimizer state support**: Save full training state for resumption

### Architecture

We use PyTorch's `Stateful` protocol to wrap model and optimizer state:

```
AppState (Stateful)
├── model (FSDPModule)
├── optimizer (Adam)
└── epoch (int)
```

In [None]:
# DCP imports
from torch.distributed.checkpoint.state_dict import (
    get_state_dict,
    set_state_dict,
    get_model_state_dict,
    StateDictOptions
)
from torch.distributed.checkpoint.stateful import Stateful
import torch.distributed.checkpoint as dcp

class AppState(Stateful):
    """Wrapper for checkpointing application state with DCP.
    
    Implements PyTorch's Stateful protocol for automatic state
    serialization/deserialization during dcp.save/load calls.
    
    The key insight is that get_state_dict/set_state_dict handle
    FSDP2's fully qualified names (FQNs) automatically.
    """

    def __init__(self, model, optimizer=None, epoch=None):
        self.model = model
        self.optimizer = optimizer
        self.epoch = epoch

    def state_dict(self):
        """Extract sharded state dict for saving."""
        model_state_dict, optimizer_state_dict = get_state_dict(
            self.model, self.optimizer
        )
        return {
            "model": model_state_dict,
            "optim": optimizer_state_dict,
            "epoch": self.epoch
        }

    def load_state_dict(self, state_dict):
        """Load sharded state dict (handles resharding automatically)."""
        set_state_dict(
            self.model,
            self.optimizer,
            model_state_dict=state_dict["model"],
            optim_state_dict=state_dict["optim"],
        )
        if "epoch" in state_dict:
            self.epoch = state_dict["epoch"]

In [None]:
def load_fsdp_checkpoint(
    model: FSDPModule, 
    optimizer: torch.optim.Optimizer, 
    ckpt: ray.train.Checkpoint
) -> int | None:
    """Load an FSDP checkpoint for resuming training.
    
    DCP automatically handles resharding if the number of workers
    differs from when the checkpoint was saved.
    
    Args:
        model: FSDP-wrapped model
        optimizer: Optimizer instance
        ckpt: Ray Train checkpoint object
        
    Returns:
        Epoch number from checkpoint, or None if not available
    """
    logger.info("Loading distributed checkpoint...")
    
    try:
        with ckpt.as_directory() as checkpoint_dir:
            app_state = AppState(model, optimizer)
            dcp.load(
                state_dict={"app": app_state},
                checkpoint_id=checkpoint_dir
            )
        logger.info(f"Loaded checkpoint from epoch {app_state.epoch}")
        return app_state.epoch
    except Exception as e:
        logger.error(f"Checkpoint loading failed: {e}")
        raise

In [None]:
def report_metrics_and_save_fsdp_checkpoint(
    model: FSDPModule, 
    optimizer: torch.optim.Optimizer, 
    metrics: dict, 
    epoch: int = 0
) -> None:
    """Save checkpoint and report metrics to Ray Train.
    
    Each worker saves its shard to a temporary directory, then
    Ray Train consolidates these to shared storage.
    
    Args:
        model: FSDP-wrapped model
        optimizer: Optimizer instance
        metrics: Dict of metrics (loss, accuracy, etc.)
        epoch: Current epoch number
    """
    logger.info("Saving checkpoint and reporting metrics...")
    
    with tempfile.TemporaryDirectory() as temp_dir:
        # Save distributed checkpoint
        dcp.save(
            state_dict={"app": AppState(model, optimizer, epoch)},
            checkpoint_id=temp_dir
        )
        
        # Report to Ray Train (uploads to shared storage)
        checkpoint = ray.train.Checkpoint.from_directory(temp_dir)
        ray.train.report(metrics, checkpoint=checkpoint)
        
    logger.info(f"Checkpoint saved. Metrics: {metrics}")

In [None]:
def save_model_for_inference(model: FSDPModule, world_rank: int) -> None:
    """Consolidate sharded model into a single file for inference.
    
    This all-gathers parameters to rank 0 and saves a standard
    PyTorch checkpoint compatible with torch.load().
    
    Warning: For very large models, this may exceed CPU memory on rank 0.
    In such cases, use distributed loading for inference instead.
    
    Args:
        model: FSDP-wrapped model
        world_rank: Current worker's rank
    """
    logger.info("Preparing model for inference...")
    
    with tempfile.TemporaryDirectory() as temp_dir:
        save_file = os.path.join(temp_dir, "full-model.pt")

        # All-gather model state to rank 0
        model_state_dict = get_model_state_dict(
            model=model,
            options=StateDictOptions(
                full_state_dict=True,    # Reconstruct full model
                cpu_offload=True,        # Save GPU memory
            )
        )

        logger.info("Retrieved complete model state dict")
        checkpoint = None

        # Only rank 0 saves the consolidated checkpoint
        if world_rank == 0:
            torch.save(model_state_dict, save_file)
            logger.info(f"Saved model to {save_file}")
            checkpoint = ray.train.Checkpoint.from_directory(temp_dir)

        # Report checkpoint (only rank 0's is non-None)
        ray.train.report({}, checkpoint=checkpoint, checkpoint_dir_name="full_model")

## `Step 5`: Training Function

The training function runs on each worker. Key responsibilities:

1. **Initialize model** on GPU
2. **Apply FSDP2 sharding** using the `shard_model()` function
3. **Resume from checkpoint** if available
4. **Run training loop** with metric reporting
5. **Save final model** for inference

### Memory Profiling

We use PyTorch's CUDA memory snapshot API (PyTorch 2.10+) to profile GPU memory usage:

```python
torch.cuda.memory._record_memory_history(max_entries=100000)
# ... training ...
torch.cuda.memory._dump_snapshot(path)
```

The snapshot can be visualized using PyTorch's memory visualizer.

In [None]:
import ray.train.torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader

def train_func(config):
    """Main training function for FSDP2 + Ray Train.
    
    This function runs on each distributed worker. Ray Train handles:
    - Process group initialization
    - Device assignment
    - Checkpoint coordination
    
    Args:
        config: Dict with training hyperparameters
    """
    # === Model Setup ===
    model = init_model()
    
    # Get assigned device and move model
    device = ray.train.torch.get_device()
    torch.cuda.set_device(device)
    model.to(device)
    
    # Apply FSDP2 sharding
    shard_model(model)
    
    # === Optimizer Setup ===
    criterion = CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=config.get('learning_rate', 0.001))
    
    # === Checkpoint Loading ===
    start_epoch = 0
    loaded_checkpoint = ray.train.get_checkpoint()
    if loaded_checkpoint:
        latest_epoch = load_fsdp_checkpoint(model, optimizer, loaded_checkpoint)
        start_epoch = latest_epoch + 1 if latest_epoch is not None else 0
        logger.info(f"Resuming from epoch {start_epoch}")
    
    # === Data Loading ===
    transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
    data_dir = os.path.join(tempfile.gettempdir(), "data")
    train_data = FashionMNIST(
        root=data_dir, train=True, download=True, transform=transform
    )
    train_loader = DataLoader(
        train_data,
        batch_size=config.get('batch_size', 64),
        shuffle=True
    )
    # Wrap with DistributedSampler and auto device placement
    train_loader = ray.train.torch.prepare_data_loader(train_loader)
    
    # === Training Context ===
    world_rank = ray.train.get_context().get_world_rank()
    run_name = ray.train.get_context().get_experiment_name()
    
    # === Memory Profiling Setup ===
    torch.cuda.memory._record_memory_history(max_entries=100000)
    
    # === Training Loop ===
    running_loss = 0.0
    num_batches = 0
    epochs = config.get('epochs', 5)
    
    for epoch in range(start_epoch, epochs):
        # Ensure proper shuffling for distributed sampler
        if ray.train.get_context().get_world_size() > 1:
            train_loader.sampler.set_epoch(epoch)
        
        for images, labels in train_loader:
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            num_batches += 1
        
        # Report metrics and save checkpoint
        avg_loss = running_loss / num_batches
        metrics = {"loss": avg_loss, "epoch": epoch}
        report_metrics_and_save_fsdp_checkpoint(model, optimizer, metrics, epoch)
        
        if world_rank == 0:
            logger.info(f"Epoch {epoch}: loss={avg_loss:.4f}")
    
    # === Save Memory Snapshot ===
    try:
        snapshot_path = f"/mnt/cluster_storage/{run_name}/rank{world_rank}_memory_snapshot.pickle"
        torch.cuda.memory._dump_snapshot(snapshot_path)
        logger.info(f"Saved memory snapshot to {snapshot_path}")
    except Exception as e:
        logger.warning(f"Could not save memory snapshot: {e}")
    finally:
        torch.cuda.memory._record_memory_history(enabled=None)
    
    # === Save Final Model ===
    save_model_for_inference(model, world_rank)

## `Step 6`: Launch Distributed Training

Ray Train's `TorchTrainer` handles:
- Spawning workers across the cluster
- Initializing PyTorch distributed process groups
- Coordinating checkpoints to shared storage
- Automatic fault tolerance and restarts

### Configuration Options

| Parameter | Description |
|-----------|-------------|
| `num_workers` | Number of distributed workers (typically = number of GPUs) |
| `use_gpu` | Enable GPU training |
| `storage_path` | Shared storage for checkpoints (NFS, S3, etc.) |
| `max_failures` | Number of retries on worker failure |

In [None]:
import ray.train
import ray.train.torch

# Scaling configuration
scaling_config = ray.train.ScalingConfig(
    num_workers=2,      # Use 2 GPU workers
    use_gpu=True        # Enable GPU training
)

# Training hyperparameters
train_loop_config = {
    "epochs": 1,
    "learning_rate": 0.001,
    "batch_size": 64,
}

# Unique experiment name
experiment_name = f"fsdp_mnist_{uuid.uuid4().hex[:8]}"

# Run configuration
run_config = ray.train.RunConfig(
    storage_path="/mnt/cluster_storage/",
    name=experiment_name,
    failure_config=ray.train.FailureConfig(max_failures=1),
)

print(f"Experiment: {experiment_name}")
print(f"Workers: {scaling_config.num_workers}")
print(f"Epochs: {train_loop_config['epochs']}")

In [None]:
# Create and run trainer
trainer = ray.train.torch.TorchTrainer(
    train_loop_per_worker=train_func,
    scaling_config=scaling_config,
    train_loop_config=train_loop_config,
    run_config=run_config,
)

print("Starting FSDP2 training...")
result = trainer.fit()
print("\nTraining completed!")
print(f"Final checkpoint: {result.checkpoint}")

## `Step 7`: Inspect Training Artifacts

After training, the following artifacts are saved to cluster storage:

```
/mnt/cluster_storage/{experiment_name}/
├── checkpoint_*/              # Epoch checkpoints (distributed shards)
│   ├── __0_0.distcp          # Rank 0 shard
│   └── __1_0.distcp          # Rank 1 shard
├── full_model/               # Consolidated model for inference
│   └── full-model.pt         # Standard PyTorch checkpoint
├── rank*_memory_snapshot.pickle  # Memory profiling data
└── checkpoint_manager_snapshot.json
```

In [None]:
# List training artifacts
import os
storage_path = f"/mnt/cluster_storage/{experiment_name}/"
print(f"Artifacts in {storage_path}:")
for item in os.listdir(storage_path):
    full_path = os.path.join(storage_path, item)
    if os.path.isdir(full_path):
        print(f"  {item}/")
    else:
        size_mb = os.path.getsize(full_path) / (1024 * 1024)
        print(f"  {item} ({size_mb:.2f} MB)")

## `Step 8`: Load Model for Inference

The consolidated model (`full-model.pt`) can be loaded without FSDP2 for inference. This is a standard PyTorch checkpoint that works on any device.

In [None]:
# Path to the saved model
PATH_TO_FULL_MODEL = f"/mnt/cluster_storage/{experiment_name}/full_model/full-model.pt"
print(f"Loading model from: {PATH_TO_FULL_MODEL}")

In [None]:
# Load the trained model
inference_model = init_model()
state_dict = torch.load(PATH_TO_FULL_MODEL, map_location='cpu', weights_only=True)
inference_model.load_state_dict(state_dict)
inference_model.eval()
print("Model loaded successfully!")

In [None]:
# Load test dataset
transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
test_data = FashionMNIST(
    root="/tmp", train=False, download=True, transform=transform
)

# Class labels for FashionMNIST
CLASSES = [
    "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
    "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"
]

print(f"Test set size: {len(test_data)}")

In [None]:
# Run inference on sample images
print("\nSample predictions:")
print("-" * 50)

with torch.no_grad():
    correct = 0
    for i in range(10):
        image = test_data.data[i].reshape(1, 1, 28, 28).float()
        output = inference_model(image)
        predicted = output.argmax().item()
        actual = test_data.targets[i].item()
        correct += (predicted == actual)
        
        status = "✓" if predicted == actual else "✗"
        print(f"{status} Sample {i}: predicted={CLASSES[predicted]:12s} actual={CLASSES[actual]}")
    
    print("-" * 50)
    print(f"Accuracy on samples: {correct}/10 ({correct*10}%)")

## Summary

In this tutorial, we learned how to:

1. **Configure FSDP2** with memory optimization strategies:
   - CPU offloading for reduced GPU memory
   - Mixed precision (FP16) for faster training
   - Resharding after forward pass

2. **Integrate with Ray Train** for distributed training:
   - `TorchTrainer` for multi-GPU/multi-node training
   - Automatic process group initialization
   - Fault tolerance with checkpoint recovery

3. **Use PyTorch Distributed Checkpoint (DCP)**:
   - Sharded saving for parallel I/O
   - Automatic resharding on load
   - Model consolidation for inference

4. **Profile GPU memory** using PyTorch's memory snapshot API

### Next Steps

- **Scale up**: Try more workers or larger models
- **Hybrid parallelism**: Combine FSDP2 with tensor/pipeline parallelism
- **Production deployment**: Use cloud storage (S3, GCS) for checkpoints
- **Hyperparameter tuning**: Integrate with Ray Tune
- **Try DeepSpeed**: See the [DeepSpeed tutorial](./DeepSpeed_RayTrain_Tutorial.ipynb) for an alternative approach

### Resources

- [PyTorch FSDP2 Tutorial](https://docs.pytorch.org/tutorials/intermediate/FSDP_tutorial.html)
- [Ray Train Documentation](https://docs.ray.io/en/latest/train/getting-started-pytorch.html)
- [PyTorch DCP Guide](https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html)