## ray training with pytorch

In [1]:
# !pip install 'ray[default]'

In [2]:
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms

import ray
from ray.util.sgd.torch import TorchTrainer
from ray.util.sgd.torch import TrainingOperator
# https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py
from ray.util.sgd.torch.resnet import ResNet18


In [4]:
ray.init("ray:// 10.3.249.181:10001")

In [3]:
def cifar_creator(config):
    """Returns dataloaders to be used in `train` and `validate`."""
    tfms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])  # meanstd transformation
    train_loader = DataLoader(
        CIFAR10(root="~/data", download=True, transform=tfms), batch_size=config["batch"])
    validation_loader = DataLoader(
        CIFAR10(root="~/data", download=True, transform=tfms), batch_size=config["batch"])
    return train_loader, validation_loader

def optimizer_creator(model, config):
    """Returns an optimizer (or multiple)"""
    return torch.optim.SGD(model.parameters(), lr=config["lr"])

CustomTrainingOperator = TrainingOperator.from_creators(
    model_creator=ResNet18, # A function that returns a nn.Module
    optimizer_creator=optimizer_creator, # A function that returns an optimizer
    data_creator=cifar_creator, # A function that returns dataloaders
    loss_creator=torch.nn.CrossEntropyLoss  # A loss function
    )


In [5]:
trainer = TorchTrainer(
    training_operator_cls=CustomTrainingOperator,
    config={"lr": 0.01, # used in optimizer_creator
            "batch": 64 # used in data_creator
           },
    num_workers=2,  # amount of parallelism
    use_gpu=torch.cuda.is_available(),
    use_tqdm=True)

[2m[36m(pid=1925)[0m 2021-08-26 20:07:53,474	INFO distributed_torch_runner.py:58 -- Setting up process group for: tcp://10.1.229.3:60591 [rank=0]
[2m[36m(pid=1930)[0m 2021-08-26 20:07:53,509	INFO distributed_torch_runner.py:58 -- Setting up process group for: tcp://10.1.229.3:60591 [rank=1]


[2m[36m(pid=1925)[0m Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /root/data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]
  0%|          | 361472/170498071 [00:00<00:47, 3560279.01it/s]
  2%|▏         | 2802688/170498071 [00:00<00:10, 15706361.63it/s]
  4%|▎         | 5981184/170498071 [00:00<00:07, 22966786.72it/s]
  5%|▌         | 8832000/170498071 [00:00<00:06, 25134476.46it/s]
  7%|▋         | 11682816/170498071 [00:00<00:06, 26269610.42it/s]
  9%|▊         | 14517248/170498071 [00:00<00:05, 26928265.64it/s]
 11%|█         | 18991104/170498071 [00:00<00:04, 32728906.50it/s]
 14%|█▍        | 24266752/170498071 [00:00<00:03, 39084635.31it/s]
 18%|█▊        | 30659584/170498071 [00:00<00:02, 46836857.39it/s]
 23%|██▎       | 38471680/170498071 [00:01<00:02, 56480128.74it/s]
 28%|██▊       | 48194560/170498071 [00:01<00:01, 68939832.40it/s]
 35%|███▍      | 59291648/170498071 [00:01<00:01, 81716631.59it/s]
 41%|████▏     | 70641664/170498071 [00:01<00:01, 91335166.57it/s]
 48%|████▊     | 81762304/170498071 [00:01<00:00, 97332670.47it/s]
 55%|█████▍    | 92948

[2m[36m(pid=1925)[0m Extracting /root/data/cifar-10-python.tar.gz to /root/data
[2m[36m(pid=1925)[0m Files already downloaded and verified
[2m[36m(pid=1930)[0m Files already downloaded and verified
[2m[36m(pid=1930)[0m Files already downloaded and verified


In [None]:
stats = trainer.train()
print(trainer.validate())

In [None]:
torch.save(trainer.state_dict(), "checkpoint.pt")
trainer.shutdown()
print("success!")