# Get Started with PyTorch DeepSpeed and Ray Train

This notebook demonstrates how to train large models using **DeepSpeed ZeRO** with Ray Train. DeepSpeed enables memory-efficient distributed training by partitioning optimizer states, gradients, and (optionally) model parameters across multiple GPUs—similar in goal to PyTorch FSDP2.

**Learning Objectives:**
1. Configure DeepSpeed ZeRO for distributed training
2. Use DeepSpeed's built-in checkpointing with Ray Train
3. Load trained models for inference

**Key differences from FSDP2 (LIVE):** (1) **Config-driven** setup (JSON/dict) vs Python `fully_shard()`, (2) **Training step**: `model_engine.backward(loss)` + `model_engine.step()`, (3) **Checkpoints**: `model_engine.save_checkpoint()` / `load_checkpoint()`.

## `Step 0`: What is `DeepSpeed`?



[DeepSpeed](https://www.deepspeed.ai/) is Microsoft's deep learning optimization library. Its main feature is **ZeRO (Zero Redundancy Optimizer)**, which partitions training state across GPUs to reduce memory—similar in goal to PyTorch FSDP:

- **ZeRO Stage 1:** Partitions optimizer states across workers  
- **ZeRO Stage 2:** Partitions optimizer states + gradients (we use this here)  
- **ZeRO Stage 3:** Partitions optimizer + gradients + **model parameters** (closest to FSDP)

**Why use DeepSpeed with Ray Train?** Same reason as FSDP: train larger models or use larger batch sizes by spreading memory across GPUs. DeepSpeed uses a **config file or dict** to control behavior instead of Python API calls like `fully_shard()`, and it integrates with Ray Train the same way—each worker runs your training function with one GPU, and Ray handles orchestration and checkpoints.

### DeepSpeed vs FSDP2 (quick comparison)

| | **FSDP2 (LIVE)** | **DeepSpeed (this)** |
|---|------------------|----------------------|
| **Setup** | `fully_shard(block, mesh=...)` + device mesh | Config dict + `deepspeed.initialize(model, optimizer, config)` |
| **Training step** | `optimizer.zero_grad()` → `loss.backward()` → `optimizer.step()` | `model_engine.backward(loss)` → `model_engine.step()` |
| **Checkpointing** | DCP: `AppState`, `dcp.save`/`dcp.load` | `model_engine.save_checkpoint(dir)` / `load_checkpoint(dir)` |
| **Ray Train** | Same: `TorchTrainer`, `ScalingConfig`, `prepare_data_loader` | Same |

## `Step 1`: Environment Setup

Check Ray cluster status and install dependencies (if needed).

In [1]:
# Check Ray cluster status
!ray status

Node status
---------------------------------------------------------------
Active:
 1 head
Idle:
 1 1xL4:16CPU-64GB-2
 1 1xL4:16CPU-64GB-1
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Total Usage:
 0.0/32.0 CPU
 0.0/2.0 GPU
 0.0/2.0 anyscale/accelerator_shape:1xL4
 0.0/1.0 anyscale/cpu_only:true
 0.0/1.0 anyscale/node-group:1xL4:16CPU-64GB-1
 0.0/1.0 anyscale/node-group:1xL4:16CPU-64GB-2
 0.0/1.0 anyscale/node-group:head
 0.0/3.0 anyscale/provider:aws
 0.0/3.0 anyscale/region:us-west-2
 0B/160.00GiB memory
 12.08KiB/44.64GiB object_store_memory

From request_resources:
 (none)
Pending Demands:
 (no resource demands)
[0m

In [2]:
# Stdlib imports
import os
import tempfile

# --- DIFFERENT from FSDP: DeepSpeed needs these env vars on worker nodes that may not have
# the full CUDA toolkit (e.g. no nvcc). FSDP has no such requirement.
os.environ["DS_BUILD_OPS"] = "0"
os.environ["DS_SKIP_CUDA_CHECK"] = "1"

# Ray Train imports
import ray
import ray.train
import ray.train.torch

# PyTorch core imports
import torch
import torch.distributed as dist
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader

# Computer vision components
from torchvision.models import VisionTransformer
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose

# --- DIFFERENT: We do NOT import FSDP (fully_shard, FSDPModule, etc.) or DCP (dcp.save/load).
# DeepSpeed is imported inside prepare_model() so it only runs on GPU workers, not the driver.

## `Step 2`: Model Definition

We use the same Vision Transformer (ViT) as the FSDP2 LIVE notebook—ideal for demonstrating DeepSpeed ZeRO. No code change to the model for DeepSpeed.

In [3]:
# Initialize the model (same as FSDP LIVE — ViT for 28x28 FashionMNIST)
def init_model(hidden_dim):
    model = VisionTransformer(
        image_size=28,
        patch_size=7,
        num_layers=12,
        num_heads=8,
        hidden_dim=hidden_dim,
        mlp_dim=768,
        num_classes=10,
    )
    model.conv_proj = torch.nn.Conv2d(1, hidden_dim, kernel_size=7, stride=7)
    return model

## `Step 3`: DeepSpeed Config and Model Preparation

DeepSpeed uses a **configuration dictionary** (or JSON file) to specify ZeRO stage, precision, and batch size. No device mesh or `fully_shard()`—one config dict and `deepspeed.initialize()` handle device placement, ZeRO partitioning, and optimizer wrapping.

In [4]:
# --- Required on clusters where worker nodes have CUDA runtime but NOT the full toolkit (no nvcc).
# DeepSpeed's import triggers nvcc checks; this creates a dummy nvcc so the import succeeds.
# FSDP has no equivalent — it does not invoke nvcc.
def _setup_deepspeed_env():
    import os
    import tempfile as _tmpfile
    os.environ["DS_BUILD_OPS"] = "0"
    os.environ["DS_SKIP_CUDA_CHECK"] = "1"
    import torch.utils.cpp_extension
    real_cuda_home = torch.utils.cpp_extension.CUDA_HOME or "/usr/local/cuda"
    fake_cuda = _tmpfile.mkdtemp(prefix="ds_cuda_")
    nvcc_dir = os.path.join(fake_cuda, "bin")
    os.makedirs(nvcc_dir, exist_ok=True)
    cuda_ver = torch.version.cuda or "12.8"
    with open(os.path.join(nvcc_dir, "nvcc"), "w") as f:
        f.write(f'#!/bin/bash\necho "Cuda compilation tools, release {cuda_ver}, V{cuda_ver}.89"\n')
    os.chmod(os.path.join(nvcc_dir, "nvcc"), 0o755)
    torch.utils.cpp_extension.CUDA_HOME = fake_cuda
    import deepspeed
    torch.utils.cpp_extension.CUDA_HOME = real_cuda_home

In [5]:
# --- DIFFERENT from FSDP: FSDP uses Python API (device mesh + fully_shard(block), fully_shard(model)).
# DeepSpeed uses a single config dict and one call to deepspeed.initialize().
def get_deepspeed_config(config):
    return {
        "fp16": {
            "enabled": config.get("use_float16", False)
        },
        "zero_optimization": {
            "stage": 2,  # Stage 2 = partition optimizer + gradients (Stage 3 = also params, like FSDP)
            "allgather_bucket_size": 2e8,
            "reduce_bucket_size": 2e8,
            "overlap_comm": True,
            "contiguous_gradients": True,
        },
        "train_micro_batch_size_per_gpu": config.get("batch_size", 128),
        "gradient_accumulation_steps": 1,
        "gradient_clipping": 1.0,
        "steps_per_print": 1000,
    }


def prepare_model(model, config):
    import deepspeed

    ds_config = get_deepspeed_config(config)

    # Standard PyTorch optimizer; DeepSpeed wraps with ZeRO
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=config.get("learning_rate", 0.001)
    )

    # DeepSpeed handles device placement & ZeRO optimizer wrapping
    model_engine, optimizer, _, _ = deepspeed.initialize(
        model=model,
        optimizer=optimizer,
        config=ds_config
    )

    return model_engine, optimizer

## `Step 4`: Checkpointing with DeepSpeed

DeepSpeed provides **built-in checkpointing**—no PyTorch DCP or `AppState` wrapper. Use `model_engine.save_checkpoint(dir)` and `model_engine.load_checkpoint(dir)`; each worker saves its own partition shard (parallel I/O), similar in spirit to FSDP2's DCP.

In [6]:
# ----------------------------
# Checkpoint Helpers for DeepSpeed + Ray Train
# These handle checkpointing and reporting in a multi-node distributed setting.
# NOTE: DeepSpeed's save_checkpoint/load_checkpoint are different from FSDP's DCP.
# ----------------------------


def report_metrics_and_save_deepspeed_checkpoint(model_engine, metrics, epoch):
    """
    Saves a DeepSpeed checkpoint for the current epoch, then reports metrics to Ray Train.

    Args:
        model_engine: The DeepSpeed engine/wrapped model.
        metrics (dict): Metrics to report (e.g., {"loss": avg_loss}).
        epoch (int): The current epoch (for checkpoint tag).
    """
    with tempfile.TemporaryDirectory() as tmp:
        # Save DeepSpeed checkpoint with current epoch in tag and client_state
        model_engine.save_checkpoint(tmp, tag=f"epoch_{epoch}", client_state={"epoch": epoch})
        # Synchronize all distributed processes before reporting
        dist.barrier()
        # Ray Train expects a checkpoint directory, so create one from the temp path
        checkpoint = ray.train.Checkpoint.from_directory(tmp)
        # Report metrics and checkpoint to Ray Train dashboard/driver
        ray.train.report(metrics, checkpoint=checkpoint)


def save_model_for_inference(model_engine, world_rank):
    """
    Saves a full PyTorch model state_dict for inference. Only rank 0 saves,
    since with ZeRO Stage 2 all params are already replicated per GPU.

    Args:
        model_engine: The DeepSpeed engine/wrapped model.
        world_rank (int): Process rank (only 0 saves).
    """
    with tempfile.TemporaryDirectory() as tmp:
        model_path = os.path.join(tmp, "full-model.pt")
        checkpoint = None

        # Only rank 0 writes the full model state_dict out.
        if world_rank == 0:
            torch.save(model_engine.module.state_dict(), model_path)
            # Create Ray checkpoint from directory (to allow Ray model recovery)
            checkpoint = ray.train.Checkpoint.from_directory(tmp)

        # Ensure all processes synchronize before reporting
        dist.barrier()
        # Empty dict for metrics, specify checkpoint_dir_name for clarity
        ray.train.report({}, checkpoint=checkpoint, checkpoint_dir_name="full_model")


def load_deepspeed_checkpoint(model_engine, ckpt):
    """
    Loads the latest DeepSpeed checkpoint from a Ray Checkpoint.

    Args:
        model_engine: The DeepSpeed engine/wrapped model (already initialized).
        ckpt: A Ray Checkpoint object.

    Returns:
        epoch (int): The epoch number restored from the checkpoint, or 0 if no checkpoint found.
    """
    with ckpt.as_directory() as d:
        # Find all checkpoint tags that match "epoch_*"
        tags = [x for x in os.listdir(d) if x.startswith("epoch_")]
        # Get the latest epoch tag (by string sort, which should work for integer epochs)
        tag = sorted(tags)[-1] if tags else None

        if not tag:
            # No checkpoint found, return 0 (train from scratch)
            return 0

        # Load the checkpoint; client_state can carry arbitrary user state (e.g., epoch)
        _, client_state = model_engine.load_checkpoint(d, tag=tag)
        # Return saved epoch (defaults to 0 if not found)
        return (client_state or {}).get("epoch", 0)

## `Step 5`: Training Function

Below is the training function that runs on each worker. Same structure as FSDP LIVE; only the **model wrapper** (DeepSpeed engine) and **training step** (`model_engine.backward` / `model_engine.step`) differ.

In [7]:
def train_func(config):
    # --- DIFFERENT from FSDP: workers may not have nvcc; patch CUDA_HOME and import DeepSpeed first.
    _setup_deepspeed_env()
    # --- SAME as FSDP: create model with init_model()
    model = init_model(config["hidden_dim"])
    # --- DIFFERENT: FSDP does model.to(device), then prepare_model() which calls fully_shard().
    # DeepSpeed: one prepare_model() call wraps model + optimizer and handles device.
    model_engine, _ = prepare_model(model, config)
    criterion = CrossEntropyLoss()

    # --- SAME as FSDP: resume from checkpoint via ray.train.get_checkpoint()
    start_epoch = 0
    ckpt = ray.train.get_checkpoint()
    if ckpt:
        start_epoch = load_deepspeed_checkpoint(model_engine, ckpt)

    # --- SAME as FSDP: data loading with prepare_data_loader (distributed sampler)
    transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
    data_dir = os.path.join(tempfile.gettempdir(), "data")
    train_data = FashionMNIST(root=data_dir, train=False, download=True, transform=transform)
    train_loader = DataLoader(train_data, batch_size=config.get("batch_size", 128), shuffle=True, num_workers=2)
    train_loader = ray.train.torch.prepare_data_loader(train_loader)

    world_rank = ray.train.get_context().get_world_rank()
    world_size = ray.train.get_context().get_world_size()
    epochs = config["epochs"]

    for epoch in range(start_epoch, epochs):
        if world_size > 1:
            train_loader.sampler.set_epoch(epoch)
        running_loss, num_batches = 0.0, 0
        for images, labels in train_loader:
            outputs = model_engine(images)
            loss = criterion(outputs, labels)
            # --- DIFFERENT: FSDP uses optimizer.zero_grad(); loss.backward(); optimizer.step()
            # DeepSpeed: model_engine.backward(loss) then model_engine.step() (engine handles zero_grad)
            model_engine.backward(loss)
            model_engine.step()
            running_loss += loss.item()
            num_batches += 1
        avg_loss = running_loss / num_batches
        metrics = {"loss": avg_loss, "epoch": epoch + 1}
        report_metrics_and_save_deepspeed_checkpoint(model_engine, metrics, epoch + 1)
        if world_rank == 0:
            print(metrics)

    save_model_for_inference(model_engine, world_rank)

<div class="alert alert-block alert-info">

**Note:** We call `_setup_deepspeed_env()` at the start of `train_func` so DeepSpeed can be imported on worker nodes that have CUDA runtime but not the full toolkit (no `nvcc`). Without this, you would see `FileNotFoundError: .../nvcc`.

</div>

## `Step 6`: Final Training Configuration and Launch

Configure scaling and resource requirements (same as FSDP LIVE). Add **DeepSpeed env vars** in `worker_runtime_env` so workers can import DeepSpeed without the full CUDA toolkit.

In [8]:
# Configure scaling and resource requirements (same as FSDP LIVE)
scaling_config = ray.train.ScalingConfig(num_workers=2, use_gpu=True)

train_loop_config = {
    "epochs": 1,
    "learning_rate": 0.001,
    "batch_size": 128,
    "hidden_dim": 3840,
    "use_float16": False,
}

In [9]:
# Generate a unique run name (same pattern as FSDP LIVE)
import random
import string
name = "deepspeed_mnist_" + ''.join(random.choices(string.ascii_letters + string.digits, k=5))

# TorchTrainer is the main entry point for Ray Train (same as FSDP).
# worker_runtime_env must set DS_BUILD_OPS and DS_SKIP_CUDA_CHECK for DeepSpeed on worker nodes.
trainer = ray.train.torch.TorchTrainer(
    train_func,
    scaling_config=scaling_config,
    train_loop_config=train_loop_config,
    run_config=ray.train.RunConfig(
        storage_path="/mnt/cluster_storage/",
        name=name,
        failure_config=ray.train.FailureConfig(max_failures=2),
        worker_runtime_env={
            "env_vars": {"DS_BUILD_OPS": "0", "DS_SKIP_CUDA_CHECK": "1"}
        },
    ),
)
result = trainer.fit()

2026-02-18 21:22:41,536	INFO worker.py:1821 -- Connecting to existing Ray cluster at address: 10.0.91.12:6379...
2026-02-18 21:22:41,548	INFO worker.py:1998 -- Connected to Ray cluster. View the dashboard at [1m[32mhttps://session-ffbqdd398vb4g8i97u3tsubr23.i.anyscaleuserdata.com [39m[22m
2026-02-18 21:22:41,753	INFO packaging.py:463 -- Pushing file package 'gcs://_ray_pkg_ae31baf8ffd7c346159f8e63a26c57a688dcdc7b.zip' (83.26MiB) to Ray cluster...
2026-02-18 21:22:42,090	INFO packaging.py:476 -- Successfully pushed file package 'gcs://_ray_pkg_ae31baf8ffd7c346159f8e63a26c57a688dcdc7b.zip'.
[36m(TrainController pid=48334)[0m [State Transition] INITIALIZING -> SCHEDULING.
[36m(TrainController pid=48334)[0m Attempting to start training worker group of size 2 with the following resources: [{'GPU': 1}] * 2
[36m(RayTrainWorker pid=11194, ip=10.0.122.26)[0m Setting up process group for: env:// [rank=0, world_size=2]
[36m(TrainController pid=48334)[0m Started training worker group o

## `Step 7`: Inspect Training Artifacts



Training artifacts include:
- `checkpoint_*/` — Epoch checkpoints with DeepSpeed shards
- `full_model/` — Consolidated model for inference (same as FSDP)

In [12]:
# List artifacts
!ls -la /mnt/cluster_storage/{name}/

total 24
drwxr-xr-x  4 ray users 6144 Feb 18 21:26 .
drwxr-xr-x 22 ray  1000 6144 Feb 18 21:22 ..
-rw-r--r--  1 ray users    0 Feb 18 21:22 .validate_storage_marker
drwxr-xr-x  3 ray users 6144 Feb 18 21:25 checkpoint_2026-02-18_21-25-00.927613
-rw-r--r--  1 ray users  335 Feb 18 21:27 checkpoint_manager_snapshot.json
drwxr-xr-x  2 ray users 6144 Feb 18 21:26 full_model


## `Step 8`: Load Model for Inference

The consolidated model (`full-model.pt`) is a standard PyTorch state dict—same as FSDP2. It works without DeepSpeed or any distributed setup.

In [13]:
# Load consolidated model (standard PyTorch state_dict, same as FSDP)
model = init_model(train_loop_config["hidden_dim"])
state = torch.load(f"/mnt/cluster_storage/{name}/full_model/full-model.pt", map_location="cpu")
model.load_state_dict(state)

<All keys matched successfully>

Load some test data

In [14]:
transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
test_data = FashionMNIST(
    root=".", train=False, download=True, transform=transform
)
test_data

Dataset FashionMNIST
    Number of datapoints: 10000
    Root location: .
    Split: Test
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=(0.5,), std=(0.5,))
           )

Run inference

In [15]:
model.eval()
with torch.no_grad():
    out = model(test_data.data[0].reshape(1, 1, 28, 28).float())
    predicted_label = out.argmax().item()
    test_label = test_data.targets[0].item()
    print(f"{predicted_label=} {test_label=}")

predicted_label=4 test_label=9
