# Get Started with PyTorch FSDP2 and Ray Train



This notebook demonstrates how to train large models using PyTorch's Fully Sharded Data Parallel (FSDP2) with Ray Train. FSDP2 enables model sharding across multiple GPUs, reducing memory footprint compared to standard DDP.

**Learning Objectives:**
1. Configure FSDP2 sharding for distributed training
2. Use PyTorch Distributed Checkpoint (DCP) for sharded model checkpointing
3. Load trained models for inference


This notebook will walk you through a high level overview of using FSDP with Ray Train.

<div class="alert alert-block alert-info">

Here is the roadmap for this notebook:

<ol>
  <li>What is FSDP?</li>
  <li>FSDP vs DDP simplified</li>
  <li>How to use FSDP (v2) with Ray Train?</li>
</ol>
</div>


## What is FSDP2?

[FSDP2](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html) is PyTorch's native solution for training large models:

- Shards model parameters, gradients, and optimizer states across workers
- All-gathers parameters during forward pass, then re-shards
- Enables training models larger than single GPU memory

**When to use FSDP2:**
- Model exceeds single GPU memory
- You want native PyTorch integration
- Building custom training loops

### FSDP Workflow

Below is a diagram (taken and adapted from PyTorch) that shows the FSDP workflow

<img src="https://anyscale-materials.s3.us-west-2.amazonaws.com/ray-train-deep-dive/FSDP.png" width="800">


Here is a table explaining the different phases, their steps and what happens:

<table>
  <thead>
    <tr>
      <th>Phase</th>
      <th>Step</th>
      <th>What Happens</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>Initialization</td>
      <td>Parameter sharding</td>
      <td>Each rank stores only its own shard of every parameter it owns.</td>
    </tr>
    <tr>
      <td>Forward pass</td>
      <td>
        <ol>
          <li>all_gather</li>
          <li>Compute</li>
          <li>Free full weights</li>
        </ol>
      </td>
      <td>Ranks gather one another’s shards to reconstruct full weights(parameters), execute the forward computation, then immediately free the temporary shards.</td>
    </tr>
    <tr>
      <td>Backward pass</td>
      <td>
        <ol>
          <li>all_gather</li>
          <li>Back-propagate</li>
          <li>reduce_scatter</li>
          <li>Free full weights</li>
        </ol>
      </td>
      <td>Shards are re-gathered, gradients are computed, then reduced-and-scattered so each rank keeps only its own gradient shard; full replicas are discarded again.</td>
    </tr>
  </tbody>
</table>


## FSDP vs DDP simplified

Let's go over a toy example  (inspired by [this guide on parallelism from huggingface](https://huggingface.co/docs/transformers/v4.13.0/en/parallelism)) comparing how FSDP and DDP operate.

### 1  Toy model

| **La** | **Lb** | **Lc** |
|:------:|:------:|:------:|
 a₀ | b₀ | c₀ |
 a₁ | b₁ | c₁ |
 a₂ | b₂ | c₂ |

Layer `La` contains the weights `[a₀, a₁, a₂]`

Total parameters = **9 scalars** (3 per layer).

---

### 2  Parameter layout at rest  

#### 2.1  DDP  (full replication)

| **GPU** | **La**           | **Lb**           | **Lc**           |
|---------|------------------|------------------|------------------|
| 0       | a₀ a₁ a₂         | b₀ b₁ b₂         | c₀ c₁ c₂         |
| 1       | a₀ a₁ a₂         | b₀ b₁ b₂         | c₀ c₁ c₂         |
| 2       | a₀ a₁ a₂         | b₀ b₁ b₂         | c₀ c₁ c₂         |

*Each worker stores **100 %** of the model (parameters + optimizer states + gradients).*

---

#### 2.2  FSDP  (parameter sharding)

| **GPU** | **La** | **Lb** | **Lc** |
|---------|:------:|:------:|:------:|
| 0       | a₀     | b₀     | c₀     |
| 1       | a₁     | b₁     | c₁     |
| 2       | a₂     | b₂     | c₂     |

*At rest each worker keeps only **1∕3** of every layer—so memory ≈ 33 % of DDP.*

---

### 3  Execution flow (GPU 0’s perspective)

#### 3.1  FSDP

<details>
<summary><strong>Forward pass (Layer La)</strong></summary>

1. **Need:** a₀ a₁ a₂  
2. **Has locally:** a₀  
3. **Gather:** a₁ from GPU 1, a₂ from GPU 2  
4. **Compute:** *y = La(x)* with full weights  
5. **Free:** a₁, a₂ (optional halfway-release)

Repeat for **Lb** then **Lc**.
</details>

<details>
<summary><strong>Backward pass (Layer Lc)</strong></summary>

1. **Need:** c₀ c₁ c₂  
2. **Has locally:** c₀  
3. **Gather:** c₁ from GPU 1, c₂ from GPU 2  
4. **Compute:** ∂L/∂c₀, ∂L/∂c₁, ∂L/∂c₂  
5. **Reduce-scatter:** average gradients; GPU k keeps only its own shard (cₖ)  
6. **Free:** c₁, c₂ (optional halfway-release)

Work upstream through **Lb → La** in reverse order.
</details>

<details>
<summary><strong>Optimizer step</strong></summary>

*Purely local.*  
GPU 0 updates a₀, b₀, c₀ using the averaged gradient shards already resident in memory. No extra communication.
</details>

---

#### 3.2  DDP  (for comparison)

| Phase             | What GPU 0 already owns         | Communication |
|-------------------|---------------------------------|---------------|
| **Forward**       | Full La, Lb, Lc                 | None          |
| **Backward**      | Full La, Lb, Lc + grads         | **All-reduce**|
| **Optimizer step**| Full params & full grads        | None          |

---

### 4  Key take-aways

|                         | **DDP**                           | **FSDP**                               |
|-------------------------|-----------------------------------|----------------------------------------|
| **Memory footprint**    | Replicated (× #GPUs)              | 1∕#GPUs at rest; slightly higher during gather |
| **Communication**       | All-reduce once per layer (backward) | Gather + reduce-scatter per layer      |
| **Optimizer states**    | Fully replicated                  | Sharded—1∕#GPUs memory|
| **Implementation ease** | Very simple                       | More knobs (wrapping policy, offloading,prefetch) |
| **When to prefer**      | Fits in memory; small models      | Large models that would otherwise OOM  |

> **Rule of thumb:**  
> *If you can replicate the model comfortably, DDP wins on simplicity and sometimes speed.  
> If you’re out of memory or pushing model scale, FSDP (or ZeRO-style sharding) is the way forward.*

---

#### 5  Why we wrapped each layer separately

For illustration we treated **“one layer = one FSDP unit.”**  
In practice you can:

* **Cluster layers** into larger FSDP units to reduce the number of gather calls.
* **Wrap only the largest sub-modules** (e.g., big embeddings, attention blocks) and leave tiny layers unsharded.

Experiment with *auto-wrap policies* to find the sweet spot between memory savings and communication overhead.


## When to Consider FSDP

- **Model no longer fits** on a single GPU even with mixed precision.  
- **Batch size is GPU memory-bound** under classic DDP.  
- **You have multiple GPUs** with sufficient interconnect bandwidth.  
- **You already use DDP** but need to push to larger architectures (e.g., multi-billion-parameter transformers).  
- **You want minimal code changes**—wrap layers with `torch.distributed.fsdp.FullyShardedDataParallel`.  

FSDP lets you step beyond the memory limits of traditional data parallelism while keeping your training loop largely unchanged.

### What is FSDP?

Fully Sharded Data Parallel (FSDP) is a parallelism method that combines the advantages of data and model parallelism for distributed training.

### FSDP Workflow

Below is a diagram (taken and adapted from PyTorch) that shows the FSDP workflow

<img src="https://anyscale-materials.s3.us-west-2.amazonaws.com/ray-train-deep-dive/FSDP.png" width="800">


Here is a table explaining the different phases, their steps and what happens:

<table>
  <thead>
    <tr>
      <th>Phase</th>
      <th>Step</th>
      <th>What Happens</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>Initialization</td>
      <td>Parameter sharding</td>
      <td>Each rank stores only its own shard of every parameter it owns.</td>
    </tr>
    <tr>
      <td>Forward pass</td>
      <td>
        <ol>
          <li>all_gather</li>
          <li>Compute</li>
          <li>Free full weights</li>
        </ol>
      </td>
      <td>Ranks gather one another’s shards to reconstruct full weights(parameters), execute the forward computation, then immediately free the temporary shards.</td>
    </tr>
    <tr>
      <td>Backward pass</td>
      <td>
        <ol>
          <li>all_gather</li>
          <li>Back-propagate</li>
          <li>reduce_scatter</li>
          <li>Free full weights</li>
        </ol>
      </td>
      <td>Shards are re-gathered, gradients are computed, then reduced-and-scattered so each rank keeps only its own gradient shard; full replicas are discarded again.</td>
    </tr>
  </tbody>
</table>


### FSDP vs DDP simplified

Let's go over a toy example  (inspired by [this guide on parallelism from huggingface](https://huggingface.co/docs/transformers/v4.13.0/en/parallelism)) comparing how FSDP and DDP operate.

### 1  Toy model

| **La** | **Lb** | **Lc** |
|:------:|:------:|:------:|
 a₀ | b₀ | c₀ |
 a₁ | b₁ | c₁ |
 a₂ | b₂ | c₂ |

Layer `La` contains the weights `[a₀, a₁, a₂]`

Total parameters = **9 scalars** (3 per layer).

---

### 2  Parameter layout at rest  

#### 2.1  DDP  (full replication)

| **GPU** | **La**           | **Lb**           | **Lc**           |
|---------|------------------|------------------|------------------|
| 0       | a₀ a₁ a₂         | b₀ b₁ b₂         | c₀ c₁ c₂         |
| 1       | a₀ a₁ a₂         | b₀ b₁ b₂         | c₀ c₁ c₂         |
| 2       | a₀ a₁ a₂         | b₀ b₁ b₂         | c₀ c₁ c₂         |

*Each worker stores **100 %** of the model (parameters + optimizer states + gradients).*

---

#### 2.2  FSDP  (parameter sharding)

| **GPU** | **La** | **Lb** | **Lc** |
|---------|:------:|:------:|:------:|
| 0       | a₀     | b₀     | c₀     |
| 1       | a₁     | b₁     | c₁     |
| 2       | a₂     | b₂     | c₂     |

*At rest each worker keeps only **1∕3** of every layer—so memory ≈ 33 % of DDP.*

---

### 3  Execution flow (GPU 0’s perspective)

#### 3.1  FSDP

<details>
<summary><strong>Forward pass (Layer La)</strong></summary>

1. **Need:** a₀ a₁ a₂  
2. **Has locally:** a₀  
3. **Gather:** a₁ from GPU 1, a₂ from GPU 2  
4. **Compute:** *y = La(x)* with full weights  
5. **Free:** a₁, a₂ (optional halfway-release)

Repeat for **Lb** then **Lc**.
</details>

<details>
<summary><strong>Backward pass (Layer Lc)</strong></summary>

1. **Need:** c₀ c₁ c₂  
2. **Has locally:** c₀  
3. **Gather:** c₁ from GPU 1, c₂ from GPU 2  
4. **Compute:** ∂L/∂c₀, ∂L/∂c₁, ∂L/∂c₂  
5. **Reduce-scatter:** average gradients; GPU k keeps only its own shard (cₖ)  
6. **Free:** c₁, c₂ (optional halfway-release)

Work upstream through **Lb → La** in reverse order.
</details>

<details>
<summary><strong>Optimizer step</strong></summary>

*Purely local.*  
GPU 0 updates a₀, b₀, c₀ using the averaged gradient shards already resident in memory. No extra communication.
</details>

---

#### 3.2  DDP  (for comparison)

| Phase             | What GPU 0 already owns         | Communication |
|-------------------|---------------------------------|---------------|
| **Forward**       | Full La, Lb, Lc                 | None          |
| **Backward**      | Full La, Lb, Lc + grads         | **All-reduce**|
| **Optimizer step**| Full params & full grads        | None          |

---

### 4  Key take-aways

|                         | **DDP**                           | **FSDP**                               |
|-------------------------|-----------------------------------|----------------------------------------|
| **Memory footprint**    | Replicated (× #GPUs)              | 1∕#GPUs at rest; slightly higher during gather |
| **Communication**       | All-reduce once per layer (backward) | Gather + reduce-scatter per layer      |
| **Optimizer states**    | Fully replicated                  | Sharded—1∕#GPUs memory|
| **Implementation ease** | Very simple                       | More knobs (wrapping policy, offloading,prefetch) |
| **When to prefer**      | Fits in memory; small models      | Large models that would otherwise OOM  |

> **Rule of thumb:**  
> *If you can replicate the model comfortably, DDP wins on simplicity and sometimes speed.  
> If you’re out of memory or pushing model scale, FSDP (or ZeRO-style sharding) is the way forward.*

---

#### 5  Why we wrapped each layer separately

For illustration we treated **“one layer = one FSDP unit.”**  
In practice you can:

* **Cluster layers** into larger FSDP units to reduce the number of gather calls.
* **Wrap only the largest sub-modules** (e.g., big embeddings, attention blocks) and leave tiny layers unsharded.

Experiment with *auto-wrap policies* to find the sweet spot between memory savings and communication overhead.


### When to Consider FSDP

- **Model no longer fits** on a single GPU even with mixed precision.  
- **Batch size is GPU memory-bound** under classic DDP.  
- **You have multiple GPUs** with sufficient interconnect bandwidth.  
- **You already use DDP** but need to push to larger architectures (e.g., multi-billion-parameter transformers).  
- **You want minimal code changes**—wrap layers with `torch.distributed.fsdp.FullyShardedDataParallel`.  

FSDP lets you step beyond the memory limits of traditional data parallelism while keeping your training loop largely unchanged.

## Step 1: Environment Setup

Check Ray cluster status and install dependencies.

In [1]:
# Check Ray cluster status
!ray status

Node status
---------------------------------------------------------------
Active:
 1 head
 1 1xL4:16CPU-64GB-2
Idle:
 1 1xL4:16CPU-64GB-1
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Total Usage:
 0.0/32.0 CPU
 0.0/2.0 GPU
 0.0/2.0 anyscale/accelerator_shape:1xL4
 0.0/1.0 anyscale/cpu_only:true
 0.0/1.0 anyscale/node-group:1xL4:16CPU-64GB-1
 0.0/1.0 anyscale/node-group:1xL4:16CPU-64GB-2
 0.0/1.0 anyscale/node-group:head
 0.0/3.0 anyscale/provider:aws
 0.0/3.0 anyscale/region:us-west-2
 0B/160.00GiB memory
 10.41KiB/44.64GiB object_store_memory

From request_resources:
 (none)
Pending Demands:
 (no resource demands)
[0m

In [2]:
# Stdlib imports
import os
import tempfile

# Ray Train imports
import ray
import ray.train
import ray.train.torch

# PyTorch core and FSDP2 imports
import torch
from torch.distributed.fsdp import (
    fully_shard,
    FSDPModule,
    CPUOffloadPolicy,
    MixedPrecisionPolicy,
)

# PyTorch Distributed Checkpoint (DCP) imports
from torch.distributed.checkpoint.state_dict import (
    get_state_dict,
    set_state_dict,
    get_model_state_dict,
    StateDictOptions
)
from torch.distributed.device_mesh import init_device_mesh 
from torch.distributed.checkpoint.stateful import Stateful
import torch.distributed.checkpoint as dcp

# PyTorch training components
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader

# Computer vision components
from torchvision.models import VisionTransformer
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose

## Step 2: Model Definition

We use a Vision Transformer (ViT) with repeatable encoder blocks - ideal for demonstrating FSDP2's per-layer sharding.

Let's go over a simple example of how to use FSDP with Ray Train and PyTorch.

Below is a sample training function that we will use to train our model.

In [3]:
def train_func(config):
    # Step 1: Initialize the model
    model = init_model(config["hidden_dim"])

    # Configure device and move model to GPU
    device = ray.train.torch.get_device()
    torch.cuda.set_device(device)
    model.to(device)

    # Step 2: Apply FSDP2 sharding to the model
    prepare_model(
        model,
        skip_model_shard=config["skip_model_shard"],
        skip_cpu_offload=config["skip_cpu_offload"],
        use_float16=config["use_float16"],
    )

    # Step 3: Initialize loss function and optimizer
    criterion = CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=config.get("learning_rate", 0.001))

    # Step 4: Load from checkpoint if available (for resuming training)
    start_epoch = 0
    loaded_checkpoint = ray.train.get_checkpoint()
    if loaded_checkpoint:
        start_epoch = load_fsdp_checkpoint(model, optimizer, loaded_checkpoint)

    # Step 5: Prepare training data
    transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
    data_dir = os.path.join(tempfile.gettempdir(), "data")
    train_data = FashionMNIST(
        root=data_dir, train=False, download=True, transform=transform
    )
    train_loader = DataLoader(
        train_data, batch_size=config.get("batch_size", 128), shuffle=True, num_workers=2
    )
    # Prepare data loader for distributed training
    train_loader = ray.train.torch.prepare_data_loader(train_loader)

    world_rank = ray.train.get_context().get_world_rank()

    # Step 6: Main training loop
    running_loss = 0.0
    num_batches = 0
    epochs = config["epochs"]

    for epoch in range(start_epoch, epochs):
        # Set epoch for distributed sampler to ensure proper shuffling
        if ray.train.get_context().get_world_size() > 1:
            train_loader.sampler.set_epoch(epoch)

        for images, labels in train_loader:
            # Note: Data is automatically moved to the correct device by prepare_data_loader
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Standard training step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Track metrics
            running_loss += loss.item()
            num_batches += 1

        # Step 7: Report metrics and save checkpoint after each epoch
        avg_loss = running_loss / num_batches
        metrics = {"loss": avg_loss, "epoch": epoch + 1}
        report_metrics_and_save_fsdp_checkpoint(model, optimizer, metrics, epoch + 1)

        # Log metrics from rank 0 only to avoid duplicate outputs
        if world_rank == 0:
            print(metrics)

    # Step 8: Save the final model for inference
    save_model_for_inference(model, world_rank)


### Initialize the model
Initialize a Vision Transformer model for FashionMNIST classification

In [4]:
def init_model(hidden_dim) -> torch.nn.Module:
    # Create a ViT model with architecture suitable for 28x28 images
    model = VisionTransformer(
        image_size=28,         # FashionMNIST image size
        patch_size=7,          # Divide 28x28 into 4x4 patches of 7x7 pixels each
        num_layers=12,         # Number of transformer encoder layers
        num_heads=8,           # Number of attention heads per layer
        hidden_dim=hidden_dim, # Hidden dimension size
        mlp_dim=768,           # MLP dimension in transformer blocks
        num_classes=10,        # FashionMNIST has 10 classes
    )

    # Modify the patch embedding layer for grayscale images (1 channel instead of 3)
    model.conv_proj = torch.nn.Conv2d(
        in_channels=1,            # FashionMNIST is grayscale (1 channel)
        out_channels=hidden_dim,  # Must match the hidden_dim
        kernel_size=7,            # Match patch_size
        stride=7,                 # Non-overlapping patches
    )

    return model

## Step 3: Prepare the model and FSDP2 Sharding Configuration

To prepare the model, we use the `fully_shard` (FSDP2) function from pytorch but we have to first:
1. Create a device mesh
2. Apply `fully_shard` to each block we want to shard
3. Apply `fully_shard` to the model itself

In [5]:
def prepare_model(
    model: torch.nn.Module,
    skip_model_shard: bool,
    skip_cpu_offload: bool,
    use_float16: bool,
):
    # Step 1: Create 1D device mesh for data parallel sharding
    world_size = ray.train.get_context().get_world_size()
    mesh = init_device_mesh(
        device_type="cuda", mesh_shape=(world_size,), mesh_dim_names=("data_parallel",)
    )

    # Step 2: Configure CPU offloading policy (optional)
    offload_policy = CPUOffloadPolicy() if not skip_cpu_offload else None

    # Step 3: Configure mixed precision policy (optional)
    mp_policy_float16 = MixedPrecisionPolicy(
        param_dtype=torch.float16,  # Store parameters in half precision
        reduce_dtype=torch.float16,  # Use half precision for gradient reduction
    )
    default_policy = MixedPrecisionPolicy(
        param_dtype=None,
        reduce_dtype=None,
        output_dtype=None,
        cast_forward_inputs=True
    )
    mp_policy = mp_policy_float16 if use_float16 else default_policy

    # Step 4: Apply sharding to each transformer encoder block
    for encoder_block in model.encoder.layers.children():
        fully_shard(
            encoder_block,
            mesh=mesh,
            reshard_after_forward=not skip_model_shard,
            offload_policy=offload_policy,
            mp_policy=mp_policy,
        )

    # Step 5: Apply sharding to the root model
    # This wraps the entire model and enables top-level FSDP2 functionality
    fully_shard(
        model,
        mesh=mesh,
        reshard_after_forward=not skip_model_shard,
        offload_policy=offload_policy,
        mp_policy=mp_policy,
    )

Here is a table of the main keyword arguments for FSDP:

| Parameter | What it controls | Typical values & when to use |
|-----------|-----------------|------------------------------|
| `reshard_after_forward` | Whether to free (reshard) parameters immediately after forward pass to save memory | **`True`** – free parameters after forward pass for maximum memory savings; increases communication overhead.<br>**`False`** (default) – keep parameters in memory through backward pass; faster but uses more memory. |
| `offload_policy` | Whether inactive parameter shards are moved to CPU RAM | **`None`** (default) – fastest, keeps all data on GPU.<br>**`CPUOffloadPolicy()`** – offloads parameters to CPU RAM; frees GPU memory at the cost of extra PCIe traffic; enables larger models on small-RAM GPUs. |
| `mp_policy` | Controls mixed precision settings for parameters and gradients | **`None`** (default) – use model's native precision (typically float32).<br>**`MixedPrecisionPolicy(param_dtype=torch.float16)`** – store parameters in half precision for memory savings and to leverage tensor cores.<br>**`MixedPrecisionPolicy(param_dtype=torch.bfloat16)`** – use bfloat16 for newer architectures (A100, H100) with better numerical stability and tensor core acceleration.<br>**`MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32)`** – commonly used configuration; bfloat16 parameters with float32 gradient reduction for better numerical stability during reduce-scatter operations.|


## Step 4: Distributed Checkpointing with PyTorch

`torch.distributed.checkpoint()` enables saving and loading models from multiple ranks in parallel. You can use this module to save on any number of ranks in parallel, and then re-shard across differing cluster topologies at load time.

PyTorch Distributed Checkpoint (DCP) provides efficient checkpointing for sharded models:
- Each worker saves only its shard (parallel I/O)
- Automatic resharding on load if worker count changes
- Full optimizer state support for training resumption


### Defining a Stateful object

We take advantage of a Stateful object to handle calling distributed state dict methods on the model and optimizer.

This is a useful wrapper for checkpointing the Application State. Since this object is compliant
with the Stateful protocol, PyTorch DCP will automatically call state_dict/load_stat_dict as needed in the
dcp.save/load APIs.

In [6]:
class AppState(Stateful):
    def __init__(self, model, optimizer=None, epoch=0):
        self.model = model
        self.optimizer = optimizer
        self.epoch = epoch

    def state_dict(self):
        # this line automatically manages FSDP2 FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
        model_state_dict, optimizer_state_dict = get_state_dict(
            self.model, self.optimizer
        )
        return {
            "model": model_state_dict,
            "optim": optimizer_state_dict,
            "epoch": self.epoch,
        }

    def load_state_dict(self, state_dict):
        # sets our state dicts on the model and optimizer, now that we've loaded
        set_state_dict(
            self.model,
            self.optimizer,
            model_state_dict=state_dict["model"],
            optim_state_dict=state_dict["optim"],
        )
        # Load epoch state
        self.epoch = state_dict["epoch"]


### Saving Distributed Checkpoints when training

This function performs two critical operations:
1. Saves the current model and optimizer state using distributed checkpointing
2. Reports metrics to Ray Train for tracking

In [7]:
def report_metrics_and_save_fsdp_checkpoint(
    model: FSDPModule, optimizer: torch.optim.Optimizer, metrics: dict, epoch: int
) -> None:

    with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
        # Perform a distributed checkpoint with DCP
        state_dict = {"app": AppState(model, optimizer, epoch)}
        dcp.save(state_dict=state_dict, checkpoint_id=temp_checkpoint_dir)

        # Report each checkpoint shard from all workers
        # This saves the checkpoint to shared cluster storage for persistence
        checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
        ray.train.report(metrics, checkpoint=checkpoint)

### Saving a final and full checkpoint for inference

For inference, we want to save an unsharded copy of the model. This is an expensive operation given all parameters need to gather all model parameters from all ranks onto a single rank.

This function consolidates the distributed model weights into a single checkpoint file that can be used for inference without FSDP.

In [8]:
def save_model_for_inference(model: FSDPModule, world_rank: int) -> None:
    with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
        save_file = os.path.join(temp_checkpoint_dir, "full-model.pt")

        # Step 1: All-gather the model state across all ranks
        # This reconstructs the complete model from distributed shards
        model_state_dict = get_model_state_dict(
            model=model,
            options=StateDictOptions(
                full_state_dict=True,  # Reconstruct full model
                cpu_offload=True,  # Move to CPU to save GPU memory
            ),
        )

        checkpoint = None

        # Step 2: Save the complete model (rank 0 only)
        if world_rank == 0:
            torch.save(model_state_dict, save_file)

            # Create checkpoint for shared storage
            checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)

        # Step 3: Report the final checkpoint to Ray Train
        ray.train.report(metrics={}, checkpoint=checkpoint, checkpoint_dir_name="full_model")

<div class="alert alert-info">

**NOTE:** In PyTorch, if both `cpu_offload` and `full_state_dict` are set to True, then only the rank0 will get the state_dict and all other ranks will get empty state_dict.

</div>

### Loading distributed checkpoints

This function handles distributed checkpoint loading with automatic resharding support. It can restore checkpoints even when the number of workers differs from the original training run.

In [9]:
def load_fsdp_checkpoint(
    model: FSDPModule, optimizer: torch.optim.Optimizer, ckpt: ray.train.Checkpoint
) -> int:
    try:
        with ckpt.as_directory() as checkpoint_dir:
            # Create state wrapper for DCP loading
            app_state = AppState(model, optimizer)
            state_dict = {"app": app_state}

            # Load the distributed checkpoint
            dcp.load(state_dict=state_dict, checkpoint_id=checkpoint_dir)

            # Return the update state's epoch
            return app_state.epoch

    except Exception as e:
        raise RuntimeError(f"Checkpoint loading failed: {e}") from e


## Step 5: Final Training Configuration 

The training function runs on each worker:
1. Initialize and shard model with FSDP2
2. Run training loop with distributed data loading
3. Save checkpoints using PyTorch DCP

Configure scaling and resource requirements.

In [10]:
scaling_config = ray.train.ScalingConfig(
    num_workers=2, use_gpu=True, resources_per_worker={"accelerator_type:L4": 0.0001}
)

Launch the distributed training job.

In [11]:
train_loop_config = {
    "epochs": 1,
    "learning_rate": 0.001,
    "batch_size": 128,
    "skip_model_shard": True,
    "skip_cpu_offload": True,
    "hidden_dim": 3840,
    "use_float16": False,
}



## Step 6: Launch Distributed Training

Ray Train's `TorchTrainer` handles worker spawning, process group initialization, and checkpoint coordination.

In [12]:
trainer = ray.train.torch.TorchTrainer(
    train_func,
    scaling_config=scaling_config,
    train_loop_config=train_loop_config,
    run_config=ray.train.RunConfig(
        storage_path="/mnt/cluster_storage/",
        name="fsdp_mnist",
        failure_config=ray.train.FailureConfig(max_failures=2),
        worker_runtime_env={
            "env_vars": {"KINETO_USE_DAEMON": "1", "KINETO_DAEMON_INIT_DELAY_S": "5"}
        },
    ),
)
result = trainer.fit()

2026-02-18 01:48:54,476	INFO worker.py:1821 -- Connecting to existing Ray cluster at address: 10.0.229.255:6379...
2026-02-18 01:48:54,487	INFO worker.py:1998 -- Connected to Ray cluster. View the dashboard at [1m[32mhttps://session-ffbqdd398vb4g8i97u3tsubr23.i.anyscaleuserdata.com [39m[22m
2026-02-18 01:48:54,684	INFO packaging.py:463 -- Pushing file package 'gcs://_ray_pkg_6462b622da9fc2f79ba5e96e1c91e3f0eea96575.zip' (82.47MiB) to Ray cluster...
2026-02-18 01:48:55,011	INFO packaging.py:476 -- Successfully pushed file package 'gcs://_ray_pkg_6462b622da9fc2f79ba5e96e1c91e3f0eea96575.zip'.
[36m(TrainController pid=59639)[0m [State Transition] INITIALIZING -> SCHEDULING.
[36m(TrainController pid=59639)[0m Attempting to start training worker group of size 2 with the following resources: [{'accelerator_type:L4': 0.0001, 'GPU': 1}] * 2
[36m(RayTrainWorker pid=12661, ip=10.0.210.195)[0m INFO:2026-02-18 01:49:01 12661:12661 init.cpp:148] Registering daemon config loader, cpuOnly =

## Step 7: Inspect Training Artifacts

Training artifacts include:
- `checkpoint_*/` - Epoch checkpoints with distributed shards
- `full_model/` - Consolidated model for inference

In [13]:
# List artifacts
storage_path = f"/mnt/cluster_storage/fsdp_mnist/"
print(f"Artifacts in {storage_path}:")
!ls -ltra $storage_path

[36m(TrainController pid=59639)[0m [State Transition] SHUTTING_DOWN -> FINISHED.


Artifacts in /mnt/cluster_storage/fsdp_mnist/:
total 24
drwxr-xr-x 22 ray  1000 6144 Feb 18 01:48 ..
-rw-r--r--  1 ray users    0 Feb 18 01:48 .validate_storage_marker
drwxr-xr-x  2 ray users 6144 Feb 18 01:50 checkpoint_2026-02-18_01-50-30.327889
drwxr-xr-x  4 ray users 6144 Feb 18 01:51 .
drwxr-xr-x  2 ray users 6144 Feb 18 01:51 full_model
-rw-r--r--  1 ray users  334 Feb 18 01:52 checkpoint_manager_snapshot.json


## Step 8: Load Model for Inference

The consolidated model (`full-model.pt`) is a standard PyTorch checkpoint that works without FSDP2.

In [14]:
model = init_model(train_loop_config["hidden_dim"])
model_state_dict = torch.load("/mnt/cluster_storage/fsdp_mnist/full_model/full-model.pt", map_location='cpu')
model.load_state_dict(model_state_dict)

[36m(autoscaler +6m28s)[0m Tip: use `ray status` to view detailed cluster status. To disable these messages, set RAY_SCHEDULER_EVENTS=0.
[36m(autoscaler +6m28s)[0m Memory cgroup out of memory: Killed process 20108 (python) total-vm:15808636kB, anon-rss:9593532kB, file-rss:312064kB, shmem-rss:0kB, UID:1000 pgtables:20884kB oom_score_adj:-998


<All keys matched successfully>

Load some test data

In [15]:
transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
test_data = FashionMNIST(
    root=".", train=False, download=True, transform=transform
)
test_data

Dataset FashionMNIST
    Number of datapoints: 10000
    Root location: .
    Split: Test
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=(0.5,), std=(0.5,))
           )

Run inference

In [16]:
model.eval()
with torch.no_grad():
    out = model(test_data.data[0].reshape(1, 1, 28, 28).float())
    predicted_label = out.argmax().item()
    test_label = test_data.targets[0].item()
    print(f"{predicted_label=} {test_label=}")

predicted_label=4 test_label=9


## Summary



This tutorial covered:
1. **FSDP2 sharding** - Distributed model parameters across GPUs using `fully_shard()`
2. **Ray Train integration** - Multi-GPU training with automatic process group management
3. **PyTorch DCP** - Sharded checkpointing with automatic resharding on load
4. **Inference** - Loading consolidated model for single-GPU inference

**Next Steps:**
- Add CPU offloading: `CPUOffloadPolicy()` for memory-constrained scenarios
- Add mixed precision: `MixedPrecisionPolicy(param_dtype=torch.float16)`
- Try [DeepSpeed tutorial](./DeepSpeed_RayTrain_Tutorial.ipynb) for comparison

**Resources:**
- [PyTorch FSDP Tutorial](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html)
- [Ray Train Documentation](https://docs.ray.io/en/latest/train/getting-started-pytorch.html)