In [None]:
import os
from typing import Dict

import torch
import torch.nn.functional as F

import ray
import ray.train as train
from ray.train.trainer import Trainer
from ray.train.callbacks import JsonLoggerCallback
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

In [None]:
training_data = datasets.FashionMNIST(
    root="~/data",
    train=True,
    download=True,
    transform=ToTensor(),
)
# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="~/data",
    train=False,
    download=True,
    transform=ToTensor(),
)

In [None]:
# Define model-1
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(),
            nn.Linear(512, 10), nn.ReLU())

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits 

In [None]:
# Define model-2
class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 120)
        self.fc2 = nn.Linear(120, 120)
        self.fc3 = nn.Linear(120,10)
        self.dropout = nn.Dropout(0.2)

    def forward(self,x):
        x = x.view(x.shape[0],-1)
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.dropout(F.relu(self.fc2(x)))
        x = F.log_softmax(self.fc3(x), dim=1)
        return x

In [None]:
# Define accuracy function
def accuracy_fn(y_pred, y_true):
    n_correct = torch.eq(y_pred, y_true).sum().item()
    acc = (n_correct / len(y_pred)) * 100
    return acc

In [None]:
def train_epoch(dataloader, model, loss_fn, optimizer, epoch):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [None]:
def validate_epoch(dataloader, model, loss_fn, epoch):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct, acc =  0, 0, 0.0
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            predictions = pred.max(dim=1)[1]
            acc += accuracy_fn(predictions, y)
    test_loss /= num_batches
    acc /= num_batches
    correct /= size
    if epoch > 0 and epoch % 50 == 0:
        print(f"Epoc: {epoch}, Avg validation loss: {test_loss:.2f}, Avg validation accuracy: {acc:.2f}%") 
        print("--" * 40)
    return test_loss

In [None]:
def train_func(config: Dict):
    batch_size = config.get("batch_size", 64) 
    lr = config.get('lr', 1e-3)
    epochs = config.get("epochs", 20)
    momentum = config.get("momentum", 0.9)
    model_type = config.get('model_type', None)
    loss_fn = config.get("loss_fn", nn.NLLLoss())

    # Create data loaders.
    train_dataloader = DataLoader(training_data, batch_size=batch_size)
    test_dataloader = DataLoader(test_data, batch_size=batch_size)

    # Prepare to use Ray integrated wrappers around PyTorch's Dataloaders
    train_dataloader = train.torch.prepare_data_loader(train_dataloader)
    test_dataloader = train.torch.prepare_data_loader(test_dataloader)

    # Create model.

    model = Classifier() if model_type else NeuralNetwork()
    # Prepare to use Ray integrated wrappers around PyTorch's model
    model = train.torch.prepare_model(model)
    
    # Get or objective loss function
    loss_fn = config.get("loss_fn", nn.NLLLoss())

    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)

    loss_results = []

    for e in range(epochs):
        train_epoch(train_dataloader, model, loss_fn, optimizer, e)
        loss = validate_epoch(test_dataloader, model, loss_fn, e)
        train.report(loss=loss)
        loss_results.append(loss)

    return loss_results

In [None]:
def train_fashion_mnist(num_workers=2, use_gpu=False):
    trainer = Trainer(
        backend="torch", num_workers=num_workers, use_gpu=use_gpu)
    trainer.start()
    result = trainer.run(
        train_func=train_func,
        config={
            "lr": 1e-3,
            "batch_size": 128,
            "epochs": 150,
            "momentum": 0.9,
            "model_type": 0,                     # change to 1 for second NN model
            "loss_fn": nn.CrossEntropyLoss()     # change to nn.nn.NLLLoss() 
        },
        callbacks=[JsonLoggerCallback()])
    trainer.shutdown() 
    return result

In [None]:
number_of_workers = 2
use_gpu = False                              # change to True if using a Ray cluster with GPUs
address = "anyscale://ray_train_ddp_cluster" # use your anyscale cluster here

In [None]:
ray.init(ignore_reinit_error=True)                           # run locally
#ray.init(address=address)                                   # run on a Ray cluster on Anyscale

In [None]:
%%time
results = train_fashion_mnist(num_workers=number_of_workers, use_gpu=use_gpu)

In [None]:
!grep processor /proc/cpuinfo | wc -l