# Getting Started - Distributed Model Training #
Ray's quickstart on Distributed Model Training:
https://docs.ray.io/en/latest/ray-overview/getting-started.html

combined with Torch's Quickstart guide:
https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import torchmetrics

In [None]:
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

In [None]:
# Get cpu, gpu or mps device for training.
# device = (
#     "cuda"
#     if torch.cuda.is_available()
#     else "mps"
#     if torch.backends.mps.is_available()
#     else "cpu"
# )
device="cpu"

class NaiveDense(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, inputs):
        inputs = self.flatten(inputs)
        logits = self.linear_relu_stack(inputs)
        return logits

def train_func():
    num_epochs=2
    batch_size=128
    model_state_name = "serial_model.pth"
    train_dataloader = DataLoader(training_data, batch_size=batch_size)

    model = NaiveDense().to(device)
    print(f"Using {device} device")
    print(model)

    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    # optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    size = len(train_dataloader.dataset)
    model.train()
    for epoch in range(num_epochs):
        print(f"training epoch [{epoch}/{num_epochs}]")
        for batch, (X, y) in enumerate(train_dataloader):
            X, y = X.to(device), y.to(device)

            # Compute prediction error
            pred = model(X)
            loss = loss_fn(pred, y)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch % 100 == 0:
                loss, current = loss.item(), (batch + 1) * len(X)
                print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
    torch.save(model.state_dict(), model_state_name)
    print(f"Saved PyTorch Model State to {model_state_name}")


This training function can now be executed. note the execution time

In [None]:
%%time
# good to run this anyway to download the packages to tmp if first time
train_func()

In [None]:
def test_func(model_state="model.pth"):
    # note that the dataloader handles the last batch not matching
    # the batch size by automatically adjusting the batch size
    batch_size=512
    test_dataloader = DataLoader(test_data, batch_size=batch_size)
    model = NaiveDense().to(device)
    model.load_state_dict(torch.load(model_state))
    loss_fn = nn.CrossEntropyLoss()
    size = len(test_dataloader.dataset)
    num_batches = len(test_dataloader)
    model.eval()
    test_loss, correct = 0, 0
    for epoch in range(1):
        # test_total = 0
        with torch.no_grad():
            for X, y in test_dataloader:
                X, y = X.to(device), y.to(device)
                pred = model(X)
                test_loss += loss_fn(pred, y).item()
                correct += (pred.argmax(1) == y).type(torch.float).sum().item()
                # test_total += len(y)
                # print(f"batch size: {len(y)}, Test total: {test_total} / {size}")
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [119]:
test_func("serial_model.pth")

Test Error: 
 Accuracy: 69.9%, Avg loss: 0.797402 



Now let’s convert this to a distributed multi-worker training function!

All you have to do is use the **ray.train.torch.prepare_model** and **ray.train.torch.prepare_data_loader** utility functions to easily setup your model & data for distributed training. This will automatically wrap your model with DistributedDataParallel and place it on the right device, and add DistributedSampler to your DataLoaders.



In [76]:
import ray.train.torch
from ray import train
from ray.train.torch import TorchTrainer
from ray.air import session, Checkpoint, RunConfig, CheckpointConfig, ScalingConfig

In [85]:
def train_func_distributed(config):
    batch_size=128
    train_dataloader = DataLoader(training_data, batch_size=batch_size)
    train_dataloader = train.torch.prepare_data_loader(train_dataloader)

    model = NaiveDense()
    model = train.torch.prepare_model(model)

    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    # optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    size = len(train_dataloader.dataset)
    acc = torchmetrics.Accuracy(task="multiclass", num_classes=10)
    # for averaging loss
    mean_valid_loss = torchmetrics.MeanMetric()

    for epoch in range(config["num_epochs"]):
        for batch, (X, y) in enumerate(train_dataloader):
            model.train()
            # Compute prediction error
            pred = model(X)
            loss = loss_fn(pred, y)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # evaluate
            model.eval()
            with torch.no_grad():
                pred = model(X)
                valid_loss = loss_fn(pred, y)
                # save loss in aggregator
                mean_valid_loss(valid_loss)
                acc(pred, y)

            # collect all metrics
            # use .item() to obtain a value that can be reported
            valid_loss = valid_loss.item()
            accuracy_collected = acc.compute().item()
            mean_valid_loss_collected = mean_valid_loss.compute().item()

            # save checkpoints
            state_dict = model.state_dict()
            checkpoint = Checkpoint.from_dict(
                dict(epoch=epoch, model_weights=state_dict)
            )
            session.report({
                "accuracy_collected": accuracy_collected,
                "valid_loss": valid_loss,
                "mean_valid_loss_collected": mean_valid_loss_collected,
            }, checkpoint=checkpoint)

            # reset for next epoch
            acc.reset()
            mean_valid_loss.reset()

Note setting GPU=True here does not work unless you actually have multiple GPUs!!

In [120]:
# %%time
# For GPU Training, set `use_gpu` to True.
use_gpu = False
checkpoint_config = CheckpointConfig(
    num_to_keep=1, checkpoint_score_attribute="valid_loss", checkpoint_score_order="min"
)
trainer = TorchTrainer(
    train_func_distributed,
    train_loop_config={"num_epochs": 3},
    scaling_config=ScalingConfig(num_workers=5, use_gpu=use_gpu),
    run_config=RunConfig(checkpoint_config=checkpoint_config)
)

result = trainer.fit()
print(result.metrics)

2023-05-14 03:34:04,861	INFO data_parallel_trainer.py:357 -- GPUs are detected in your Ray cluster, but GPU training is not enabled for this trainer. To enable GPU training, make sure to set `use_gpu` to True in your scaling config.
[2m[36m(TrainTrainable pid=141220)[0m 2023-05-14 03:34:06,621	INFO data_parallel_trainer.py:357 -- GPUs are detected in your Ray cluster, but GPU training is not enabled for this trainer. To enable GPU training, make sure to set `use_gpu` to True in your scaling config.
[2m[36m(TorchTrainer pid=141220)[0m 2023-05-14 03:34:06,624	INFO data_parallel_trainer.py:357 -- GPUs are detected in your Ray cluster, but GPU training is not enabled for this trainer. To enable GPU training, make sure to set `use_gpu` to True in your scaling config.
[2m[36m(RayTrainWorker pid=141265)[0m 2023-05-14 03:34:12,335	INFO config.py:86 -- Setting up process group for: env:// [rank=0, world_size=5]
[2m[36m(RayTrainWorker pid=141265)[0m 2023-05-14 03:34:13,908	INFO train

Trial name,accuracy_collected,date,done,hostname,iterations_since_restore,mean_valid_loss_collected,node_ip,pid,should_checkpoint,time_since_restore,time_this_iter_s,time_total_s,timestamp,training_iteration,trial_id,valid_loss
TorchTrainer_b429c_00000,0.0984375,2023-05-14_03-34-14,False,jim-desktop,1,2.2977,192.168.0.135,141220,True,7.49983,7.49983,7.49983,1684049654,1,b429c_00000,2.29408


2023-05-14 03:34:59,881	INFO tune.py:945 -- Total run time: 55.03 seconds (55.02 seconds for the tuning loop).


{'accuracy_collected': 0.6187499761581421, 'valid_loss': 1.4619611501693726, 'mean_valid_loss_collected': 1.491119623184204, 'timestamp': 1684049697, 'time_this_iter_s': 0.11327934265136719, 'should_checkpoint': True, 'done': True, 'training_iteration': 282, 'trial_id': 'b429c_00000', 'date': '2023-05-14_03-34-57', 'time_total_s': 50.72310280799866, 'pid': 141220, 'hostname': 'jim-desktop', 'node_ip': '192.168.0.135', 'config': {'train_loop_config': {'num_epochs': 3}}, 'time_since_restore': 50.72310280799866, 'iterations_since_restore': 282, 'experiment_tag': '0'}


In [121]:
result.checkpoint  # last saved checkpoint
result.best_checkpoints

[(TorchCheckpoint(local_path=/home/jim/ray_results/TorchTrainer_2023-05-14_03-34-04/TorchTrainer_b429c_00000_0_2023-05-14_03-34-04/checkpoint_000281),
  {'accuracy_collected': 0.6187499761581421,
   'valid_loss': 1.4619611501693726,
   'mean_valid_loss_collected': 1.491119623184204,
   'timestamp': 1684049697,
   'time_this_iter_s': 0.11327934265136719,
   'should_checkpoint': True,
   'done': False,
   'training_iteration': 282,
   'trial_id': 'b429c_00000',
   'date': '2023-05-14_03-34-57',
   'time_total_s': 50.72310280799866,
   'pid': 141220,
   'hostname': 'jim-desktop',
   'node_ip': '192.168.0.135',
   'time_since_restore': 50.72310280799866,
   'iterations_since_restore': 282,
   'experiment_tag': '0',
   'config/train_loop_config/num_epochs': 3})]

In [122]:
checkpoint_dict = result.best_checkpoints[0][0].to_dict()
torch.save(checkpoint_dict.get("model_weights"), "parallel_model.pth")

In [123]:
test_func("parallel_model.pth")

Test Error: 
 Accuracy: 60.8%, Avg loss: 1.508966 



unsure why the distributed version has a much lower accuracy...