# Hyperparameter Tuning with PyTorch & RayTune

This notebook will walk you through the basics of using [RayTune](https://ray.readthedocs.io/en/latest/tune.html). We'll do so with a PyTorch model in this example.

We'll follow a simple process:
1. We'll first create a model and train it, just like we might on a single node.
2. We'll then make the slight modifications to turn it into a distributed hyperparameter search.
3. We'll then run it on RayTune and see the results.


Let's go ahead and get started, first we're going start off with our core imports. We'll be training on the MNIST dataset with a ConvNet model.

In [24]:
import os 

import ray

from torchvision import datasets, transforms

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

from filelock import FileLock

We'll set our global variables for epochs and test size.

In [3]:
EPOCH_SIZE = 512
TEST_SIZE = 256

## Single Node PyTorch Hyperparameter Tuning

Our example will follow nearly the exact same code that you can find in the [PyTorch MNIST example here](https://github.com/pytorch/examples/blob/master/mnist/main.py).

You'll see that we create an even simpler model than in that example, however you can use that one if you wish to try and make some better predictions.

In [25]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
        self.fc = nn.Linear(192, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 3))
        x = x.view(-1, 192)
        x = self.fc(x)
        return F.log_softmax(x, dim=1)

After creating that network, we can now create our data loaders for training and test data. These are just plain [PyTorch dataloaders](https://pytorch.org/docs/1.1.0/data.html?highlight=dataloader#torch.utils.data.DataLoader) except that we've added a `FileLock` to ensure that only one process downloads the data on each machine (if we have multiple workers / machine on our Ray cluster).

Other than that, there's nothing that's changed from the [PyTorch example version](https://github.com/pytorch/examples/blob/master/mnist/main.py#L101).

In [5]:
def get_data_loaders():
    mnist_transforms = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])

    # We add FileLock here because multiple workers will want to
    # download data, and this may cause overwrites since
    # DataLoader is not threadsafe.
    # This is only relevant in the distributed 
    with FileLock(os.path.expanduser("~/data.lock")):
        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST(
                "/tmp/data",
                train=True,
                download=True,
                transform=mnist_transforms),
            batch_size=64,
            shuffle=True)

        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST("/tmp/data", train=False, transform=mnist_transforms),
            batch_size=64,
            shuffle=True)
    return train_loader, test_loader

We defined how we're going to download / load the data [and preprocess it]. Now it's time to define our training and test functions. While the arguments are a bit switched up from the PyTorch tutorial we've referenced, the difference is inconsequential. We're going to take an optimizer, a model, the train loader, specify our device and then train the model.

In [6]:
def train(model, optimizer, train_loader, device=torch.device("cpu")):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        if batch_idx * len(data) > EPOCH_SIZE:
            return
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

It's the same story for our test model. We've defined some basic `average correct prediction` metric that we'll be tracking here. We could add / calculate more as well - we're just keeping it simple.

In [7]:
def test(model, data_loader, device=torch.device("cpu")):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(data_loader):
            if batch_idx * len(data) > TEST_SIZE:
                break
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    return correct / total

Lastly, we'll create a wrapper function for this particular model. In doing so all we need to do is specify the configuration for the model that we would like to train and the function gets the data, creates the model, and optimizes it accordingly.

In [26]:
def train_mnist(config):
    train_loader, test_loader = get_data_loaders()
    model = ConvNet()
    optimizer = optim.SGD(model.parameters(), lr=config["lr"], momentum=config['momentum'])
    for i in range(10):
        train(model, optimizer, train_loader)
        acc = test(model, test_loader)
        print(acc)

### Single-Node Hyperparameter Tuning

Now, let's show what we might have to do if we were going to perform hyperparameter tuning on a single machine. We would have to enumerate all the possibilities and either train them serially or use something like multiprocessing to train them in parallel. That setup takes a little bit of work so often times people opt to train them serially and just wait for it to take a long time.

This is what that might end up looking like.

In [42]:
import itertools
conf = {
    "lr": [0.001, 0.01, 0.1],
    "momentum": [0.001, 0.01, 0.1, 0.9]
}

combinations = list(itertools.product(*conf.values()))
print(len(combinations))

12


In [40]:
for lr, momentum in combinations:
    train_mnist({"lr":lr, "momentum":momentum})
    break # we'll stop this after one run and just use it for illustrative purposes

0.175
0.228125
0.30625
0.36875
0.409375
0.571875
0.61875
0.6625
0.7125
0.75625


### RayTune: Distributed Hyperparameter Tuning

Now, we've seen how you might approach the problem in a single node world. With RayTune, it becomes trivial to move your code from a single node to multiple nodes. Let's take a look at the changes that we're going to need to do achieve that.

First, let's import Ray and initialize our Ray application on the cluster.

In [92]:
import ray

ray.shutdown()
# ray.init(address='auto')
from ray import tune

The first minor change is that we'll specify that we want to perform a strict `grid_search` on our hyperparameters.

In [93]:
conf = {
    "lr": tune.grid_search([0.001, 0.01, 0.1]),
    "momentum": tune.grid_search([0.001, 0.01, 0.1, 0.9])
}

Now let's take our simple training function and add a single line: `tune.track.log(mean_accuracy=acc)`.

That's all that we need to change in order for RayTune to be able to parallelize our different hyperparameter combinations. When we're executing a hyperparameter sweep, we're executing an **experiment**. Each distinct combination of our different hyperparameters is a single **trials**.

In the following example, we're using the **functional API**, this makes it easy to get something up and running but does provide overall less control than the **class API** [`tune.Trainable`].

In [94]:
def train_mnist(config):
    train_loader, test_loader = get_data_loaders()
    model = ConvNet()
    optimizer = optim.SGD(model.parameters(), lr=config["lr"])
    for i in range(10):
        train(model, optimizer, train_loader)
        acc = test(model, test_loader)
        tune.track.log(mean_accuracy=acc)

Here's an example of the **class API**. Note that `_setup` is called **once per trial**. While the number of times `_train` is called is determined by the parameter that we pass to the `tune.run` call in the cell now. `stop={"training_iteration": 10}`.

In [106]:
class TrainMNIST(tune.Trainable):
    def _setup(self, config):
        self.config = config
        self.train_loader, self.test_loader = get_data_loaders()
        self.model = ConvNet()
        self.optimizer = optim.SGD(self.model.parameters(), lr=self.config["lr"])
    
    def _train(self):
        train(self.model, self.optimizer, self.train_loader)
        acc = test(self.model, self.test_loader)
        return {"mean_accuracy": acc}

In [107]:
analysis = tune.run(TrainMNIST, config=conf, stop={"training_iteration": 10})
# # to run using the functional API, run the following
# analysis = tune.run(train_mnist, config=conf)

Trial name,status,loc,lr,momentum
TrainMNIST_00000,RUNNING,,0.001,0.001
TrainMNIST_00001,PENDING,,0.01,0.001
TrainMNIST_00002,PENDING,,0.1,0.001
TrainMNIST_00003,PENDING,,0.001,0.01
TrainMNIST_00004,PENDING,,0.01,0.01
TrainMNIST_00005,PENDING,,0.1,0.01
TrainMNIST_00006,PENDING,,0.001,0.1
TrainMNIST_00007,PENDING,,0.01,0.1
TrainMNIST_00008,PENDING,,0.1,0.1
TrainMNIST_00009,PENDING,,0.001,0.9


Result for TrainMNIST_00000:
  date: 2020-04-10_12-34-36
  done: false
  experiment_id: 62e48d2d36e54eaca656a0426bd327dc
  experiment_tag: 0_lr=0.001,momentum=0.001
  hostname: billmp.lan
  iterations_since_restore: 1
  mean_accuracy: 0.10625
  node_ip: 192.168.1.13
  pid: 23084
  time_since_restore: 0.32193803787231445
  time_this_iter_s: 0.32193803787231445
  time_total_s: 0.32193803787231445
  timestamp: 1586547276
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: '00000'
  
Result for TrainMNIST_00003:
  date: 2020-04-10_12-34-36
  done: false
  experiment_id: ce9f54360cee4c85acb48a2d277f02e5
  experiment_tag: 3_lr=0.001,momentum=0.01
  hostname: billmp.lan
  iterations_since_restore: 1
  mean_accuracy: 0.096875
  node_ip: 192.168.1.13
  pid: 23083
  time_since_restore: 0.3322930335998535
  time_this_iter_s: 0.3322930335998535
  time_total_s: 0.3322930335998535
  timestamp: 1586547276
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: '00003'
  
Result 

Trial name,status,loc,lr,momentum,acc,iter,total time (s)
TrainMNIST_00000,TERMINATED,,0.001,0.001,0.096875,10.0,3.43669
TrainMNIST_00001,TERMINATED,,0.01,0.001,0.74375,10.0,3.42372
TrainMNIST_00002,RUNNING,192.168.1.13:23070,0.1,0.001,0.81875,9.0,3.10756
TrainMNIST_00003,TERMINATED,,0.001,0.01,0.090625,10.0,3.44806
TrainMNIST_00004,TERMINATED,,0.01,0.01,0.76875,10.0,3.49915
TrainMNIST_00005,RUNNING,192.168.1.13:23066,0.1,0.01,0.825,9.0,3.09171
TrainMNIST_00006,RUNNING,192.168.1.13:23095,0.001,0.1,0.1125,7.0,2.40808
TrainMNIST_00007,RUNNING,192.168.1.13:23094,0.01,0.1,0.271875,7.0,2.71676
TrainMNIST_00008,RUNNING,,0.1,0.1,,,
TrainMNIST_00009,RUNNING,,0.001,0.9,,,


Result for TrainMNIST_00002:
  date: 2020-04-10_12-34-40
  done: true
  experiment_id: 3ee3a41a6d97443a8861775d1c0cf30e
  experiment_tag: 2_lr=0.1,momentum=0.001
  hostname: billmp.lan
  iterations_since_restore: 10
  mean_accuracy: 0.853125
  node_ip: 192.168.1.13
  pid: 23070
  time_since_restore: 3.7694034576416016
  time_this_iter_s: 0.66184401512146
  time_total_s: 3.7694034576416016
  timestamp: 1586547280
  timesteps_since_restore: 0
  training_iteration: 10
  trial_id: '00002'
  
Result for TrainMNIST_00005:
  date: 2020-04-10_12-34-40
  done: true
  experiment_id: 9f1f537e7a6a4857ab692f2e62670359
  experiment_tag: 5_lr=0.1,momentum=0.01
  hostname: billmp.lan
  iterations_since_restore: 10
  mean_accuracy: 0.8375
  node_ip: 192.168.1.13
  pid: 23066
  time_since_restore: 3.804111957550049
  time_this_iter_s: 0.7124001979827881
  time_total_s: 3.804111957550049
  timestamp: 1586547280
  timesteps_since_restore: 0
  training_iteration: 10
  trial_id: '00005'
  
Result for TrainM

Trial name,status,loc,lr,momentum,acc,iter,total time (s)
TrainMNIST_00000,TERMINATED,,0.001,0.001,0.096875,10,3.43669
TrainMNIST_00001,TERMINATED,,0.01,0.001,0.74375,10,3.42372
TrainMNIST_00002,TERMINATED,,0.1,0.001,0.853125,10,3.7694
TrainMNIST_00003,TERMINATED,,0.001,0.01,0.090625,10,3.44806
TrainMNIST_00004,TERMINATED,,0.01,0.01,0.76875,10,3.49915
TrainMNIST_00005,TERMINATED,,0.1,0.01,0.8375,10,3.80411
TrainMNIST_00006,TERMINATED,,0.001,0.1,0.1375,10,4.05427
TrainMNIST_00007,TERMINATED,,0.01,0.1,0.33125,10,4.12035
TrainMNIST_00008,TERMINATED,,0.1,0.1,0.8625,10,2.44107
TrainMNIST_00009,TERMINATED,,0.001,0.9,0.153125,10,2.28965


In [108]:
print("Best config: ", analysis.get_best_config(metric="mean_accuracy"))

Best config:  {'lr': 0.1, 'momentum': 0.1}


In [109]:
# Get a dataframe for analyzing trial results.
df = analysis.dataframe()

In [110]:
df.sort_values('mean_accuracy', ascending=False).head()

Unnamed: 0,mean_accuracy,done,timesteps_total,episodes_total,training_iteration,experiment_id,date,timestamp,time_this_iter_s,time_total_s,...,hostname,node_ip,time_since_restore,timesteps_since_restore,iterations_since_restore,trial_id,experiment_tag,config/lr,config/momentum,logdir
11,0.88125,True,,,10,57e8999280cb45ccb58932f5ffc490fb,2020-04-10_12-34-44,1586547284,0.219276,2.29302,...,billmp.lan,192.168.1.13,2.29302,0,10,11,"11_lr=0.1,momentum=0.9",0.1,0.9,/Users/williamchambers/ray_results/TrainMNIST/...
8,0.8625,True,,,10,9c012f6a63254dcf84feaa437842d0d8,2020-04-10_12-34-44,1586547284,0.219407,2.441065,...,billmp.lan,192.168.1.13,2.441065,0,10,8,"8_lr=0.1,momentum=0.1",0.1,0.1,/Users/williamchambers/ray_results/TrainMNIST/...
2,0.853125,True,,,10,3ee3a41a6d97443a8861775d1c0cf30e,2020-04-10_12-34-40,1586547280,0.661844,3.769403,...,billmp.lan,192.168.1.13,3.769403,0,10,2,"2_lr=0.1,momentum=0.001",0.1,0.001,/Users/williamchambers/ray_results/TrainMNIST/...
5,0.8375,True,,,10,9f1f537e7a6a4857ab692f2e62670359,2020-04-10_12-34-40,1586547280,0.7124,3.804112,...,billmp.lan,192.168.1.13,3.804112,0,10,5,"5_lr=0.1,momentum=0.01",0.1,0.01,/Users/williamchambers/ray_results/TrainMNIST/...
4,0.76875,True,,,10,a0f936131af7463ea0cad3ca3dc8f9e2,2020-04-10_12-34-40,1586547280,0.465088,3.499153,...,billmp.lan,192.168.1.13,3.499153,0,10,4,"4_lr=0.01,momentum=0.01",0.01,0.01,/Users/williamchambers/ray_results/TrainMNIST/...


In [101]:
analysis = tune.run(train_mnist, config=conf)

Trial name,status,loc,lr,momentum
train_mnist_00000,RUNNING,,0.001,0.001
train_mnist_00001,PENDING,,0.01,0.001
train_mnist_00002,PENDING,,0.1,0.001
train_mnist_00003,PENDING,,0.001,0.01
train_mnist_00004,PENDING,,0.01,0.01
train_mnist_00005,PENDING,,0.1,0.01
train_mnist_00006,PENDING,,0.001,0.1
train_mnist_00007,PENDING,,0.01,0.1
train_mnist_00008,PENDING,,0.1,0.1
train_mnist_00009,PENDING,,0.001,0.9


Result for train_mnist_00000:
  date: 2020-04-10_12-30-52
  done: false
  experiment_id: 481535da8ac64c22a40d6ebe5816aea1
  experiment_tag: 0_lr=0.001,momentum=0.001
  hostname: billmp.lan
  iterations_since_restore: 1
  mean_accuracy: 0.165625
  node_ip: 192.168.1.13
  pid: 22981
  time_since_restore: 0.5230550765991211
  time_this_iter_s: 0.5230550765991211
  time_total_s: 0.5230550765991211
  timestamp: 1586547052
  timesteps_since_restore: 0
  training_iteration: 0
  trial_id: '00000'
  
Result for train_mnist_00002:
  date: 2020-04-10_12-30-52
  done: false
  experiment_id: 1b8189b437c646de8623cb4db8dda0dc
  experiment_tag: 2_lr=0.1,momentum=0.001
  hostname: billmp.lan
  iterations_since_restore: 1
  mean_accuracy: 0.515625
  node_ip: 192.168.1.13
  pid: 22984
  time_since_restore: 0.5195770263671875
  time_this_iter_s: 0.5195770263671875
  time_total_s: 0.5195770263671875
  timestamp: 1586547052
  timesteps_since_restore: 0
  training_iteration: 0
  trial_id: '00002'
  
Result f

Trial name,status,loc,lr,momentum,acc,iter,total time (s)
train_mnist_00000,RUNNING,192.168.1.13:22981,0.001,0.001,0.178125,8.0,3.46004
train_mnist_00001,RUNNING,192.168.1.13:22980,0.01,0.001,0.746875,8.0,3.51447
train_mnist_00002,RUNNING,192.168.1.13:22984,0.1,0.001,0.88125,8.0,3.40058
train_mnist_00003,RUNNING,192.168.1.13:23014,0.001,0.01,0.103125,5.0,2.199
train_mnist_00004,RUNNING,192.168.1.13:23015,0.01,0.01,0.33125,5.0,2.34423
train_mnist_00005,RUNNING,192.168.1.13:23016,0.1,0.01,0.6625,5.0,2.08736
train_mnist_00006,RUNNING,192.168.1.13:23017,0.001,0.1,0.228125,4.0,1.95937
train_mnist_00007,RUNNING,192.168.1.13:23018,0.01,0.1,0.515625,5.0,2.15796
train_mnist_00008,PENDING,,0.1,0.1,,,
train_mnist_00009,PENDING,,0.001,0.9,,,


Result for train_mnist_00008:
  date: 2020-04-10_12-30-57
  done: false
  experiment_id: 699bd7607e754533832ce576663b4253
  experiment_tag: 8_lr=0.1,momentum=0.1
  hostname: billmp.lan
  iterations_since_restore: 1
  mean_accuracy: 0.521875
  node_ip: 192.168.1.13
  pid: 23029
  time_since_restore: 0.41529178619384766
  time_this_iter_s: 0.41529178619384766
  time_total_s: 0.41529178619384766
  timestamp: 1586547057
  timesteps_since_restore: 0
  training_iteration: 0
  trial_id: 00008
  
Result for train_mnist_00010:
  date: 2020-04-10_12-30-57
  done: false
  experiment_id: a0c7de6bd33c4200b78ad6d2069317ed
  experiment_tag: 10_lr=0.01,momentum=0.9
  hostname: billmp.lan
  iterations_since_restore: 1
  mean_accuracy: 0.16875
  node_ip: 192.168.1.13
  pid: 23025
  time_since_restore: 0.4081389904022217
  time_this_iter_s: 0.4081389904022217
  time_total_s: 0.4081389904022217
  timestamp: 1586547057
  timesteps_since_restore: 0
  training_iteration: 0
  trial_id: '00010'
  
Result for t

Trial name,status,loc,lr,momentum,acc,iter,total time (s)
train_mnist_00000,TERMINATED,,0.001,0.001,0.19375,9,3.88122
train_mnist_00001,TERMINATED,,0.01,0.001,0.753125,9,4.02972
train_mnist_00002,TERMINATED,,0.1,0.001,0.890625,9,3.81148
train_mnist_00003,TERMINATED,,0.001,0.01,0.11875,9,3.80629
train_mnist_00004,TERMINATED,,0.01,0.01,0.521875,9,4.09952
train_mnist_00005,TERMINATED,,0.1,0.01,0.890625,9,3.68201
train_mnist_00006,TERMINATED,,0.001,0.1,0.29375,9,3.95327
train_mnist_00007,TERMINATED,,0.01,0.1,0.68125,9,3.75076
train_mnist_00008,TERMINATED,,0.1,0.1,0.896875,9,2.49928
train_mnist_00009,TERMINATED,,0.001,0.9,0.165625,9,2.56056


In [102]:
# Get a dataframe for analyzing trial results.
df = analysis.dataframe()

In [103]:
df.sort_values('mean_accuracy', ascending=False).head()

Unnamed: 0,mean_accuracy,trial_id,training_iteration,time_this_iter_s,done,timesteps_total,episodes_total,experiment_id,date,timestamp,...,pid,hostname,node_ip,time_since_restore,timesteps_since_restore,iterations_since_restore,experiment_tag,config/lr,config/momentum,logdir
8,0.896875,8,9,0.226021,False,,,699bd7607e754533832ce576663b4253,2020-04-10_12-30-59,1586547059,...,23029,billmp.lan,192.168.1.13,2.499276,0,10,"8_lr=0.1,momentum=0.1",0.1,0.1,/Users/williamchambers/ray_results/train_mnist...
2,0.890625,2,9,0.4109,False,,,1b8189b437c646de8623cb4db8dda0dc,2020-04-10_12-30-55,1586547055,...,22984,billmp.lan,192.168.1.13,3.81148,0,10,"2_lr=0.1,momentum=0.001",0.1,0.001,/Users/williamchambers/ray_results/train_mnist...
5,0.890625,5,9,0.380373,False,,,65e4bd8f7f3e4689be98db0a3f0005b9,2020-04-10_12-30-57,1586547057,...,23016,billmp.lan,192.168.1.13,3.682014,0,10,"5_lr=0.1,momentum=0.01",0.1,0.01,/Users/williamchambers/ray_results/train_mnist...
11,0.86875,11,9,0.186408,False,,,de1cfc0d7aa64fc9bd74d616885e3506,2020-04-10_12-31-00,1586547060,...,23030,billmp.lan,192.168.1.13,2.291056,0,10,"11_lr=0.1,momentum=0.9",0.1,0.9,/Users/williamchambers/ray_results/train_mnist...
1,0.753125,1,9,0.515246,False,,,8e84b35eeb2640dfbf6eb611c818b54f,2020-04-10_12-30-56,1586547056,...,22980,billmp.lan,192.168.1.13,4.02972,0,10,"1_lr=0.01,momentum=0.001",0.01,0.001,/Users/williamchambers/ray_results/train_mnist...


# Conclusion

In this example we learned about how to perform distributed hyperparameter tuning with RayTune. We took a sweep that we had to run locally and ran it in a distributed fashion with basically zero code changes. We learned about the different `tunable` types and how to manipulate them. See [the documentation for more information](https://ray.readthedocs.io/en/latest/tune.html).