<a href="https://colab.research.google.com/github/niemand-01/ML-Demo/blob/master/Hyperparameter_tuning_pytorch_RayTune.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from functools import partial
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import random_split
import torchvision
import torchvision.transforms as transforms
!pip install ray
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler

In [3]:
# load data

def load_data(data_dir="./data"):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset = torchvision.datasets.CIFAR10(
        root=data_dir, train=True, download=True, transform=transform)

    testset = torchvision.datasets.CIFAR10(
        root=data_dir, train=False, download=True, transform=transform)

    return trainset, testset

In [4]:
# simple NN with 2 tunable parameters:
# l1,l2 at FC2,FC3
class Net(nn.Module):
    def __init__(self, l1=120, l2=84):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, l1)
        self.fc2 = nn.Linear(l1, l2)
        self.fc3 = nn.Linear(l2, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [8]:
def train(config,checkpoint_dir=None,data_dir=None):
  # config func is a python-built-in func for exposed params configuration
  net = Net(config["l1"],config["l2"])

  # optimizer
  optimizer = optim.SGD(net.parameters(),lr=config["lr"],momentum=0.9)
  # loss
  criterion = nn.CrossEntropyLoss()

  # device
  device = "cpu"

  # save checkpoint 
  if checkpoint_dir:
    model_state,optimizer_state = torch.load(
        os.path.join(checkpoint_dir,"checkpoint")
    )
    net.load_state_dict(model_state)
    optimizer.load_state_dict(optimizer_state)

  # divide train/validation set = 80/20
  trainset,testset = load_data(data_dir)
  test_abs = int(len(trainset)*0.8)
  # split 20% of trainset to validationset
  train_subset,val_subset = random_split(
      trainset,[test_abs,len(trainset)-test_abs]
  )

  # load generator
  trainloader = torch.utils.data.DataLoader(
      train_subset,
      batch_size = int(config["batch_size"]),
      shuffle = True,
      num_workers = 8
  )
  validloader = torch.utils.data.DataLoader(
      val_subset,
      batch_size = int(config["batch_size"]),
      shuffle = True,
      num_workers = 8
  )

  for epoch in range(10):
    running_loss = 0.0
    epoch_steps = 0

    # each epoch traindata
    for i, data in enumerate(trainloader,0):
      # get the inputs; data is a list of [inputs, labels]
      inputs, labels = data
      inputs, labels = inputs.to(device), labels.to(device)

      # zero the parameter gradients
      optimizer.zero_grad()

      # forward + backward + optimize
      outputs = net(inputs)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()

      # print statistics
      running_loss += loss.item()
      epoch_steps += 1
      if i % 2000 == 1999:  
        # print every 2000 mini-batches
        print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1,running_loss / epoch_steps))
        running_loss = 0.0

    # each epoch validdata
    # Validation loss
    val_loss = 0.0
    val_steps = 0
    total = 0
    correct = 0
    for i, data in enumerate(validloader, 0):
      # disable gradient calculation
      with torch.no_grad():
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        loss = criterion(outputs, labels)
        val_loss += loss.cpu().numpy()
        val_steps += 1


    # communication with raytune
    with tune.checkpoint_dir(epoch) as checkpoint_dir:
      path = os.path.join(checkpoint_dir, "checkpoint")
      torch.save((net.state_dict(), optimizer.state_dict()), path)

    tune.report(loss=(val_loss / val_steps), accuracy=correct / total)
    print("Finished Training")

In [6]:
# configure the search space

config = {
    "l1":tune.sample_from(lambda _: 2**np.random.randint(2,9)),
    "l2":tune.sample_from(lambda _:2**np.random.randint(2,9)),
    "lr":tune.loguniform(1e-4,1e-1),
    "batch_size": tune.choice([2,4,8,16])
}

In [None]:
# main function
def main(num_samples=10, max_num_epochs=10, gpus_per_trial=0):
    data_dir = os.path.abspath("./data")
    load_data(data_dir)
    config = {
        "l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),
        "l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),
        "lr": tune.loguniform(1e-4, 1e-1),
        "batch_size": tune.choice([2, 4, 8, 16])
    }
    # A scheduler decides which trials to run, stop, or pause
    scheduler = ASHAScheduler(
        metric="loss",
        mode="min",
        max_t=max_num_epochs,
        grace_period=1,
        reduction_factor=2)
    # print progress
    reporter = CLIReporter(
        # parameter_columns=["l1", "l2", "lr", "batch_size"],
        metric_columns=["loss", "accuracy", "training_iteration"])
    
    # run tuning
    result = tune.run(
        partial(train, data_dir=data_dir),
        resources_per_trial={"cpu": 2, "gpu": gpus_per_trial},
        config=config,
        num_samples=num_samples,
        scheduler=scheduler,
        progress_reporter=reporter)

    best_trial = result.get_best_trial("loss", "min", "last")
    print("Best trial config: {}".format(best_trial.config))
    print("Best trial final validation loss: {}".format(
        best_trial.last_result["loss"]))
    print("Best trial final validation accuracy: {}".format(
        best_trial.last_result["accuracy"]))

    best_trained_model = Net(best_trial.config["l1"], best_trial.config["l2"])
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda:0"
        if gpus_per_trial > 1:
            best_trained_model = nn.DataParallel(best_trained_model)
    best_trained_model.to(device)

    best_checkpoint_dir = best_trial.checkpoint.value
    model_state, optimizer_state = torch.load(os.path.join(
        best_checkpoint_dir, "checkpoint"))
    best_trained_model.load_state_dict(model_state)

    test_acc = test_accuracy(best_trained_model, device)
    print("Best trial test set accuracy: {}".format(test_acc))


if __name__ == "__main__":
    # You can change the number of GPUs per trial here:
    main(num_samples=10, max_num_epochs=10, gpus_per_trial=0)