# PyTorch DDP Fashion MNIST Training Example

This example demonstrates how to train a convolutional neural network to classify images using the [Fashion MNIST](https://github.com/zalandoresearch/fashion-mnist) dataset and [PyTorch Distributed Data Parallel (DDP)](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html).

This notebook walks you through running that example locally, and how to easily scale PyTorch DDP across multiple nodes with Kubeflow TrainJob.

## Install the Kubeflow SDK

You need to install the Kubeflow SDK to interact with Kubeflow APIs:

In [None]:
# TODO (astefanutti): Change to the Kubeflow SDK when it's available.
!pip install git+https://github.com/kubeflow/training-operator.git@master#subdirectory=sdk_v2

## Install the PyTorch dependencies

You also need to install PyTorch and Torchvision to be able to run the example locally:

In [None]:
!pip install torch==2.5.1
!pip install torchvision==0.20.0

## Define the training function

In [10]:
def train_fashion_mnist():
    import os

    import torch
    import torch.distributed as dist
    import torch.nn.functional as F
    from torch import nn
    from torch.utils.data import DataLoader, DistributedSampler
    from torchvision import datasets, transforms

    # Define the PyTorch CNN model to be trained
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 20, 5, 1)
            self.conv2 = nn.Conv2d(20, 50, 5, 1)
            self.fc1 = nn.Linear(4 * 4 * 50, 500)
            self.fc2 = nn.Linear(500, 10)

        def forward(self, x):
            x = F.relu(self.conv1(x))
            x = F.max_pool2d(x, 2, 2)
            x = F.relu(self.conv2(x))
            x = F.max_pool2d(x, 2, 2)
            x = x.view(-1, 4 * 4 * 50)
            x = F.relu(self.fc1(x))
            x = self.fc2(x)
            return F.log_softmax(x, dim=1)

    # Use NCCL is a GPU is available, otherwise use Gloo as communication backend
    device, backend = ("cuda", "nccl") if torch.cuda.is_available() else ("cpu", "gloo")

    print(f"Using Device: {device}, Backend: {backend}")

    # Setup PyTorch Distributed
    local_rank = int(os.getenv("LOCAL_RANK", 0))
    dist.init_process_group(backend=backend)

    print(
        "Distributed Training for WORLD_SIZE: {}, RANK: {}, LOCAL_RANK: {}".format(
            dist.get_world_size(),
            dist.get_rank(),
            local_rank,
        )
    )

    # Create the model and load it into the device
    device = torch.device(f"{device}:{local_rank}")
    model = nn.parallel.DistributedDataParallel(Net().to(device))

    # Retrieve the Fashion-MNIST dataset
    if local_rank == 0:
        # Only download the dataset from local rank 0
        dataset = datasets.FashionMNIST(
            "./data",
            train=True,
            download=True,
            transform=transforms.Compose([transforms.ToTensor()]),
        )
        dist.barrier()
    else:
        # Wait for local rank 0 to complete downloading the dataset and load it
        dist.barrier()
        dataset = datasets.FashionMNIST(
            "./data",
            train=True,
            download=False,
            transform=transforms.Compose([transforms.ToTensor()]),
        )

    # Shard the dataset accross workers
    train_loader = DataLoader(
        dataset,
        batch_size=100,
        sampler=DistributedSampler(dataset),
        pin_memory=torch.cuda.is_available(),
    )

    # Setup the optimization loop
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    # TODO(astefanutti): add parameters to the training function
    for epoch in range(1, 5):
        model.train()

        # Iterate over mini-batches from the training set
        for batch_idx, (inputs, labels) in enumerate(train_loader):
            # Copy the data to the GPU device if available
            inputs, labels = inputs.to(device), labels.to(device)
            # Forward pass
            outputs = model(inputs)
            loss = F.nll_loss(outputs, labels)
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch_idx % 10 == 0 and dist.get_rank() == 0:
                print(
                    "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                        epoch,
                        batch_idx * len(inputs),
                        len(train_loader.dataset),
                        100.0 * batch_idx / len(train_loader),
                        loss.item(),
                    )
                )

    # Wait for the distributed training to complete
    dist.barrier()
    if dist.get_rank() == 0:
        print("Training is finished")

    # Finally clean up PyTorch distributed
    dist.destroy_process_group()

## Dry-run the training locally

In [17]:
import os

# Set the Torch Distributed env variables so the training function can be run in the notebook
# See https://pytorch.org/docs/stable/elastic/run.html#environment-variables
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "1234"

# Run the training function locally
train_fashion_mnist()

Using Device: cpu, Backend: gloo
Distributed Training for WORLD_SIZE: 1, RANK: 0, LOCAL_RANK: 0
Training is finished


## Scale PyTorch DDP with Kubeflow TrainJob

You can use `TrainingClient()` from the Kubeflow SDK to communicate with Kubeflow APIs and scale your training function across multiple PyTorch training nodes.

Kubeflow Trainer creates a `TrainJob` resource and automatically sets the appropriate environment variables to set up PyTorch in distributed environment.

In [13]:
from kubeflow.training import Trainer, TrainingClient
client = TrainingClient()

## List the Training Runtimes

You can get the list of available Training Runtimes to start your TrainJob:

In [17]:
for runtime in client.list_runtimes():
    print(runtime)

Runtime(name='torch-distributed', phase='pre-training', accelerator='Unknown', accelerator_count='Unknown')


Each Training Runtime shows whether you can use it for pre-training or post-training.
Additionally, it shows available accelerator type and number of available resources.

## Run the distributed TrainJob

In [18]:
job_name = client.train(
    # Use one the of the training runtimes installed on your Kubernetes cluster
    runtime_ref="torch-distributed",
    trainer=Trainer(
        func=train_fashion_mnist,
        # Set how many worker Pods you want the job to be distributed into
        num_nodes=4,
        # Set the resources for each worker Pod
        resources_per_node={
            "cpu": 1,
            "memory": "16Gi",
            # Uncomment to distribute the TrainJob on nodes with GPUs
            #"nvidia.com/gpu": 1,
        },
    ),
)

## Check the TrainJob components

You can check the details of the TrainJob that's created:

In [19]:
job = client.get_job(job_name)
print(job)

TrainJob(name='t4a503bc3394', runtime_ref='torch-distributed', creation_timestamp=datetime.datetime(2025, 1, 30, 15, 1, 46, tzinfo=tzutc()), components=[Component(name='trainer-node-0', status='Running', device='gpu', device_count='1', pod_name='t4a503bc3394-trainer-node-0-0-xh4mb'), Component(name='trainer-node-1', status='Running', device='gpu', device_count='1', pod_name='t4a503bc3394-trainer-node-0-1-4vjkq'), Component(name='trainer-node-2', status='Running', device='gpu', device_count='1', pod_name='t4a503bc3394-trainer-node-0-2-f422f'), Component(name='trainer-node-3', status='Running', device='gpu', device_count='1', pod_name='t4a503bc3394-trainer-node-0-3-grcdm')], status='Created')


Since the TrainJob is distributed using 4 nodes, the TrainJob creates 4 components: `trainer-node-0`, ..., `trainer-node-3`, and you can get the individual status for each of these components.

## Watch the TrainJob logs

In [20]:
_ = client.get_job_logs(job_name, follow=True)

[trainer-node]: Using Device: cuda, Backend: nccl
[trainer-node]: Distributed Training for WORLD_SIZE: 4, RANK: 0, LOCAL_RANK: 0
[trainer-node]: Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
[trainer-node]: Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz
100%|██████████| 26.4M/26.4M [00:01<00:00, 15.2MB/s]
[trainer-node]: Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw
[trainer-node]: Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
[trainer-node]: Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
100%|██████████| 29.5k/29.5k [00:00<00:00, 324kB/s]
[trainer-node]: Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMN

Each node processes it's assigned shard of the Fashion-MNIST dataset.
As the `TrainJob` is distributed on 4 nodes, and the dataset contains a total of 60 000 samples, each node processes 15 000 samples.

## Delete the TrainJob

In [17]:
client.delete_job(job_name)