**Ray Tune**

This notebook is from the PyTorch tutorial on [Hyperparameter Tuning with Ray](https://pytorch.org/tutorials/beginner/hyperparameter_tuning_tutorial.html).

In [1]:
import os
from functools import partial

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import torchvision
import torchvision.transforms as T

from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler

## Data Loaders
Wrap the data loaders in their own function and pass a global data directory.
This way we can share a data directory between different trials

In [2]:
DATA_DIR = '/home/evan/.Data/cifar10'

def load_data(data_dir=DATA_DIR):
    transform = T.Compose([
        T.ToTensor(),
        T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    trainset = torchvision.datasets.CIFAR10(
        root=data_dir, train=True, download=True, transform=transform)
    
    testset = torchvision.datasets.CIFAR10(
        root=data_dir, train=False, download=True, transform=transform)
    
    return trainset, testset

## Configurable Neural Network
We can only tune those parameters that are configurable.
In this example, we can specify the layer sizes of the fully connected layers.

In [3]:
class Net(nn.Module):
    def __init__(self, l1=120, l2=84):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool  = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1   = nn.Linear(16 * 5 * 5, l1)
        self.fc2   = nn.Linear(l1, l2)
        self.fc3   = nn.Linear(l2, 10)
        
    def forward(self, x):
        h = self.pool(F.relu(self.conv1(x)))
        h = self.pool(F.relu(self.conv2(h)))
        h = h.view(-1, 16 * 5 * 5)
        h = F.relu(self.fc1(h))
        h = F.relu(self.fc2(h))
        h = self.fc3(h)
        return h

## Train Function
We wrap an entire training script into a function that accepts a config of tunable hyperparameters, and optionally some relevant directories for checkpointing and data.

The training function is fairly standard/vanilla except that **we communicate our validation metrics to Ray Tune**. Ray Tune uses these metrics to decide which hyperparameter configuration lead to the best results. These metrics can also be used to **stop bad performing trials early** in order to avoid wasting resources on those trials.

The **checkpoint saving** is optional, however, it is necessary if we wanted to use advanced schedulers like [**Population Based Training**](https://docs.ray.io/en/master/tune/tutorials/tune-advanced-tutorial.html). Also, by saving the checkpoint, we can later load the trained models and validate them on a test set.

In [4]:
def train_cifar(cfg, checkpoint_dir=None, data_dir=None):
    """ A complete training script, wrapped in a function.
    Parameters from our config, ``cfg``, are tunable.
    """
    # Make network, with FC layer sizes from config.
    net = Net(cfg['l1'], cfg['l2'])
    
    # Set device.
    ## Defaults to CPU, but will use GPU(s) if they are available.
    device = 'cpu'
    if torch.cuda.is_available():
        device = 'cuda'
        if torch.cuda.device_count() > 1:
            net = nn.DataParallel(net)
    net.to(device)
    
    # Optimizer and objective.
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=cfg['lr'], momentum=0.9)
    
    # Load checkpoints if available.
    if checkpoint_dir:
        checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint')
        model_state, optimizer_state = torch.load(checkpoint_path)
        net.load_state_dict(model_state)
        optimizer.load_state_dict(optimizer_state)
    
    # Load data.
    trainset, testset = load_data(data_dir)
    
    # Split train into 80/20 train and validation sets.
    N = len(trainset)
    test_abs = int(N * 0.8)
    train_subset, val_subset = random_split(trainset, [test_abs, N - test_abs])
    
    # Make dataloaders.
    batch_size = int(cfg['batch_size'])
    data_loader = partial(DataLoader, batch_size=batch_size, 
                          shuffle=True, num_workers=8)
    train_loader = data_loader(train_subset)
    val_loader   = data_loader(val_subset)
    
    # Training and validation helper functions.
    def _train(epoch, epoch_steps):
        running_loss = 0.0
        for i, data in enumerate(train_loader, 0):
            # Get the inputs; data : list[inputs, labels]
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Zero the parameter gradients.
            optimizer.zero_grad()
            
            # Forward --> Backward --> Optimize.
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            # Print statistics.
            running_loss += loss.item()
            epoch_steps += 1
            if i % 2000 == 1999:
                print(f"[{epoch + 1}, {i + 1}] loss: "
                      f"{running_loss / epoch_steps: .3f}")
                running_loss = 0.0
    
    @torch.no_grad()
    def _validation():
        val_loss  = 0.0
        val_steps = 0
        total   = 0
        correct = 0
        for data in val_loader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Forward & predictions
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total   += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            loss = criterion(outputs, labels)
            val_loss  += loss.cpu().numpy()
            val_steps += 1
        
        loss = val_loss / val_steps
        accuracy = correct / total
        return loss, accuracy
    
    # Checkpoint helper function.
    def _checkpoint(epoch):
        with tune.checkpoint_dir(epoch) as checkpoint_dir:
            path = os.path.join(checkpoint_dir, 'checkpoint')
            torch.save((net.state_dict(), optimizer.state_dict()), path)
            
    # Train loop.
    for epoch in range(10):
        epoch_steps = 0
        
        # Training and validation.
        _train(epoch, epoch_steps)
        val_loss, val_acc = _validation()
        
        # Save checkpoint and report current performance.
        #_checkpoint(epoch)
        #tune.report(loss=val_loss, accuracy=val_acc)
    
    print("\nFinished training!")

## Test Accuracy

In [5]:
@torch.no_grad()
def test_accuracy(net, device='cpu'):
    # Load the data.
    _, testset = load_data()
    test_loader = DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)
    
    # Test loop.
    correct = total = 0
    for data in test_loader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total   += labels.size(0)
        correct += (predicted == labels).sum().item()
    return correct / total    

## Configuring the Search Space
Lastly, we need to define Ray Tune's search space.

The `tune.sample_from()` function makes it possible to define your own sample methods to obtain hyperparameters.

In our example, the `l1` and `l2` parameters should be powers of 2 between 4 and 256 (so either 4, 8, 16, 32, 128, 256). The learning rate, `lr`, should be uniformly sampled between 0.0001 and 0.1. Lastly, the batch size is a choice between 2, 4, 8, and 16.

At each trial, Ray Tune will now randomly sample a combination of parameters from these search spaces. It will then train a number of models in parallel and find the best performing one among these.

We also use the `ASHAScheduler`, which will terminate bad performing trials early.

We wrap the `train_cifar` function with `functools.partial` to set the constant `data_dir` arg.

In [6]:
config = {
    'l1': tune.sample_from(lambda _: 2**np.random.randint(2, 9)),
    'l2': tune.sample_from(lambda _: 2**np.random.randint(2, 9)),
    'lr': tune.loguniform(1e-4, 1e-1),
    'batch_size': tune.choice([2, 4, 8, 16]),
}

We can also tell Ray Tune what resources should be available for each trial.

You can specify the number of CPUs, which are then available e.g. to increase the `num_workers` of the torch `DataLoader` instances. The selected number of GPUs are made visible to torch in each trial. Trials do not need to have access to GPUs that haven't been requested for them, so you don't have to care about two trials using the same set of resources.

Here we can also specify **fractional GPUs**, so something like `gpus_per_trial=0.5` is completely valid. The trials will then share GPUs among each other. You just have to make sure that the models still fit in the GPU memory.

After training the models, we will find the best performing one and load the trained network from the checkpoint file. We then obtain the test set accuracy and report everything by printing.

In [7]:
def main(num_samples=10, max_num_epochs=10, gpus_per_trial=1):
    data_dir = DATA_DIR
    load_data(data_dir) # Run here so we can download, if need be.
    
    scheduler = ASHAScheduler(
        metric='loss',
        mode='min',
        max_t=max_num_epochs,
        grace_period=1,
        reduction_factor=2)
    
    reporter = CLIReporter(
        # parameter_columns=['l1', 'l2', 'lr', 'batch_size'],
        metric_columns=['loss', 'accuracy', 'training_iteration'])
    
    # Run the search.
    result = tune.run(
        partial(train_cifar, data_dir=DATA_DIR),
        resources_per_trial={'cpu': 1, 'gpu': gpus_per_trial},
        config=config,
        num_samples=num_samples,
        scheduler=scheduler,
        progress_reporter=reporter,)
    
    # Print results.
    best_trial = result.get_best_trial('loss', 'min', 'last')
    print(f"Best trial config: {best_trial.config}")
    print("Best trial final validation loss: "
          f"{best_trial.last_result['loss']}")
    print("Best trial final validation accuracy: "
          f"{best_trial.last_result['accuracy']}")
    
    # Get best model's test accuracy.
    best_trained_model = Net(best_trial.config['l1'], best_trial.config['l2'])
    device = 'cpu'
    if torch.cuda.is_available():
        device = 'cuda:0'
        if gpus_per_trial > 1:
            best_trained_model = nn.DataParallel(best_trained_model)
    best_trained_model.to(device)
    
    best_checkpoint_dir = os.path.join(best_trial.checkpoint.value, 'checkpoint')
    model_state, optimizer_state = torch.load(best_checkpoint_dir)
    best_trained_model.load_state_dict(model_state)
    
    test_acc = test_accuracy(best_trained_model, device)
    print(f"Best trial test set accuracy: {test_acc}")

In [8]:
if __name__ == '__main__':
    # You can change the number of GPUs per trial here.
    main(num_samples=10, max_num_epochs=10, gpus_per_trial=1)

Files already downloaded and verified
Files already downloaded and verified


2021-03-17 06:46:52,706	INFO services.py:1172 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m
2021-03-17 06:46:54,180	INFO registry.py:64 -- Detected unknown callable for trainable. Converting to class.


== Status ==
Memory usage on this node: 6.6/23.4 GiB
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 8.000: None | Iter 4.000: None | Iter 2.000: None | Iter 1.000: None
Resources requested: 1/8 CPUs, 1/1 GPUs, 0.0/11.04 GiB heap, 0.0/3.81 GiB objects (0/1.0 accelerator_type:GTX)
Result logdir: /home/evan/ray_results/DEFAULT_2021-03-17_06-46-54
Number of trials: 1/10 (1 RUNNING)
+---------------------+----------+-------+--------------+------+------+------------+
| Trial name          | status   | loc   |   batch_size |   l1 |   l2 |         lr |
|---------------------+----------+-------+--------------+------+------+------------|
| DEFAULT_1fa08_00000 | RUNNING  |       |            8 |    4 |  128 | 0.00246176 |
+---------------------+----------+-------+--------------+------+------+------------+


[2m[36m(pid=38440)[0m Files already downloaded and verified
[2m[36m(pid=38440)[0m Files already downloaded and verified
[2m[36m(pid=38440)[0m [1, 2000] loss:  2.307
[2m[36m(pid=3

[2m[36m(pid=38439)[0m Files already downloaded and verified
[2m[36m(pid=38439)[0m Files already downloaded and verified
[2m[36m(pid=38439)[0m [1, 2000] loss:  2.309
[2m[36m(pid=38439)[0m [1, 4000] loss:  1.152
[2m[36m(pid=38439)[0m [1, 6000] loss:  0.767
[2m[36m(pid=38439)[0m [1, 8000] loss:  0.573
[2m[36m(pid=38439)[0m [1, 10000] loss:  0.452
[2m[36m(pid=38439)[0m [1, 12000] loss:  0.366
[2m[36m(pid=38439)[0m [1, 14000] loss:  0.303
[2m[36m(pid=38439)[0m [1, 16000] loss:  0.254
[2m[36m(pid=38439)[0m [1, 18000] loss:  0.215
[2m[36m(pid=38439)[0m [1, 20000] loss:  0.188
[2m[36m(pid=38439)[0m [2, 2000] loss:  1.810
[2m[36m(pid=38439)[0m [2, 4000] loss:  0.877
[2m[36m(pid=38439)[0m [2, 6000] loss:  0.581
[2m[36m(pid=38439)[0m [2, 8000] loss:  0.428
[2m[36m(pid=38439)[0m [2, 10000] loss:  0.336
[2m[36m(pid=38439)[0m [2, 12000] loss:  0.278
[2m[36m(pid=38439)[0m [2, 14000] loss:  0.235
[2m[36m(pid=38439)[0m [2, 16000] loss:  0.204


[2m[36m(pid=38436)[0m Files already downloaded and verified
[2m[36m(pid=38436)[0m [1, 2000] loss:  1.846
[2m[36m(pid=38436)[0m [2, 2000] loss:  1.431
[2m[36m(pid=38436)[0m [3, 2000] loss:  1.283
[2m[36m(pid=38436)[0m [4, 2000] loss:  1.191
[2m[36m(pid=38436)[0m [5, 2000] loss:  1.127
[2m[36m(pid=38436)[0m [6, 2000] loss:  1.073
[2m[36m(pid=38436)[0m [7, 2000] loss:  1.023
[2m[36m(pid=38436)[0m [8, 2000] loss:  0.979
[2m[36m(pid=38436)[0m [9, 2000] loss:  0.945
[2m[36m(pid=38436)[0m [10, 2000] loss:  0.927
Trial DEFAULT_1fa08_00004 completed. Last result: 
== Status ==
Memory usage on this node: 8.8/23.4 GiB
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 8.000: None | Iter 4.000: None | Iter 2.000: None | Iter 1.000: None
Resources requested: 0/8 CPUs, 0/1 GPUs, 0.0/11.04 GiB heap, 0.0/3.81 GiB objects (0/1.0 accelerator_type:GTX)
Result logdir: /home/evan/ray_results/DEFAULT_2021-03-17_06-46-54
Number of trials: 6/10 (1 PENDING, 5 TERMINATED)
+------

KeyboardInterrupt: 