implementing pytorch example from https://docs.ray.io/en/latest/train/train.html

In [1]:
import os

import torch
import torch.nn as nn
import torch.optim as optim

from ray import train
import ray.train.torch
from ray.train import Trainer

from ray.util import connect as ray_connect
from ray.util import disconnect as ray_disconnect
from ray.util.client import ray as rayclient

In [None]:
REMOTE = False
if REMOTE:
    if rayclient.is_connected():
        ray_disconnect()

    ray_connect('{ray_head}:10001'.format(ray_head=os.environ['RAY_CLUSTER']))

In [2]:
num_samples = 20
input_size = 10
layer_size = 15
output_size = 5

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.layer1 = nn.Linear(input_size, layer_size)
        self.relu = nn.ReLU()
        self.layer2 = nn.Linear(layer_size, output_size)

    def forward(self, input):
        return self.layer2(self.relu(self.layer1(input)))

# In this example we use a randomly generated dataset.
input = torch.randn(num_samples, input_size)
labels = torch.randn(num_samples, output_size)

In [3]:
def train_func():
    num_epochs = 3
    model = NeuralNetwork()
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1)

    for epoch in range(num_epochs):
        output = model(input)
        loss = loss_fn(output, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f"epoch: {epoch}, loss: {loss.item()}")

In [4]:
train_func()

epoch: 0, loss: 1.1057108640670776
epoch: 1, loss: 1.0742485523223877
epoch: 2, loss: 1.0476206541061401


In [5]:
def train_func_distributed():
    num_epochs = 3
    model = NeuralNetwork()
    model = train.torch.prepare_model(model)
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1)

    for epoch in range(num_epochs):
        output = model(input)
        loss = loss_fn(output, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f"epoch: {epoch}, loss: {loss.item()}")

In [6]:
trainer = Trainer(backend="torch", num_workers=4)

2022-05-18 17:37:19,773	INFO trainer.py:223 -- Trainer logs will be logged in: /home/mcliffor/ray_results/train_2022-05-18_17-37-19


In [7]:
# For GPU Training, set `use_gpu` to True.
# trainer = Trainer(backend="torch", num_workers=4, use_gpu=True)

trainer.start()
results = trainer.run(train_func_distributed)
trainer.shutdown()

[2m[36m(BaseWorkerMixin pid=3131801)[0m 2022-05-18 17:37:45,391	INFO torch.py:335 -- Setting up process group for: env:// [rank=3, world_size=4]
[2m[36m(BaseWorkerMixin pid=3131801)[0m [W socket.cpp:558] [c10d] The client socket cannot be initialized to connect to [::ffff:192.168.1.222]:39107 (errno: 97 - Address family not supported by protocol).
[2m[36m(BaseWorkerMixin pid=3131801)[0m [W socket.cpp:558] [c10d] The client socket cannot be initialized to connect to [::ffff:192.168.1.222]:39107 (errno: 97 - Address family not supported by protocol).
[2m[36m(BaseWorkerMixin pid=3131800)[0m 2022-05-18 17:37:45,377	INFO torch.py:335 -- Setting up process group for: env:// [rank=2, world_size=4]
[2m[36m(BaseWorkerMixin pid=3131800)[0m [W socket.cpp:558] [c10d] The client socket cannot be initialized to connect to [::ffff:192.168.1.222]:39107 (errno: 97 - Address family not supported by protocol).
[2m[36m(BaseWorkerMixin pid=3131800)[0m [W socket.cpp:558] [c10d] The client 

[2m[36m(BaseWorkerMixin pid=3131801)[0m epoch: 0, loss: 1.0745793581008911
[2m[36m(BaseWorkerMixin pid=3131801)[0m epoch: 1, loss: 1.0589888095855713
[2m[36m(BaseWorkerMixin pid=3131801)[0m epoch: 2, loss: 1.0449146032333374
[2m[36m(BaseWorkerMixin pid=3131800)[0m epoch: 0, loss: 1.0745793581008911
[2m[36m(BaseWorkerMixin pid=3131800)[0m epoch: 1, loss: 1.0589888095855713
[2m[36m(BaseWorkerMixin pid=3131800)[0m epoch: 2, loss: 1.0449146032333374
[2m[36m(BaseWorkerMixin pid=3131798)[0m epoch: 0, loss: 1.0745793581008911
[2m[36m(BaseWorkerMixin pid=3131798)[0m epoch: 1, loss: 1.0589888095855713
[2m[36m(BaseWorkerMixin pid=3131798)[0m epoch: 2, loss: 1.0449146032333374
[2m[36m(BaseWorkerMixin pid=3131799)[0m epoch: 0, loss: 1.0745793581008911
[2m[36m(BaseWorkerMixin pid=3131799)[0m epoch: 1, loss: 1.0589888095855713
[2m[36m(BaseWorkerMixin pid=3131799)[0m epoch: 2, loss: 1.0449146032333374
