# Ray Tune - A Deeper Dive Using MNIST with PyTorch

© 2019-2020, Anyscale. All Rights Reserved

![Anyscale Academy](../images/AnyscaleAcademy_Logo_clearbanner_141x100.png)

The [previous notebook](01-Understanding-Hyperparameter-Tuning.ipynb) explained the concept of hyperparameter tuning/optimization (HPO) and walked through the basics of using [Ray Tune](https://ray.readthedocs.io/en/latest/tune.html). 

Now we'll use another example to explore more of the API features. We'll use the [MNIST](http://yann.lecun.com/exdb/mnist/) of hand-written digits and train a [PyTorch](https://pytorch.org/) model to recognize them.

For another, excellent example using Ray Tune, see [this post](https://www.recogn.ai/biome-text/documentation/tutorials/Hyperparameter_optimization_with_Ray_Tune.html#download-the-data-and-create-the-vocabulary).

In [None]:
import os 
from torchvision import datasets, transforms
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from filelock import FileLock

## PyTorch Hyperparameter Tuning

Our example will closely follow the code in the [PyTorch MNIST example](https://github.com/pytorch/examples/blob/master/mnist/main.py). However, we will create an even simpler model than the one in the example, although you could try that model and compare its predictions.

Let's start by defining a few global variables for epoch and test sizes. Also define a data location.

In [None]:
EPOCH_SIZE = 512
TEST_SIZE = 256

DATA_ROOT = '../data/mnist'

The following class defines a convolutional neural network.

> **Tip:** Most of these code definitions can be found in `mnist.py`, too.

In [None]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
        self.fc = nn.Linear(192, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 3))
        x = x.view(-1, 192)
        x = self.fc(x)
        return F.log_softmax(x, dim=1)

After creating that network, we can now create our data loaders for training and test data. These are just plain [PyTorch `DataLoaders`](https://pytorch.org/docs/1.1.0/data.html?highlight=dataloader#torch.utils.data.DataLoader) with two additions:

1. A `FileLock` is added to ensure that only one process downloads the data on each machine, just in case we have multiple workers per machine in our Ray cluster.
2. The root directory for the data can be specified and it will be created if it doesn't exist.

Otherwise, this code is identical to the [PyTorch example version](https://github.com/pytorch/examples/blob/master/mnist/main.py#L101).

In [None]:
def get_data_loaders():
    mnist_transforms = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])

    # We add FileLock here because multiple workers on the same machine coulde try 
    # download the data. This would cause overwrites, since DataLoader is not threadsafe.
    # You wouldn't need this for single-process training.
    lock_file = f'{DATA_ROOT}/data.lock'
    import os
    if not os.path.exists(DATA_ROOT):
        os.makedirs(DATA_ROOT)
        
    with FileLock(os.path.expanduser(lock_file)):
        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST(DATA_ROOT, train=True, download=True, transform=mnist_transforms),
            batch_size=64,
            shuffle=True)

        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST(DATA_ROOT, train=False, transform=mnist_transforms),
            batch_size=64,
            shuffle=True)
    return train_loader, test_loader

Now we define our training and test functions. While the arguments are a bit switched up from the original PyTorch tutorial, the difference is inconsequential. the arguments are an optimizer, a model, the training data loader, and our device. Then we train the model.

In [None]:
def train(model, optimizer, train_loader, device=torch.device("cpu")):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        if batch_idx * len(data) > EPOCH_SIZE:
            return
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

Similarly for our test model, we define a basic _average correct prediction_ metric that we will track. We could add more metrics, but we'll keep it simple.

In [None]:
def test(model, data_loader, device=torch.device("cpu")):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(data_loader):
            if batch_idx * len(data) > TEST_SIZE:
                break
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    return correct / total

Finally, we create a wrapper function for this particular model. In doing so all we need to do is specify the configuration for the model that we would like to train and the function will do the rest:

1. Retrieve the data with the loaders returned by `get_data_loaders()`
2. Create the `ConvNet` model
3. Optimize the model using _stochastic gradient descent_.

In [None]:
def train_mnist(config):
    train_loader, test_loader = get_data_loaders()
    model = ConvNet()
    optimizer = optim.SGD(model.parameters(), lr=config["lr"], momentum=config['momentum'])
    for i in range(10):
        train(model, optimizer, train_loader)
        acc = test(model, test_loader)
        print(acc)

### Single-Node Hyperparameter Tuning

Let's show what we might do if we performed hyperparameter tuning on a single machine. We would have to enumerate all the possibilities and either train them serially or use something like multiprocessing to train them in parallel. That setup takes a little bit of work so people often decide to train them serially, which is easiest, but requires the most time.

This is what we might do.

In [None]:
import itertools
conf = {
    "lr": [0.001, 0.01, 0.1],
    "momentum": [0.001, 0.01, 0.1, 0.9]
}

combinations = list(itertools.product(*conf.values()))
print(len(combinations))

In [None]:
for lr, momentum in combinations:
    train_mnist({"lr":lr, "momentum":momentum})
    break # we'll stop this after one run and just use it for illustrative purposes

### Distributed Hyperparameter Tuning with Ray Tune

Ray Tune makes it trivial to move this code from a single node to multiple nodes. Let's see how to use the code we've written with Ray Tune.

First, we set up Ray as before.

In [None]:
import ray
from ray import tune

In [None]:
!../tools/start-ray.sh --check --verbose

In [None]:
ray.init(address='auto', ignore_reinit_error=True)

The first change is we'll perform a strict `grid_search` on our hyperparameters, like we used in the previous lesson. Our hyperparameters are the learning rate, `lr`, and the `momentum`.

In [None]:
config = {
    "lr": tune.grid_search([0.001, 0.01, 0.1]),
    "momentum": tune.grid_search([0.001, 0.01, 0.1, 0.9])
}

Next we modify `train_mnist` to use Tune's "reporting" logger:

In [None]:
def train_mnist(config):
    from ray.tune import report
    train_loader, test_loader = get_data_loaders()
    model = ConvNet()
    optimizer = optim.SGD(model.parameters(), lr=config["lr"], momentum=config['momentum'])
    for i in range(10):
        train(model, optimizer, train_loader)
        acc = test(model, test_loader)
        report(mean_accuracy=acc)

That's all that we need to change in order for Ray Tune to be able to parallelize our different hyperparameter combinations. 

When we execute a hyperparameter sweep, we perform an **experiment**. Each distinct combination of our different hyperparameters constitutes a single **trial**.

## Tune's Functional vs. Class API

In the previous lesson we used the **functional API**. This API is most convenient for quickly setting up experiments, but it provides less overall flexbility compared to the **class API** [`tune.Trainable`](https://docs.ray.io/en/latest/tune/api_docs/trainable.html#tune-trainable).

We'll try both, starting with the functional API.

We add a stopping criterion, `stop={"training_iteration": 20}`, so this will go reasonably quickly, while still producing good results. Consider removing this condition if you don't mind waiting longer and you want optimal results.

In [None]:
analysis_func = tune.run(train_mnist, config=config, stop={"training_iteration": 20})

In [None]:
print("Best config: ", analysis_func.get_best_config(metric="mean_accuracy"))

In [None]:
analysis_func.dataframe().sort_values('mean_accuracy', ascending=False).head()

In [None]:
analysis_func.dataframe()[['mean_accuracy', 'config/lr', 'config/momentum']].sort_values('mean_accuracy', ascending=False)

How long did it take? We'll compare this value with a different training run in the next lesson.

In [None]:
stats = analysis_func.stats()
secs = stats["timestamp"] - stats["start_time"]
print(f'{secs:7.2f} seconds, {secs/60.0:7.2f} minutes')

Now let's use the **class API**. Note that `_setup` is called **once per trial**. While the number of times `_train` is called is determined by the parameter that we pass to the `tune.run` call. 

In [None]:
class TrainMNIST(tune.Trainable):
    def _setup(self, config):
        self.config = config
        self.train_loader, self.test_loader = get_data_loaders()
        self.model = ConvNet()
        self.optimizer = optim.SGD(self.model.parameters(), lr=self.config["lr"])
    
    def _train(self):
        train(self.model, self.optimizer, self.train_loader)
        acc = test(self.model, self.test_loader)
        return {"mean_accuracy": acc}

In [None]:
analysis = tune.run(
    TrainMNIST, 
    config=config, 
    stop={"training_iteration": 20},
    verbose=1
)

In [None]:
print("Best config: ", analysis.get_best_config(metric="mean_accuracy"))

In [None]:
# Get a dataframe for analyzing trial results.
df = analysis.dataframe()

In [None]:
analysis.dataframe().sort_values('mean_accuracy', ascending=False).head()

It's easier to see what we want if project out the interesting columns:

In [None]:
analysis.dataframe()[['mean_accuracy', 'config/lr', 'config/momentum']].sort_values('mean_accuracy', ascending=False)

How long did it take? We'll compare this value with a different training run in the next lesson.

In [None]:
stats = analysis.stats()
secs = stats["timestamp"] - stats["start_time"]
print(f'{secs:7.2f} seconds, {secs/60.0:7.2f} minutes')

The next lesson will explore optimization algorithms that speed up HPO.