# Get Started with PyTorch FSDP2 and Ray Train

This notebook demonstrates how to train large models using PyTorch's Fully Sharded Data Parallel (FSDP2) with Ray Train. FSDP2 enables model sharding across multiple GPUs, reducing memory footprint compared to standard DDP.

**Learning Objectives:**
1. Configure FSDP2 sharding for distributed training
2. Use PyTorch Distributed Checkpoint (DCP) for sharded model checkpointing
3. Load trained models for inference

## What is FSDP2?

[FSDP2](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html) is PyTorch's native solution for training large models:

- Shards model parameters, gradients, and optimizer states across workers
- All-gathers parameters during forward pass, then re-shards
- Enables training models larger than single GPU memory

**When to use FSDP2:**
- Model exceeds single GPU memory
- You want native PyTorch integration
- Building custom training loops

## Prerequisites

- Ray cluster with GPU workers (this example uses 2x GPUs)
- PyTorch 2.0+ with CUDA support
- Shared storage accessible from all workers (e.g., `/mnt/cluster_storage/`)

## Step 1: Environment Setup

Check Ray cluster status and install dependencies.

In [1]:
# Check Ray cluster status
!ray status

Node status
---------------------------------------------------------------
Active:
 1 head
 1 1xL4:16CPU-64GB-2
Idle:
 1 1xL4:16CPU-64GB-1
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Total Usage:
 0.0/32.0 CPU
 0.0/2.0 GPU
 0.0/2.0 anyscale/accelerator_shape:1xL4
 0.0/1.0 anyscale/cpu_only:true
 0.0/1.0 anyscale/node-group:1xL4:16CPU-64GB-1
 0.0/1.0 anyscale/node-group:1xL4:16CPU-64GB-2
 0.0/1.0 anyscale/node-group:head
 0.0/3.0 anyscale/provider:aws
 0.0/3.0 anyscale/region:us-west-2
 0B/160.00GiB memory
 16.30KiB/44.64GiB object_store_memory

From request_resources:
 (none)
Pending Demands:
 (no resource demands)
[0m

In [2]:
%%bash
pip install -q torch torchvision


[93m#################

Local packages tensorboard are not supported across cluster, please check our documentations for workarounds: https://docs.anyscale.com/configuration/dependency-management/dependency-development

#################[0m




[92mSuccessfully registered `torch, torchvision` packages to be installed on all cluster nodes.[0m
[92mView and update dependencies here: https://console.anyscale.com/cld_g54aiirwj1s8t9ktgzikqur41k/prj_f1j47h9srml4cyg962id75ms2e/workspaces/expwrk_p5rbudbzwfjvieqiireatn2pzp?workspace-tab=dependencies[0m


In [3]:
# Verify installation and check versions
import torch
import ray

print(f"PyTorch version: {torch.__version__}")
print(f"Ray version: {ray.__version__}")

PyTorch version: 2.10.0+cu128
Ray version: 2.53.0


In [4]:
# Setup
import os
os.environ["RAY_TRAIN_V2_ENABLED"] = "1"

import tempfile
import uuid
import torch

## Step 2: Model Definition

We use a Vision Transformer (ViT) with repeatable encoder blocks - ideal for demonstrating FSDP2's per-layer sharding.

In [5]:
from torchvision.models import VisionTransformer
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose

def init_model():
    """Initialize Vision Transformer for FashionMNIST (28x28 grayscale, 10 classes)."""
    model = VisionTransformer(
        image_size=28, patch_size=7, num_layers=10, num_heads=2,
        hidden_dim=128, mlp_dim=128, num_classes=10,
    )
    # Modify for grayscale input
    model.conv_proj = torch.nn.Conv2d(1, 128, kernel_size=7, stride=7)
    return model

# Verify model
test_model = init_model()
print(f"Model parameters: {sum(p.numel() for p in test_model.parameters()):,}")
del test_model

Model parameters: 1,006,090


## Step 3: FSDP2 Sharding Configuration

FSDP2's `fully_shard` API shards model parameters across workers:

- **Device Mesh**: Describes cluster topology for data parallelism
- **Per-layer sharding**: Shard each encoder block individually for fine-grained memory control
- **Resharding**: Option to free all-gathered weights after forward pass

In [6]:
from torch.distributed.fsdp import fully_shard
from torch.distributed.device_mesh import init_device_mesh
import ray.train

def shard_model(model):
    """Apply FSDP2 sharding to the model."""
    world_size = ray.train.get_context().get_world_size()
    
    # Create device mesh for data parallelism
    mesh = init_device_mesh("cuda", (world_size,), mesh_dim_names=("dp",))
    
    # Shard each encoder block individually
    for block in model.encoder.layers.children():
        fully_shard(block, mesh=mesh, reshard_after_forward=True)
    
    # Shard the root model
    fully_shard(model, mesh=mesh, reshard_after_forward=True)

## Step 4: Distributed Checkpointing

PyTorch Distributed Checkpoint (DCP) provides efficient checkpointing for sharded models:
- Each worker saves only its shard (parallel I/O)
- Automatic resharding on load if worker count changes
- Full optimizer state support for training resumption

In [7]:
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict, get_model_state_dict, StateDictOptions
from torch.distributed.checkpoint.stateful import Stateful
import torch.distributed.checkpoint as dcp

class AppState(Stateful):
    """Wrapper for DCP checkpointing."""
    def __init__(self, model, optimizer=None, epoch=None):
        self.model, self.optimizer, self.epoch = model, optimizer, epoch

    def state_dict(self):
        model_sd, optim_sd = get_state_dict(self.model, self.optimizer)
        return {"model": model_sd, "optim": optim_sd, "epoch": self.epoch}

    def load_state_dict(self, state_dict):
        set_state_dict(self.model, self.optimizer,
                      model_state_dict=state_dict["model"],
                      optim_state_dict=state_dict["optim"])
        self.epoch = state_dict.get("epoch")

In [8]:
def load_checkpoint(model, optimizer, ckpt):
    """Load FSDP checkpoint (handles resharding automatically)."""
    with ckpt.as_directory() as ckpt_dir:
        app_state = AppState(model, optimizer)
        dcp.load(state_dict={"app": app_state}, checkpoint_id=ckpt_dir)
    return app_state.epoch

In [9]:
def save_checkpoint(model, optimizer, metrics, epoch):
    """Save FSDP checkpoint and report metrics."""
    with tempfile.TemporaryDirectory() as tmp_dir:
        dcp.save(state_dict={"app": AppState(model, optimizer, epoch)}, checkpoint_id=tmp_dir)
        ray.train.report(metrics, checkpoint=ray.train.Checkpoint.from_directory(tmp_dir))

In [10]:
def save_model_for_inference(model, world_rank):
    """Consolidate sharded model for inference (rank 0 saves full model)."""
    with tempfile.TemporaryDirectory() as tmp_dir:
        model_sd = get_model_state_dict(model, options=StateDictOptions(full_state_dict=True, cpu_offload=True))
        ckpt = None
        if world_rank == 0:
            torch.save(model_sd, os.path.join(tmp_dir, "full-model.pt"))
            ckpt = ray.train.Checkpoint.from_directory(tmp_dir)
        ray.train.report({}, checkpoint=ckpt, checkpoint_dir_name="full_model")

## Step 5: Training Function

The training function runs on each worker:
1. Initialize and shard model with FSDP2
2. Run training loop with distributed data loading
3. Save checkpoints using PyTorch DCP

In [11]:
import ray.train.torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader

def train_func(config):
    """FSDP2 training function."""
    # Model setup
    model = init_model()
    device = ray.train.torch.get_device()
    torch.cuda.set_device(device)
    model.to(device)
    shard_model(model)
    
    # Training setup
    criterion = CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=config.get('lr', 0.001))
    
    # Resume from checkpoint if available
    start_epoch = 0
    if ray.train.get_checkpoint():
        start_epoch = load_checkpoint(model, optimizer, ray.train.get_checkpoint()) + 1
    
    # Data loading
    transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
    train_data = FashionMNIST(root=tempfile.gettempdir(), train=True, download=True, transform=transform)
    train_loader = DataLoader(train_data, batch_size=config.get('batch_size', 64), shuffle=True)
    train_loader = ray.train.torch.prepare_data_loader(train_loader)
    
    # Context
    world_rank = ray.train.get_context().get_world_rank()
    
    # Training loop
    for epoch in range(start_epoch, config.get('epochs', 1)):
        if ray.train.get_context().get_world_size() > 1:
            train_loader.sampler.set_epoch(epoch)
        
        total_loss, num_batches = 0.0, 0
        for images, labels in train_loader:
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            num_batches += 1
        
        avg_loss = total_loss / num_batches
        save_checkpoint(model, optimizer, {"loss": avg_loss, "epoch": epoch}, epoch)
        if world_rank == 0:
            print(f"Epoch {epoch}: loss={avg_loss:.4f}")
    
    # Save final model for inference
    save_model_for_inference(model, world_rank)

## Step 6: Launch Distributed Training

Ray Train's `TorchTrainer` handles worker spawning, process group initialization, and checkpoint coordination.

In [12]:
import ray.train.torch

# Configuration
experiment_name = f"fsdp_{uuid.uuid4().hex[:8]}"
scaling_config = ray.train.ScalingConfig(num_workers=2, use_gpu=True)
run_config = ray.train.RunConfig(storage_path="/mnt/cluster_storage/", name=experiment_name)
train_config = {"epochs": 1, "lr": 0.001, "batch_size": 64}

print(f"Experiment: {experiment_name}")

Experiment: fsdp_b2f564ce


In [13]:
# Run training
trainer = ray.train.torch.TorchTrainer(
    train_loop_per_worker=train_func,
    scaling_config=scaling_config,
    train_loop_config=train_config,
    run_config=run_config,
)
result = trainer.fit()
print(f"Training complete! Checkpoint: {result.checkpoint}")

2026-02-02 06:50:28,293	INFO worker.py:1821 -- Connecting to existing Ray cluster at address: 10.0.140.201:6379...
2026-02-02 06:50:28,305	INFO worker.py:1998 -- Connected to Ray cluster. View the dashboard at [1m[32mhttps://session-ffbqdd398vb4g8i97u3tsubr23.i.anyscaleuserdata.com [39m[22m
2026-02-02 06:50:28,308	INFO packaging.py:463 -- Pushing file package 'gcs://_ray_pkg_490ef89b3fa7416cf8cfa71609535e02c19cb158.zip' (0.58MiB) to Ray cluster...
2026-02-02 06:50:28,311	INFO packaging.py:476 -- Successfully pushed file package 'gcs://_ray_pkg_490ef89b3fa7416cf8cfa71609535e02c19cb158.zip'.
[36m(TrainController pid=82530)[0m [State Transition] INITIALIZING -> SCHEDULING.
[36m(TrainController pid=82530)[0m Attempting to start training worker group of size 2 with the following resources: [{'GPU': 1}] * 2
[36m(TrainController pid=82530)[0m [FailurePolicy] RETRY
[36m(TrainController pid=82530)[0m   Source: controller
[36m(TrainController pid=82530)[0m   Error count: 1 (max all

Training complete! Checkpoint: Checkpoint(filesystem=local, path=/mnt/cluster_storage/fsdp_b2f564ce/full_model)


[36m(TrainController pid=82530)[0m [State Transition] SHUTTING_DOWN -> FINISHED.


## Step 7: Inspect Training Artifacts

Training artifacts include:
- `checkpoint_*/` - Epoch checkpoints with distributed shards
- `full_model/` - Consolidated model for inference

In [14]:
# List artifacts
storage_path = f"/mnt/cluster_storage/{experiment_name}/"
print(f"Artifacts in {storage_path}:")
for item in sorted(os.listdir(storage_path)):
    print(f"  {item}/" if os.path.isdir(os.path.join(storage_path, item)) else f"  {item}")

Artifacts in /mnt/cluster_storage/fsdp_b2f564ce/:
  .validate_storage_marker
  checkpoint_2026-02-02_06-52-14.180406/
  checkpoint_manager_snapshot.json
  full_model/


## Step 8: Load Model for Inference

The consolidated model (`full-model.pt`) is a standard PyTorch checkpoint that works without FSDP2.

In [15]:
# Load model for inference
model_path = f"/mnt/cluster_storage/{experiment_name}/full_model/full-model.pt"
print(f"Loading from: {model_path}")

Loading from: /mnt/cluster_storage/fsdp_b2f564ce/full_model/full-model.pt


In [16]:
inference_model = init_model()
inference_model.load_state_dict(torch.load(model_path, map_location='cpu', weights_only=True))
inference_model.eval()
print("Model loaded.")

Model loaded.


In [17]:
# Test inference
test_data = FashionMNIST(root="/tmp", train=False, download=True, transform=Compose([ToTensor(), Normalize((0.5,), (0.5,))]))
with torch.no_grad():
    sample = test_data.data[0].reshape(1, 1, 28, 28).float()
    output = inference_model(sample)
print(f"Inference output shape: {output.shape}")

Inference output shape: torch.Size([1, 10])


## Summary

This tutorial covered:
1. **FSDP2 sharding** - Distributed model parameters across GPUs using `fully_shard()`
2. **Ray Train integration** - Multi-GPU training with automatic process group management
3. **PyTorch DCP** - Sharded checkpointing with automatic resharding on load
4. **Inference** - Loading consolidated model for single-GPU inference

**Next Steps:**
- Add CPU offloading: `CPUOffloadPolicy()` for memory-constrained scenarios
- Add mixed precision: `MixedPrecisionPolicy(param_dtype=torch.float16)`
- Try [DeepSpeed tutorial](./DeepSpeed_RayTrain_Tutorial.ipynb) for comparison

**Resources:**
- [PyTorch FSDP Tutorial](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html)
- [Ray Train Documentation](https://docs.ray.io/en/latest/train/getting-started-pytorch.html)