# Get Started with PyTorch FSDP2 and Ray Train (LIVE)

Streamlined notebook for ~1 hr hands-on workshops. Same ViT-on-FashionMNIST setup; compare with [DeepSpeed_RayTrain_Tutorial_LIVE.ipynb](./DeepSpeed_RayTrain_Tutorial_LIVE.ipynb) for the alternative backend.



This notebook demonstrates how to train large models using PyTorch's Fully Sharded Data Parallel (FSDP2) with Ray Train. FSDP2 enables model sharding across multiple GPUs, reducing memory footprint compared to standard DDP.

**Learning Objectives:**
1. Configure FSDP2 sharding for distributed training
2. Use PyTorch Distributed Checkpoint (DCP) for sharded model checkpointing
3. Load trained models for inference


This notebook will walk you through a high level overview of using FSDP with Ray Train.

<div class="alert alert-block alert-info">

Here is the roadmap for this notebook:

<ol>
  <li>What is FSDP?</li>
  <li>FSDP vs DDP simplified</li>
  <li>How to use FSDP (v2) with Ray Train?</li>
</ol>
</div>

## `Step 0`: What is FSDP2?




[FSDP2](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html) is PyTorch's native solution for training large models:

- Shards model parameters, gradients, and optimizer states across workers
- All-gathers parameters during forward pass, then re-shards
- Enables training models larger than single GPU memory

**When to use FSDP2:**
- Model exceeds single GPU memory
- You want native PyTorch integration
- Building custom training loops

### FSDP Workflow

Below is a diagram (taken and adapted from PyTorch) that shows the FSDP workflow

<img src="https://anyscale-materials.s3.us-west-2.amazonaws.com/ray-train-deep-dive/FSDP.png" width="800">


Here is a table explaining the different phases, their steps and what happens:

<table>
  <thead>
    <tr>
      <th>Phase</th>
      <th>Step</th>
      <th>What Happens</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>Initialization</td>
      <td>Parameter sharding</td>
      <td>Each rank stores only its own shard of every parameter it owns.</td>
    </tr>
    <tr>
      <td>Forward pass</td>
      <td>
        <ol>
          <li>all_gather</li>
          <li>Compute</li>
          <li>Free full weights</li>
        </ol>
      </td>
      <td>Ranks gather one another’s shards to reconstruct full weights(parameters), execute the forward computation, then immediately free the temporary shards.</td>
    </tr>
    <tr>
      <td>Backward pass</td>
      <td>
        <ol>
          <li>all_gather</li>
          <li>Back-propagate</li>
          <li>reduce_scatter</li>
          <li>Free full weights</li>
        </ol>
      </td>
      <td>Shards are re-gathered, gradients are computed, then reduced-and-scattered so each rank keeps only its own gradient shard; full replicas are discarded again.</td>
    </tr>
  </tbody>
</table>


### FSDP vs DDP simplified

Let's go over a toy example  (inspired by [this guide on parallelism from huggingface](https://huggingface.co/docs/transformers/v4.13.0/en/parallelism)) comparing how FSDP and DDP operate.

### 1  Toy model

| **La** | **Lb** | **Lc** |
|:------:|:------:|:------:|
 a₀ | b₀ | c₀ |
 a₁ | b₁ | c₁ |
 a₂ | b₂ | c₂ |

Layer `La` contains the weights `[a₀, a₁, a₂]`

Total parameters = **9 scalars** (3 per layer).

---

### 2  Parameter layout at rest  

#### 2.1  DDP  (full replication)

| **GPU** | **La**           | **Lb**           | **Lc**           |
|---------|------------------|------------------|------------------|
| 0       | a₀ a₁ a₂         | b₀ b₁ b₂         | c₀ c₁ c₂         |
| 1       | a₀ a₁ a₂         | b₀ b₁ b₂         | c₀ c₁ c₂         |
| 2       | a₀ a₁ a₂         | b₀ b₁ b₂         | c₀ c₁ c₂         |

*Each worker stores **100 %** of the model (parameters + optimizer states + gradients).*

---

#### 2.2  FSDP  (parameter sharding)

| **GPU** | **La** | **Lb** | **Lc** |
|---------|:------:|:------:|:------:|
| 0       | a₀     | b₀     | c₀     |
| 1       | a₁     | b₁     | c₁     |
| 2       | a₂     | b₂     | c₂     |

*At rest each worker keeps only **1∕3** of every layer—so memory ≈ 33 % of DDP.*

---

### 3  Execution flow (GPU 0’s perspective)

#### 3.1  FSDP

<details>
<summary><strong>Forward pass (Layer La)</strong></summary>

1. **Need:** a₀ a₁ a₂  
2. **Has locally:** a₀  
3. **Gather:** a₁ from GPU 1, a₂ from GPU 2  
4. **Compute:** *y = La(x)* with full weights  
5. **Free:** a₁, a₂ (optional halfway-release)

Repeat for **Lb** then **Lc**.
</details>

<details>
<summary><strong>Backward pass (Layer Lc)</strong></summary>

1. **Need:** c₀ c₁ c₂  
2. **Has locally:** c₀  
3. **Gather:** c₁ from GPU 1, c₂ from GPU 2  
4. **Compute:** ∂L/∂c₀, ∂L/∂c₁, ∂L/∂c₂  
5. **Reduce-scatter:** average gradients; GPU k keeps only its own shard (cₖ)  
6. **Free:** c₁, c₂ (optional halfway-release)

Work upstream through **Lb → La** in reverse order.
</details>

<details>
<summary><strong>Optimizer step</strong></summary>

*Purely local.*  
GPU 0 updates a₀, b₀, c₀ using the averaged gradient shards already resident in memory. No extra communication.
</details>

---

#### 3.2  DDP  (for comparison)

| Phase             | What GPU 0 already owns         | Communication |
|-------------------|---------------------------------|---------------|
| **Forward**       | Full La, Lb, Lc                 | None          |
| **Backward**      | Full La, Lb, Lc + grads         | **All-reduce**|
| **Optimizer step**| Full params & full grads        | None          |

---

### 4  Key take-aways

|                         | **DDP**                           | **FSDP**                               |
|-------------------------|-----------------------------------|----------------------------------------|
| **Memory footprint**    | Replicated (× #GPUs)              | 1∕#GPUs at rest; slightly higher during gather |
| **Communication**       | All-reduce once per layer (backward) | Gather + reduce-scatter per layer      |
| **Optimizer states**    | Fully replicated                  | Sharded—1∕#GPUs memory|
| **Implementation ease** | Very simple                       | More knobs (wrapping policy, offloading,prefetch) |
| **When to prefer**      | Fits in memory; small models      | Large models that would otherwise OOM  |

> **Rule of thumb:**  
> *If you can replicate the model comfortably, DDP wins on simplicity and sometimes speed.  
> If you’re out of memory or pushing model scale, FSDP (or ZeRO-style sharding) is the way forward.*

---

#### 5  Why we wrapped each layer separately

For illustration we treated **“one layer = one FSDP unit.”**  
In practice you can:

* **Cluster layers** into larger FSDP units to reduce the number of gather calls.
* **Wrap only the largest sub-modules** (e.g., big embeddings, attention blocks) and leave tiny layers unsharded.

Experiment with *auto-wrap policies* to find the sweet spot between memory savings and communication overhead.


### When to Consider FSDP

- **Model no longer fits** on a single GPU even with mixed precision.  
- **Batch size is GPU memory-bound** under classic DDP.  
- **You have multiple GPUs** with sufficient interconnect bandwidth.  
- **You already use DDP** but need to push to larger architectures (e.g., multi-billion-parameter transformers).  
- **You want minimal code changes**—wrap layers with `torch.distributed.fsdp.FullyShardedDataParallel`.  

FSDP lets you step beyond the memory limits of traditional data parallelism while keeping your training loop largely unchanged.

### What is FSDP?

Fully Sharded Data Parallel (FSDP) is a parallelism method that combines the advantages of data and model parallelism for distributed training.

### FSDP Workflow

Below is a diagram (taken and adapted from PyTorch) that shows the FSDP workflow

<img src="https://anyscale-materials.s3.us-west-2.amazonaws.com/ray-train-deep-dive/FSDP.png" width="800">


Here is a table explaining the different phases, their steps and what happens:

<table>
  <thead>
    <tr>
      <th>Phase</th>
      <th>Step</th>
      <th>What Happens</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>Initialization</td>
      <td>Parameter sharding</td>
      <td>Each rank stores only its own shard of every parameter it owns.</td>
    </tr>
    <tr>
      <td>Forward pass</td>
      <td>
        <ol>
          <li>all_gather</li>
          <li>Compute</li>
          <li>Free full weights</li>
        </ol>
      </td>
      <td>Ranks gather one another’s shards to reconstruct full weights(parameters), execute the forward computation, then immediately free the temporary shards.</td>
    </tr>
    <tr>
      <td>Backward pass</td>
      <td>
        <ol>
          <li>all_gather</li>
          <li>Back-propagate</li>
          <li>reduce_scatter</li>
          <li>Free full weights</li>
        </ol>
      </td>
      <td>Shards are re-gathered, gradients are computed, then reduced-and-scattered so each rank keeps only its own gradient shard; full replicas are discarded again.</td>
    </tr>
  </tbody>
</table>


### FSDP vs DDP simplified

Let's go over a toy example  (inspired by [this guide on parallelism from huggingface](https://huggingface.co/docs/transformers/v4.13.0/en/parallelism)) comparing how FSDP and DDP operate.

### 1  Toy model

| **La** | **Lb** | **Lc** |
|:------:|:------:|:------:|
 a₀ | b₀ | c₀ |
 a₁ | b₁ | c₁ |
 a₂ | b₂ | c₂ |

Layer `La` contains the weights `[a₀, a₁, a₂]`

Total parameters = **9 scalars** (3 per layer).

---

### 2  Parameter layout at rest  

#### 2.1  DDP  (full replication)

| **GPU** | **La**           | **Lb**           | **Lc**           |
|---------|------------------|------------------|------------------|
| 0       | a₀ a₁ a₂         | b₀ b₁ b₂         | c₀ c₁ c₂         |
| 1       | a₀ a₁ a₂         | b₀ b₁ b₂         | c₀ c₁ c₂         |
| 2       | a₀ a₁ a₂         | b₀ b₁ b₂         | c₀ c₁ c₂         |

*Each worker stores **100 %** of the model (parameters + optimizer states + gradients).*

---

#### 2.2  FSDP  (parameter sharding)

| **GPU** | **La** | **Lb** | **Lc** |
|---------|:------:|:------:|:------:|
| 0       | a₀     | b₀     | c₀     |
| 1       | a₁     | b₁     | c₁     |
| 2       | a₂     | b₂     | c₂     |

*At rest each worker keeps only **1∕3** of every layer—so memory ≈ 33 % of DDP.*

---

### 3  Execution flow (GPU 0’s perspective)

#### 3.1  FSDP

<details>
<summary><strong>Forward pass (Layer La)</strong></summary>

1. **Need:** a₀ a₁ a₂  
2. **Has locally:** a₀  
3. **Gather:** a₁ from GPU 1, a₂ from GPU 2  
4. **Compute:** *y = La(x)* with full weights  
5. **Free:** a₁, a₂ (optional halfway-release)

Repeat for **Lb** then **Lc**.
</details>

<details>
<summary><strong>Backward pass (Layer Lc)</strong></summary>

1. **Need:** c₀ c₁ c₂  
2. **Has locally:** c₀  
3. **Gather:** c₁ from GPU 1, c₂ from GPU 2  
4. **Compute:** ∂L/∂c₀, ∂L/∂c₁, ∂L/∂c₂  
5. **Reduce-scatter:** average gradients; GPU k keeps only its own shard (cₖ)  
6. **Free:** c₁, c₂ (optional halfway-release)

Work upstream through **Lb → La** in reverse order.
</details>

<details>
<summary><strong>Optimizer step</strong></summary>

*Purely local.*  
GPU 0 updates a₀, b₀, c₀ using the averaged gradient shards already resident in memory. No extra communication.
</details>

---

#### 3.2  DDP  (for comparison)

| Phase             | What GPU 0 already owns         | Communication |
|-------------------|---------------------------------|---------------|
| **Forward**       | Full La, Lb, Lc                 | None          |
| **Backward**      | Full La, Lb, Lc + grads         | **All-reduce**|
| **Optimizer step**| Full params & full grads        | None          |

---

### 4  Key take-aways

|                         | **DDP**                           | **FSDP**                               |
|-------------------------|-----------------------------------|----------------------------------------|
| **Memory footprint**    | Replicated (× #GPUs)              | 1∕#GPUs at rest; slightly higher during gather |
| **Communication**       | All-reduce once per layer (backward) | Gather + reduce-scatter per layer      |
| **Optimizer states**    | Fully replicated                  | Sharded—1∕#GPUs memory|
| **Implementation ease** | Very simple                       | More knobs (wrapping policy, offloading,prefetch) |
| **When to prefer**      | Fits in memory; small models      | Large models that would otherwise OOM  |

> **Rule of thumb:**  
> *If you can replicate the model comfortably, DDP wins on simplicity and sometimes speed.  
> If you’re out of memory or pushing model scale, FSDP (or ZeRO-style sharding) is the way forward.*

---

#### 5  Why we wrapped each layer separately

For illustration we treated **“one layer = one FSDP unit.”**  
In practice you can:

* **Cluster layers** into larger FSDP units to reduce the number of gather calls.
* **Wrap only the largest sub-modules** (e.g., big embeddings, attention blocks) and leave tiny layers unsharded.

Experiment with *auto-wrap policies* to find the sweet spot between memory savings and communication overhead.


### When to Consider FSDP

- **Model no longer fits** on a single GPU even with mixed precision.  
- **Batch size is GPU memory-bound** under classic DDP.  
- **You have multiple GPUs** with sufficient interconnect bandwidth.  
- **You already use DDP** but need to push to larger architectures (e.g., multi-billion-parameter transformers).  
- **You want minimal code changes**—wrap layers with `torch.distributed.fsdp.FullyShardedDataParallel`.  

FSDP lets you step beyond the memory limits of traditional data parallelism while keeping your training loop largely unchanged.

## `Step 1`: Environment Setup

Check Ray cluster status and install dependencies.

<div style="display: flex; justify-content: center;">
  <img src="./images/img_1.png" alt="FSDP Diagram" style="width: 80%;" />
</div>

In [1]:
# Check Ray cluster status
!ray status

Node status
---------------------------------------------------------------
Active:
 (no active nodes)
Idle:
 1 head
 1 1xL4:16CPU-64GB-1
 1 1xL4:16CPU-64GB-2
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Total Usage:
 0.0/32.0 CPU
 0.0/2.0 GPU
 0.0/2.0 anyscale/accelerator_shape:1xL4
 0.0/1.0 anyscale/cpu_only:true
 0.0/1.0 anyscale/node-group:1xL4:16CPU-64GB-1
 0.0/1.0 anyscale/node-group:1xL4:16CPU-64GB-2
 0.0/1.0 anyscale/node-group:head
 0.0/3.0 anyscale/provider:aws
 0.0/3.0 anyscale/region:us-west-2
 0B/160.00GiB memory
 0B/44.60GiB object_store_memory

From request_resources:
 (none)
Pending Demands:
 (no resource demands)
[0m

In [2]:
# Stdlib imports
import os
import tempfile

# Ray Train imports
import ray
import ray.train
import ray.train.torch

# PyTorch core and FSDP2 imports
import torch
from torch.distributed.fsdp import (
    fully_shard,
    FSDPModule,
    CPUOffloadPolicy,
    MixedPrecisionPolicy,
)

# PyTorch Distributed Checkpoint (DCP) imports
from torch.distributed.checkpoint.state_dict import (
    get_state_dict,
    set_state_dict,
    get_model_state_dict,
    StateDictOptions
)
from torch.distributed.device_mesh import init_device_mesh 
from torch.distributed.checkpoint.stateful import Stateful
import torch.distributed.checkpoint as dcp

# PyTorch training components
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader

# Computer vision components
from torchvision.models import VisionTransformer
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose

## `Step 2`: Model Definition

We use a Vision Transformer (ViT) with repeatable encoder blocks - ideal for demonstrating FSDP2's per-layer sharding.

Let's go over a simple example of how to use FSDP with Ray Train and PyTorch.

<div style="display: flex; justify-content: center;">
  <img src="./images/img_2.png" alt="Train Diagram" style="width: 80%;" />
</div>

Below is a sample training function that we will use to train our model.

In [None]:
def train_func(config):
    """
    Training function executed by EACH worker (GPU).
    Ray Train spawns N copies of this function — one per GPU.
    All workers run the same code but process different data batches.
    """

    # ==========================================================
    # Step 1 : Initialize the model (on CPU, not sharded yet)
    # ==========================================================
    model = init_model(config["hidden_dim"])

    # ==========================================================
    # Step 2 : Move the model to this worker's assigned GPU
    # ==========================================================
    device = ray.train.torch.get_device()
    torch.cuda.set_device(device)
    model.to(device)

    # ==========================================================
    # Step 3 : Apply FSDP2 sharding (now each worker owns only a slice of parameters)
    #          Must come AFTER model is on GPU!
    # ==========================================================
    prepare_model(
        model,
        skip_model_shard=config["skip_model_shard"],
        skip_cpu_offload=config["skip_cpu_offload"],
        use_float16=config["use_float16"],
    )

    # ==========================================================
    # Step 4 : Define loss function and optimizer (sharded states under FSDP)
    # ==========================================================
    criterion = CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=config.get("learning_rate", 0.001))

    # ==========================================================
    # Step 5 : Resume from checkpoint if available (sharded, supports flexible recovery)
    # ==========================================================
    start_epoch = 0
    loaded_checkpoint = ray.train.get_checkpoint()
    if loaded_checkpoint:
        start_epoch = load_fsdp_checkpoint(model, optimizer, loaded_checkpoint)

    # ==========================================================
    # Step 6 : Prepare the distributed FashionMNIST data loader
    #          Each worker gets a different shard automatically
    # ==========================================================
    transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
    data_dir = os.path.join(tempfile.gettempdir(), "data")
    train_data = FashionMNIST(
        root=data_dir, train=False, download=True, transform=transform
    )
    train_loader = DataLoader(
        train_data,
        batch_size=config.get("batch_size", 128),
        shuffle=True,
        num_workers=2,
    )
    train_loader = ray.train.torch.prepare_data_loader(train_loader)

    world_rank = ray.train.get_context().get_world_rank()

    # ==========================================================
    # Step 7 : Main training loop (standard PyTorch, FSDP handles sharding transparently)
    # ==========================================================
    running_loss = 0.0
    num_batches = 0
    epochs = config["epochs"]

    for epoch in range(start_epoch, epochs):

        # ==========================================================
        # Ensure proper random shuffling by epoch for DistributedSampler
        # ==========================================================
        if ray.train.get_context().get_world_size() > 1:
            train_loader.sampler.set_epoch(epoch)

        for images, labels in train_loader:

            # ==========================================================
            # Forward pass: model(images) triggers all-gather and computation
            # ==========================================================
            outputs = model(images)
            loss = criterion(outputs, labels)

            # ==========================================================
            # Backward pass: loss.backward() handles sharded gradients
            # Local optimizer step updates only this shard
            # ==========================================================
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            num_batches += 1

        # ==========================================================
        # Step 8 : Save and report metrics and checkpoint after every epoch
        #          Checkpointing is sharded—fast and memory-efficient
        # ==========================================================
        avg_loss = running_loss / num_batches
        metrics = {"loss": avg_loss, "epoch": epoch + 1}
        report_metrics_and_save_fsdp_checkpoint(model, optimizer, metrics, epoch + 1)

        # ==========================================================
        # Only rank 0 logs metrics to prevent duplicate outputs
        # ==========================================================
        if world_rank == 0:
            print(metrics)

    # ==========================================================
    # Step 9 : Save the final model for inference ("all_gather" full state onto rank 0)
    # ==========================================================
    save_model_for_inference(model, world_rank)


### Initialize the model
Initialize a Vision Transformer model for FashionMNIST classification

In [4]:
def init_model(hidden_dim) -> torch.nn.Module:
    # Create a ViT model with architecture suitable for 28x28 images
    model = VisionTransformer(
        image_size=28,         # FashionMNIST image size
        patch_size=7,          # Divide 28x28 into 4x4 patches of 7x7 pixels each
        num_layers=12,         # Number of transformer encoder layers
        num_heads=8,           # Number of attention heads per layer
        hidden_dim=hidden_dim, # Hidden dimension size
        mlp_dim=768,           # MLP dimension in transformer blocks
        num_classes=10,        # FashionMNIST has 10 classes
    )

    # Modify the patch embedding layer for grayscale images (1 channel instead of 3)
    model.conv_proj = torch.nn.Conv2d(
        in_channels=1,            # FashionMNIST is grayscale (1 channel)
        out_channels=hidden_dim,  # Must match the hidden_dim
        kernel_size=7,            # Match patch_size
        stride=7,                 # Non-overlapping patches
    )

    return model

## `Step 3`: Prepare the model and FSDP2 Sharding Configuration

To prepare the model, we use the `fully_shard` (FSDP2) function from pytorch but we have to first:
1. Create a device mesh
2. Apply `fully_shard` to each block we want to shard
3. Apply `fully_shard` to the model itself

> ```
> ViT Model
>   ├── Patch Embedding      ← wrapped by root fully_shard(model)
>   ├── Encoder Block 1      ← fully_shard(block)
>   ├── Encoder Block 2      ← fully_shard(block)
>   ├── ...                  
>   ├── Encoder Block 12     ← fully_shard(block)
>   └── Classification Head  ← wrapped by root fully_shard(model)
> ```

In [6]:
def prepare_model(
    model: torch.nn.Module,
    skip_model_shard: bool,
    skip_cpu_offload: bool,
    use_float16: bool,
):
    """
    Apply FSDP2 sharding to the model.

    This function performs the following:
      1. Creates a device mesh (sets up GPU topology)
      2. Shards each encoder block individually
      3. Shards the root model (top-level)
    
    After running, each GPU keeps only a fraction (1/N) of the total parameters.
    """

    # =================================================
    # Step 1: Device Mesh Creation
    # =================================================
    #
    # The device mesh defines how GPUs are organized for distributed training.
    # Here, we set up a 1D mesh for pure data parallelism: all GPUs lined up along one dimension.
    #
    world_size = ray.train.get_context().get_world_size()
    mesh = init_device_mesh(
        device_type="cuda",                # Use CUDA-enabled GPUs
        mesh_shape=(world_size,),          # 1D mesh: a flat group of GPUs
        mesh_dim_names=("data_parallel",)  # Naming the dimension for clarity
    )

    # =================================================
    # Step 2: CPU Offloading Configuration (Optional)
    # =================================================
    #
    # CPU offloading will transfer parameter shards that aren't actively in use
    # from GPU VRAM to CPU RAM, saving GPU memory at the expense of a bit more data transfer time.
    # Useful if GPU memory is a tight resource.
    #
    offload_policy = CPUOffloadPolicy() if not skip_cpu_offload else None

    # =================================================
    # Step 3: Mixed Precision Configuration (Optional)
    # =================================================
    #
    # Mixed precision stores model parameters in float16 (half precision) rather than float32,
    # reducing memory requirements and using specialized GPU tensor cores for better performance.
    # float16 works well on most GPUs. For production on A100/H100, use bfloat16.
    #
    mp_policy_float16 = MixedPrecisionPolicy(
        param_dtype=torch.float16,    # Store model parameters as float16
        reduce_dtype=torch.float16,   # Reduce the gradients in float16 as well
    )
    default_policy = MixedPrecisionPolicy(
        param_dtype=None,             # Keep original precision (float32)
        reduce_dtype=None,
        output_dtype=None,
        cast_forward_inputs=True      # Inputs will match parameter dtype
    )
    mp_policy = mp_policy_float16 if use_float16 else default_policy

    # =================================================
    # Step 4: Sharding Each Encoder Block Separately
    # =================================================
    #
    # Each transformer encoder block becomes its own FSDP unit.
    # This means that during the forward pass, only a single block's parameters are gathered onto GPU at a time,
    # dramatically reducing peak memory usage.
    #
    # reshard_after_forward:
    #   - True  → Free parameters after forward (memory savings, more comms)
    #   - False → Keep parameters in memory for backward (faster, but uses more memory)
    # skip_model_shard=True is equivalent to reshard_after_forward=False
    #
    for encoder_block in model.encoder.layers.children():
        fully_shard(
            encoder_block,
            mesh=mesh,
            reshard_after_forward=not skip_model_shard,
            offload_policy=offload_policy,
            mp_policy=mp_policy,
        )

    # =================================================
    # Step 5: Sharding the Root Model
    # =================================================
    #
    # This wraps the remaining parts of the model—such as patch embedding,
    # positional encoding, and the classification head—with FSDP as well.
    # Top-level sharding is essential for FSDP2's coordination and memory efficiency.
    #
    fully_shard(
        model,
        mesh=mesh,
        reshard_after_forward=not skip_model_shard,
        offload_policy=offload_policy,
        mp_policy=mp_policy,
    )

Here is a table of the main keyword arguments for FSDP:

| Parameter | What it controls | Typical values & when to use |
|-----------|-----------------|------------------------------|
| `reshard_after_forward` | Whether to free (reshard) parameters immediately after forward pass to save memory | **`True`** – free parameters after forward pass for maximum memory savings; increases communication overhead.<br>**`False`** (default) – keep parameters in memory through backward pass; faster but uses more memory. |
| `offload_policy` | Whether inactive parameter shards are moved to CPU RAM | **`None`** (default) – fastest, keeps all data on GPU.<br>**`CPUOffloadPolicy()`** – offloads parameters to CPU RAM; frees GPU memory at the cost of extra PCIe traffic; enables larger models on small-RAM GPUs. |
| `mp_policy` | Controls mixed precision settings for parameters and gradients | **`None`** (default) – use model's native precision (typically float32).<br>**`MixedPrecisionPolicy(param_dtype=torch.float16)`** – store parameters in half precision for memory savings and to leverage tensor cores.<br>**`MixedPrecisionPolicy(param_dtype=torch.bfloat16)`** – use bfloat16 for newer architectures (A100, H100) with better numerical stability and tensor core acceleration.<br>**`MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32)`** – commonly used configuration; bfloat16 parameters with float32 gradient reduction for better numerical stability during reduce-scatter operations.|


## `Step 4`: Distributed Checkpointing with PyTorch

`torch.distributed.checkpoint()` enables saving and loading models from multiple ranks in parallel. You can use this module to save on any number of ranks in parallel, and then re-shard across differing cluster topologies at load time.

PyTorch Distributed Checkpoint (DCP) provides efficient checkpointing for sharded models:
- Each worker saves only its shard (parallel I/O)
- Automatic resharding on load if worker count changes
- Full optimizer state support for training resumption


### Defining a Stateful object

We take advantage of a Stateful object to handle calling distributed state dict methods on the model and optimizer.

This is a useful wrapper for checkpointing the Application State. Since this object is compliant
with the Stateful protocol, PyTorch DCP will automatically call state_dict/load_stat_dict as needed in the
dcp.save/load APIs.

In [None]:

class AppState(Stateful):
    """
    AppState is a checkpointable wrapper for model, optimizer, and epoch state.

    This class implements the `Stateful` protocol and provides:
      - `state_dict()`: Gathers the full checkpoint state using FSDP2-aware utilities.
      - `load_state_dict()`: Restores both model and optimizer state, handling possible resharding and distributed setups.

    Purpose:
      - Makes distributed checkpointing simple and robust.
      - Handles sharded model state via FSDP2 Fully Qualified Names (FQNs).
      - epoch is also checkpointed for resume support.

    FSDP2 NOTE:
        - Use `get_state_dict` and `set_state_dict` from FSDP2 utilities.
        - Never call model.state_dict() directly with FSDP2!
        - State dict keys include fully sharded paths such as 'encoder.layers.0.self_attention.weight'.
    """

    def __init__(self, model, optimizer=None, epoch=0):
        self.model = model
        self.optimizer = optimizer
        self.epoch = epoch  # Track training progress for resume

    def state_dict(self):
        # ============================================================
        # Step 1: Gather and return complete checkpoint state
        # ============================================================
        #
        # FSDP2: Use `get_state_dict()` instead of model.state_dict():
        #    - Collects sharded state dict with global param names (FQNs)
        #    - Returns both model and optimizer state
        #
        # CAUTION: Only use `get_state_dict()` with FSDP2!
        #
        model_state_dict, optimizer_state_dict = get_state_dict(
            self.model, self.optimizer
        )
        return {
            "model": model_state_dict,
            "optim": optimizer_state_dict,
            "epoch": self.epoch,  # Allows checkpointing epoch for resume
        }

    def load_state_dict(self, state_dict):
        # ============================================================
        # Step 2: Restore state from checkpoint dictionary
        # ============================================================
        #
        # FSDP2: Use `set_state_dict()` for distributed, sharded recovery:
        #    - Handles model, optimizer shards, and resharding if world_size changes
        #    - Automatically maps FQNs
        #
        set_state_dict(
            self.model,
            self.optimizer,
            model_state_dict=state_dict["model"],
            optim_state_dict=state_dict["optim"],
        )
        self.epoch = state_dict["epoch"]  # Ensure resume from checkpoint is correct



### Saving Distributed Checkpoints when training

This function performs two critical operations:
1. Saves the current model and optimizer state using distributed checkpointing
2. Reports metrics to Ray Train for tracking

In [None]:
def report_metrics_and_save_fsdp_checkpoint(
    model: FSDPModule,
    optimizer: torch.optim.Optimizer,
    metrics: dict,
    epoch: int,
) -> None:
    """
    Save distributed FSDP model & optimizer state, and report metrics to Ray Train.

    This function:
      - Saves a *sharded* model & optimizer checkpoint using DCP (across all workers)
      - Reports metrics and checkpoint to Ray Train for experiment tracking
    """
    # ========================================================================
    # Step 1: Create a temporary directory for the distributed checkpoint
    # ========================================================================
    with tempfile.TemporaryDirectory() as temp_checkpoint_dir:

        # =====================================================================
        # Step 2: Compose app state for FSDP checkpointing
        # (Combines model, optimizer, and epoch into a single checkpointed object)
        # =====================================================================
        state_dict = {"app": AppState(model, optimizer, epoch)}

        # =====================================================================
        # Step 3: Save the distributed sharded checkpoint using DCP
        # (Ensures model & optimizer shards are persisted on all workers)
        # =====================================================================
        dcp.save(
            state_dict=state_dict,
            checkpoint_id=temp_checkpoint_dir,
        )

        # =====================================================================
        # Step 4: Inform Ray Train about this checkpoint and current metrics
        #         - checkpoint: gathers all shards into a Ray-readable checkpoint
        #         - metrics: dictionary with current validation/train metrics
        # =====================================================================
        checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
        ray.train.report(
            metrics,              # experiment metrics to track
            checkpoint=checkpoint # distributed checkpoint to save
        )

### Saving a final and full checkpoint for inference

For inference, we want to save an unsharded copy of the model. This is an expensive operation given all parameters need to gather all model parameters from all ranks onto a single rank.

This function consolidates the distributed model weights into a single checkpoint file that can be used for inference without FSDP.

In [None]:
def save_model_for_inference(
    model: FSDPModule,
    world_rank: int,
) -> None:
    """
    Consolidate distributed FSDP model weights into a single checkpoint
    file for inference, gathering all parameters onto rank 0.

    Args:
        model (FSDPModule): The FSDP-wrapped model.
        world_rank (int): Current process rank.
    """
    # ========================================================================
    # Step 1: Create a temporary directory to store the unsharded full model
    # ========================================================================
    with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
        save_file = os.path.join(temp_checkpoint_dir, "full-model.pt")

        # =====================================================================
        # Step 2: Gather full model weights (from all shards) onto rank 0
        #         - full_state_dict=True gathers all weights to rank 0
        #         - cpu_offload=True moves tensors to CPU to reduce GPU memory usage
        # =====================================================================
        model_state_dict = get_model_state_dict(
            model=model,
            options=StateDictOptions(
                full_state_dict==True,   # Gather all model weights
                cpu_offload==True        # Offload to CPU for serialization
            ),
        )

        checkpoint = None

        # =====================================================================
        # Step 3: On rank 0, save the full model, and register a Ray checkpoint
        # =====================================================================
        if world_rank == 0:
            torch.save(model_state_dict, save_file)
            # Prepare checkpoint directory for Ray Train
            checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)

        # =====================================================================
        # Step 4: Report the inference checkpoint to Ray Train on all ranks
        #         - Only rank 0 will have a non-None checkpoint object
        # =====================================================================
        ray.train.report(
            metrics={},
            checkpoint=checkpoint,
            checkpoint_dir_name="full_model",
        )

<div class="alert alert-info">

**NOTE:** In PyTorch, if both `cpu_offload` and `full_state_dict` are set to True, then only the rank0 will get the state_dict and all other ranks will get empty state_dict.

</div>

### Loading distributed checkpoints

This function handles distributed checkpoint loading with automatic resharding support. It can restore checkpoints even when the number of workers differs from the original training run.

In [None]:
# ============================================================
# load_fsdp_checkpoint: Load a Distributed Checkpoint w/ Resharding
# ============================================================
# Loads a distributed checkpoint using torch DCP, restoring model and optimizer state.
# Also handles resharding if loading occurs with a different world size.
def load_fsdp_checkpoint(
    model: FSDPModule,
    optimizer: torch.optim.Optimizer,
    ckpt: ray.train.Checkpoint,
) -> int:
    try:
        with ckpt.as_directory() as checkpoint_dir:
            # Wrap the model and optimizer for DCP loading
            app_state = AppState(model, optimizer)
            state_dict = {"app": app_state}

            # Load the distributed checkpoint (reshard if necessary)
            dcp.load(
                state_dict=state_dict,
                checkpoint_id=checkpoint_dir,
            )

            # Return epoch from updated app_state
            return app_state.epoch
    except Exception as e:
        raise RuntimeError(f"Checkpoint loading failed: {e}") from e


## `Step 5`: Final Training Configuration 

The training function runs on each worker:
1. Initialize and shard model with FSDP2
2. Run training loop with distributed data loading
3. Save checkpoints using PyTorch DCP

Configure scaling and resource requirements.

In [12]:
# ============================================================
# ScalingConfig: HOW to run (infrastructure)
# ============================================================
# This tells Ray Train:
#   - Spawn 2 workers (one per GPU)
#   - Each worker needs one GPU
# Changing num_workers changes how many GPUs participate in FSDP.
# More workers = each GPU holds less of the model = lower memory per GPU.

scaling_config = ray.train.ScalingConfig(
    num_workers=2,           # Number of GPU workers (= number of FSDP shards)
    use_gpu=True,            # Each worker gets one GPU
)

Launch the distributed training job.

In [13]:
# ============================================================
# train_loop_config: WHAT to run (hyperparameters)
# ============================================================
# Walk through each parameter:
#   - epochs=1: Just 1 epoch for demo speed. Try 3+ for better accuracy.
#   - hidden_dim=3840: Intentionally large to need FSDP. The model will
#     be ~180M parameters, which is tight for a single L4 (24GB).
#
# FSDP KNOBS (the interesting part!):
#   - skip_model_shard=True  → reshard_after_forward=False (keep params)
#   - skip_cpu_offload=True  → No CPU offloading
#   - use_float16=False      → Full fp32 precision
#
# EXERCISE: After the first run, try changing these:
#   1. skip_model_shard=False → enables resharding (lower memory)
#   2. skip_cpu_offload=False → enables CPU offload (even lower memory)
#   3. use_float16=True       → half precision (half the memory, faster)
#   4. hidden_dim=7680        → double the model size (will it OOM?)

train_loop_config = {
    "epochs": 1,                  # Number of training epochs
    "learning_rate": 0.001,       # Adam learning rate
    "batch_size": 128,            # Batch size per worker (not global!)
    "skip_model_shard": True,     # True = keep params after forward (faster)
    "skip_cpu_offload": True,     # True = no CPU offloading (faster)
    "hidden_dim": 3840,           # ViT hidden dim (controls model size)
    "use_float16": False,         # True = use float16 mixed precision
}


## `Step 6`: Launch Distributed Training

Ray Train's `TorchTrainer` handles worker spawning, process group initialization, and checkpoint coordination.

In [14]:
# ============================================================
#                Launch distributed training!
# ============================================================

# TorchTrainer is the main entry point for Ray Train.
# It takes your training function, config, and handles all the distributed plumbing:
#     - Worker spawning and GPU assignment
#     - Process group initialization (NCCL backend)
#     - Checkpoint coordination between workers
#     - Fault tolerance (auto-restart on failure)

import random
import uuid

# ------------------------------------------------------------
# Generate a unique experiment name (folder) for this run
# ------------------------------------------------------------
training_name = "fsdp_mnist_" + str(uuid.uuid4())[:8]

trainer = ray.train.torch.TorchTrainer(
    train_func,                               # The function each worker runs
    scaling_config=scaling_config,            # How many workers, what GPUs
    train_loop_config=train_loop_config,      # Hyperparameters
    run_config=ray.train.RunConfig(
        storage_path="/mnt/cluster_storage/", # Shared storage for checkpoints
        name=training_name,                   # Experiment name (unique for each run)
        failure_config=ray.train.FailureConfig(
            max_failures=2                    # Auto-retry up to 2 times on worker failure
        ),
        worker_runtime_env={
            # These env vars configure the Kineto profiler to avoid log warnings.
            "env_vars": {
                "KINETO_USE_DAEMON": "1",
                "KINETO_DAEMON_INIT_DELAY_S": "5"
            }
        },
    ),
)

# trainer.fit() blocks until training completes.
#     INITIALIZING -> SCHEDULING -> RUNNING -> SHUTTING_DOWN -> FINISHED
result = trainer.fit()

2026-02-18 23:31:54,136	INFO worker.py:1821 -- Connecting to existing Ray cluster at address: 10.0.23.119:6379...
2026-02-18 23:31:54,147	INFO worker.py:1998 -- Connected to Ray cluster. View the dashboard at [1m[32mhttps://session-ffbqdd398vb4g8i97u3tsubr23.i.anyscaleuserdata.com [39m[22m
2026-02-18 23:31:54,152	INFO packaging.py:463 -- Pushing file package 'gcs://_ray_pkg_338ba8a2c55854004cd37d310db5092fe44dd789.zip' (1.25MiB) to Ray cluster...
2026-02-18 23:31:54,158	INFO packaging.py:476 -- Successfully pushed file package 'gcs://_ray_pkg_338ba8a2c55854004cd37d310db5092fe44dd789.zip'.
[36m(TrainController pid=45328)[0m [State Transition] INITIALIZING -> SCHEDULING.
[36m(TrainController pid=45328)[0m Attempting to start training worker group of size 2 with the following resources: [{'GPU': 1}] * 2
[36m(TrainController pid=45328)[0m [FailurePolicy] RETRY
[36m(TrainController pid=45328)[0m   Source: controller
[36m(TrainController pid=45328)[0m   Error count: 1 (max allo

[36m(TrainController pid=45328)[0m [State Transition] SHUTTING_DOWN -> FINISHED.


## `Step 7`: Inspect Training Artifacts

Training artifacts include:
- `checkpoint_*/` - Epoch checkpoints with distributed shards
- `full_model/` - Consolidated model for inference

In [17]:
# ============================================================
# List training artifacts on shared storage
# ============================================================
#   - checkpoint_* = sharded (for resuming training)
#   - full_model/ = consolidated (for inference)
storage_path = f"/mnt/cluster_storage/{training_name}/"
print(f"Artifacts in {storage_path}:")
!ls -ltra $storage_path

Artifacts in /mnt/cluster_storage/fsdp_mnist_bd7ac8fb/:
total 24
drwxr-xr-x 23 ray  1000 6144 Feb 18 23:31 ..
-rw-r--r--  1 ray users    0 Feb 18 23:31 .validate_storage_marker
drwxr-xr-x  2 ray users 6144 Feb 18 23:34 checkpoint_2026-02-18_23-34-44.513953
drwxr-xr-x  4 ray users 6144 Feb 18 23:35 .
drwxr-xr-x  2 ray users 6144 Feb 18 23:35 full_model
-rw-r--r--  1 ray users  335 Feb 18 23:36 checkpoint_manager_snapshot.json


## `Step 8`: Load Model for Inference

The consolidated model (`full-model.pt`) is a standard PyTorch checkpoint that works without FSDP2.

In [18]:
# ============================================================
# Load the consolidated model for inference
# ============================================================

# 1. Create the same architecture as used during training
model = init_model(train_loop_config["hidden_dim"])

# 2. Load the full model checkpoint (standard PyTorch, no FSDP required)
model_state_dict = torch.load(
    f"/mnt/cluster_storage/{training_name}/full_model/full-model.pt",
    map_location='cpu'  # Loads model on CPU, works even without GPU
)

# 3. Load the checkpointed weights into the model
model.load_state_dict(model_state_dict)

<All keys matched successfully>

Load some test data

In [20]:
# ============================================================
# Prepare test data for inference
# ============================================================
# Use the same preprocessing as during training to ensure compatibility.
# The FashionMNIST test set contains 10,000 images across 10 classes.

transform = Compose([
    ToTensor(),
    Normalize((0.5,), (0.5,))
])

test_data = FashionMNIST(
    root=".",
    train=False,
    download=True,
    transform=transform
)

test_data

Dataset FashionMNIST
    Number of datapoints: 10000
    Root location: .
    Split: Test
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=(0.5,), std=(0.5,))
           )

Run inference

In [None]:
# ============================================================
# Run a single inference prediction on the test dataset
# ============================================================

# Set model to evaluation mode (disables dropout, uses running stats for batch norm)
model.eval()

# Disable gradient computation for inference (saves memory and is faster)
with torch.no_grad():
    # Prepare the first test image:
    # - reshape to [batch_size, channels, height, width]
    # - convert to float tensor
    first_image = test_data.data[0].reshape(1, 1, 28, 28).float()
    output = model(first_image)
    
    # Get the predicted class (index with highest score)
    predicted_label = output.argmax().item()
    # Get the actual label
    actual_label = test_data.targets[0].item()
    
    # Map label indices to human-readable class names
    class_names = [
        "T-shirt", "Trouser", "Pullover", "Dress", "Coat",
        "Sandal", "Shirt", "Sneaker", "Bag", "Ankle Boot"
    ]

    print(f"Predicted: {predicted_label} ({class_names[predicted_label]}), "
          f"Actual: {actual_label} ({class_names[actual_label]})")



Predicted: 4 (Coat), Actual: 9 (Ankle Boot)


## Summary



This tutorial covered:
1. **FSDP2 sharding** - Distributed model parameters across GPUs using `fully_shard()`
2. **Ray Train integration** - Multi-GPU training with automatic process group management
3. **PyTorch DCP** - Sharded checkpointing with automatic resharding on load
4. **Inference** - Loading consolidated model for single-GPU inference

**Next Steps:**
- Add CPU offloading: `CPUOffloadPolicy()` for memory-constrained scenarios
- Add mixed precision: `MixedPrecisionPolicy(param_dtype=torch.float16)`
- Try [DeepSpeed tutorial](./DeepSpeed_RayTrain_Tutorial.ipynb) for comparison

**Resources:**
- [PyTorch FSDP Tutorial](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html)
- [Ray Train Documentation](https://docs.ray.io/en/latest/train/getting-started-pytorch.html)