# Set up dataset and model

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

def get_dataset():
    return datasets.FashionMNIST(
        root="/tmp/data",
        train=True,
        download=True,
        transform=ToTensor(),
    )

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, inputs):
        inputs = self.flatten(inputs)
        logits = self.linear_relu_stack(inputs)
        return logits

# Define single-worker PyTorch training function

In [2]:
def train_func():
    num_epochs = 3
    batch_size = 64

    dataset = get_dataset()
    dataloader = DataLoader(dataset, batch_size=batch_size)

    model = NeuralNetwork()

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

    for epoch in range(num_epochs):
        for inputs, labels in dataloader:
            optimizer.zero_grad()
            pred = model(inputs)
            loss = criterion(pred, labels)
            loss.backward()
            optimizer.step()
        print(f"epoch: {epoch}, loss: {loss.item()}")

# Execute training function

In [3]:
train_func()

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to /tmp/data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:15<00:00, 1657025.38it/s]


Extracting /tmp/data/FashionMNIST/raw/train-images-idx3-ubyte.gz to /tmp/data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to /tmp/data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 597431.05it/s]


Extracting /tmp/data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to /tmp/data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to /tmp/data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:01<00:00, 2551066.34it/s]


Extracting /tmp/data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to /tmp/data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to /tmp/data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 14252328.05it/s]


Extracting /tmp/data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to /tmp/data/FashionMNIST/raw

epoch: 0, loss: 0.9412856101989746
epoch: 1, loss: 0.7699680328369141
epoch: 2, loss: 0.7065638303756714


# Convert this to a distributed multi-worker training function
Use the ray.train.torch.prepare_model and ray.train.torch.prepare_data_loader utility functions to set up your model and data for distributed training. This automatically wraps the model with DistributedDataParallel and places it on the right device, and adds DistributedSampler to the DataLoaders.

In [5]:
import ray.train.torch

def train_func_distributed():
    num_epochs = 3
    batch_size = 64

    dataset = get_dataset()
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    dataloader = ray.train.torch.prepare_data_loader(dataloader)

    model = NeuralNetwork()
    model = ray.train.torch.prepare_model(model)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

    for epoch in range(num_epochs):
        if ray.train.get_context().get_world_size() > 1:
            dataloader.sampler.set_epoch(epoch)

        for inputs, labels in dataloader:
            optimizer.zero_grad()
            pred = model(inputs)
            loss = criterion(pred, labels)
            loss.backward()
            optimizer.step()
        print(f"epoch: {epoch}, loss: {loss.item()}")

# Instantiate a TorchTrainer with 4 workers, and use it to run the new training function.

In [6]:
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
# For GPU Training, set `use_gpu` to True.
use_gpu = False
trainer = TorchTrainer(
    train_func_distributed,
    scaling_config=ScalingConfig(num_workers=4, use_gpu=use_gpu))
results = trainer.fit()

0,1
Current time:,2024-02-27 14:49:32
Running for:,00:00:46.34
Memory:,6.5/7.7 GiB

Trial name,status,loc
TorchTrainer_ed52b_00000,TERMINATED,192.168.31.225:142601


[36m(RayTrainWorker pid=142658)[0m Setting up process group for: env:// [rank=0, world_size=4]
[36m(TorchTrainer pid=142601)[0m Started distributed worker processes: 
[36m(TorchTrainer pid=142601)[0m - (ip=192.168.31.225, pid=142658) world_rank=0, local_rank=0, node_rank=0
[36m(TorchTrainer pid=142601)[0m - (ip=192.168.31.225, pid=142659) world_rank=1, local_rank=1, node_rank=0
[36m(TorchTrainer pid=142601)[0m - (ip=192.168.31.225, pid=142660) world_rank=2, local_rank=2, node_rank=0
[36m(TorchTrainer pid=142601)[0m - (ip=192.168.31.225, pid=142661) world_rank=3, local_rank=3, node_rank=0
[36m(RayTrainWorker pid=142658)[0m Moving model to device: cpu
[36m(RayTrainWorker pid=142658)[0m Wrapping provided model in DistributedDataParallel.


[36m(RayTrainWorker pid=142658)[0m epoch: 0, loss: 1.722748875617981
[36m(RayTrainWorker pid=142658)[0m epoch: 1, loss: 0.8981984257698059[32m [repeated 4x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)[0m
[36m(RayTrainWorker pid=142658)[0m epoch: 2, loss: 0.7986962795257568[32m [repeated 4x across cluster][0m
Trial TorchTrainer_ed52b_00000 completed. Last result: 


2024-02-27 14:49:32,314	INFO tune.py:1042 -- Total run time: 46.50 seconds (46.33 seconds for the tuning loop).
