# Using `LocalTrainerClient` for MNIST image classification with PyTorch DDP

This notebook uses the `LocalTrainerClient` to train an image classification model using the MNIST fashion dataset.

The `LocalTrainerClient` runs training jobs locally in Docker containers. No Kubernetes cluster is required.

## Install the KubeFlow SDK

In [None]:
# TODO (eoinfennessy): Update repo and branch before merging.
!pip install git+https://github.com/eoinfennessy/trainer.git@add-local-trainer-client#subdirectory=sdk

## Define the training function

This function trains a CNN model using the Fashion MNIST dataset.


In [None]:
def train_pytorch():
    import os

    import torch
    from torch import nn
    import torch.nn.functional as F

    from torchvision import datasets, transforms
    import torch.distributed as dist
    from torch.utils.data import DataLoader, DistributedSampler

    # [1] Configure CPU/GPU device and distributed backend.
    # Kubeflow Trainer will automatically configure the distributed environment.
    device, backend = ("cuda", "nccl") if torch.cuda.is_available() else ("cpu", "gloo")
    dist.init_process_group(backend=backend)

    local_rank = int(os.getenv("LOCAL_RANK", 0))
    print(
        "Distributed Training with WORLD_SIZE: {}, RANK: {}, LOCAL_RANK: {}.".format(
            dist.get_world_size(),
            dist.get_rank(),
            local_rank,
        )
    )

    # [2] Define PyTorch CNN Model to be trained.
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 20, 5, 1)
            self.conv2 = nn.Conv2d(20, 50, 5, 1)
            self.fc1 = nn.Linear(4 * 4 * 50, 500)
            self.fc2 = nn.Linear(500, 10)

        def forward(self, x):
            x = F.relu(self.conv1(x))
            x = F.max_pool2d(x, 2, 2)
            x = F.relu(self.conv2(x))
            x = F.max_pool2d(x, 2, 2)
            x = x.view(-1, 4 * 4 * 50)
            x = F.relu(self.fc1(x))
            x = self.fc2(x)
            return F.log_softmax(x, dim=1)

    # [3] Attach model to the correct device.
    device = torch.device(f"{device}:{local_rank}")
    model = nn.parallel.DistributedDataParallel(Net().to(device))
    model.train()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    # [4] Get the Fashion-MNIST dataset and distributed it across all available devices.
    dataset = datasets.FashionMNIST(
        "./data",
        train=True,
        download=True,
        transform=transforms.Compose([transforms.ToTensor()]),
    )
    train_loader = DataLoader(
        dataset,
        batch_size=100,
        sampler=DistributedSampler(dataset),
    )

    # [5] Define the training loop.
    for epoch in range(3):
        for batch_idx, (inputs, labels) in enumerate(train_loader):
            # Attach tensors to the device.
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            outputs = model(inputs)
            loss = F.nll_loss(outputs, labels)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if batch_idx % 10 == 0 and dist.get_rank() == 0:
                print(
                    "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                        epoch,
                        batch_idx * len(inputs),
                        len(train_loader.dataset),
                        100.0 * batch_idx / len(train_loader),
                        loss.item(),
                    )
                )

    # Wait for the training to complete and destroy to PyTorch distributed process group.
    dist.barrier()
    if dist.get_rank() == 0:
        print("Training is finished")
    dist.destroy_process_group()

## Create the trainer client

The `LocalTrainerClient` exposes the same interface as the `TrainerClient`.

When you wish to move your job to a Kubernetes cluster, simply change the
training client -- the rest of the notebook will work without modification.

In [None]:
import os

from kubeflow.trainer import LocalTrainerClient, TrainerClient, CustomTrainer

exec_mode = os.getenv("KUBEFLOW_TRAINER_EXEC_MODE", "local")

if exec_mode == "local":
    client = LocalTrainerClient()
else:
    client = TrainerClient()

## List runtimes

Predefined training runtimes are included as part of the package. Currently only the `torch-distributed` runtime is included.

In [None]:
client.list_runtimes()

## Get runtime

In [None]:
runtime = client.get_runtime("torch-distributed")

## Start the training job

This job uses the `torch-distrbuted` runtime to run the `train_pytorch` training function. The job runs on four worker node containers.

In [None]:
job_name = client.train(
        runtime=runtime,
        trainer=CustomTrainer(
            func=train_pytorch,
            num_nodes=4,
        )
    )

## Follow job logs

In [None]:
_ = client.get_job_logs(job_name, follow=True)

## Optional: Examine Docker resources

In a terminal, run the following to list the containers running the training job:

```shell
docker ps -a --filter label=trainer.kubeflow.org/train-job-name
```

Example output:

```
CONTAINER ID   IMAGE                                           COMMAND                   CREATED          STATUS                     PORTS     NAMES
e116e42af00a   pytorch/pytorch:2.5.0-cuda12.4-cudnn9-runtime   "bash -c '\nread -r -…"   11 minutes ago   Exited (0) 7 minutes ago             kubeflow-trainer-l20b00d08df5-3
0f9d891edd74   pytorch/pytorch:2.5.0-cuda12.4-cudnn9-runtime   "bash -c '\nread -r -…"   11 minutes ago   Exited (0) 7 minutes ago             kubeflow-trainer-l20b00d08df5-2
1672a222d360   pytorch/pytorch:2.5.0-cuda12.4-cudnn9-runtime   "bash -c '\nread -r -…"   11 minutes ago   Exited (0) 7 minutes ago             kubeflow-trainer-l20b00d08df5-1
e1f11b5fad3c   pytorch/pytorch:2.5.0-cuda12.4-cudnn9-runtime   "bash -c '\nread -r -…"   11 minutes ago   Exited (0) 7 minutes ago             kubeflow-trainer-l20b00d08df5-0
```

Run the following to see the Docker network:

```shell
docker network ls --filter label=trainer.kubeflow.org/train-job-name
```

Example output:

```
NETWORK ID     NAME                            DRIVER    SCOPE
157a6ed5752f   kubeflow-trainer-l20b00d08df5   bridge    local
```

## Delete the job

In [None]:
client.delete_job(job_name)