In [1]:
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from filelock import FileLock
from torch.utils.data import random_split
import torchvision
import torchvision.transforms as transforms
import ray
from ray import tune
from ray.air import session
from ray.air.checkpoint import Checkpoint
from ray.tune.schedulers import ASHAScheduler

In [2]:
def load_data(data_dir="./data"):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # We add FileLock here because multiple workers will want to
    # download data, and this may cause overwrites since
    # DataLoader is not threadsafe.
    with FileLock(os.path.expanduser("~/.data.lock")):
        trainset = torchvision.datasets.CIFAR10(
            root=data_dir, train=True, download=True, transform=transform)

        testset = torchvision.datasets.CIFAR10(
            root=data_dir, train=False, download=True, transform=transform)

    return trainset, testset

In [3]:
class Net(nn.Module):
    def __init__(self, l1=120, l2=84):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, l1)
        self.fc2 = nn.Linear(l1, l2)
        self.fc3 = nn.Linear(l2, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [4]:
def train_cifar(config):
    net = Net(config["l1"], config["l2"])

    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda:0"
        if torch.cuda.device_count() > 1:
            net = nn.DataParallel(net)
    net.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)

    # To restore a checkpoint, use `session.get_checkpoint()`.
    loaded_checkpoint = session.get_checkpoint()
    if loaded_checkpoint:
        with loaded_checkpoint.as_directory() as loaded_checkpoint_dir:
           model_state, optimizer_state = torch.load(os.path.join(loaded_checkpoint_dir, "checkpoint.pt"))
        net.load_state_dict(model_state)
        optimizer.load_state_dict(optimizer_state)

    data_dir = os.path.abspath("./data")
    trainset, testset = load_data(data_dir)

    test_abs = int(len(trainset) * 0.8)
    train_subset, val_subset = random_split(
        trainset, [test_abs, len(trainset) - test_abs])

    trainloader = torch.utils.data.DataLoader(
        train_subset,
        batch_size=int(config["batch_size"]),
        shuffle=True,
        num_workers=8)
    valloader = torch.utils.data.DataLoader(
        val_subset,
        batch_size=int(config["batch_size"]),
        shuffle=True,
        num_workers=8)

    for epoch in range(10):  # loop over the dataset multiple times
        running_loss = 0.0
        epoch_steps = 0
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            epoch_steps += 1
            if i % 2000 == 1999:  # print every 2000 mini-batches
                print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1,
                                                running_loss / epoch_steps))
                running_loss = 0.0

        # Validation loss
        val_loss = 0.0
        val_steps = 0
        total = 0
        correct = 0
        for i, data in enumerate(valloader, 0):
            with torch.no_grad():
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = net(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                loss = criterion(outputs, labels)
                val_loss += loss.cpu().numpy()
                val_steps += 1

        # Here we save a checkpoint. It is automatically registered with
        # Ray Tune and can be accessed through `session.get_checkpoint()`
        # API in future iterations.
        os.makedirs("my_model", exist_ok=True)
        torch.save(
            (net.state_dict(), optimizer.state_dict()), "my_model/checkpoint.pt")
        checkpoint = Checkpoint.from_directory("my_model")
        session.report({"loss": (val_loss / val_steps), "accuracy": correct / total}, checkpoint=checkpoint)
    print("Finished Training")

In [5]:
def _test_accuracy(net, device="cpu"):
    trainset, testset = load_data()

    testloader = torch.utils.data.DataLoader(
        testset, batch_size=4, shuffle=False, num_workers=2)

    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return correct / total

In [6]:
def test_best_model(best_result):
    best_trained_model = Net(best_result.config["l1"], best_result.config["l2"])
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    best_trained_model.to(device)

    checkpoint_path = os.path.join(best_result.checkpoint.to_directory(), "checkpoint.pt")

    model_state, optimizer_state = torch.load(checkpoint_path)
    best_trained_model.load_state_dict(model_state)

    trainset, testset = load_data()

    testloader = torch.utils.data.DataLoader(
        testset, batch_size=4, shuffle=False, num_workers=2)

    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = best_trained_model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()


    print("Best trial test set accuracy: {}".format(correct / total))

In [7]:
config = {
    "l1": tune.sample_from(lambda _: 2**np.random.randint(2, 9)),
    "l2": tune.sample_from(lambda _: 2**np.random.randint(2, 9)),
    "lr": tune.loguniform(1e-4, 1e-1),
    "batch_size": tune.choice([2, 4, 8, 16]),
}

In [8]:
def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
    config = {
        "l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),
        "l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),
        "lr": tune.loguniform(1e-4, 1e-1),
        "batch_size": tune.choice([2, 4, 8, 16])
    }
    scheduler = ASHAScheduler(
        max_t=max_num_epochs,
        grace_period=1,
        reduction_factor=2)
    
    tuner = tune.Tuner(
        tune.with_resources(
            tune.with_parameters(train_cifar),
            resources={"cpu": 2, "gpu": gpus_per_trial}
        ),
        tune_config=tune.TuneConfig(
            metric="loss",
            mode="min",
            scheduler=scheduler,
            num_samples=num_samples,
        ),
        param_space=config,
    )
    results = tuner.fit()
    
    best_result = results.get_best_result("loss", "min")

    print("Best trial config: {}".format(best_result.config))
    print("Best trial final validation loss: {}".format(
        best_result.metrics["loss"]))
    print("Best trial final validation accuracy: {}".format(
        best_result.metrics["accuracy"]))

    test_best_model(best_result)

main(num_samples=2, max_num_epochs=2, gpus_per_trial=0)

0,1
Current time:,2023-07-30 11:22:17
Running for:,00:01:27.21
Memory:,6.5/27.4 GiB

Trial name,status,loc,batch_size,lr,iter,total time (s),loss,accuracy
train_cifar_1572f_00000,RUNNING,172.25.212.162:13280,16,0.0650363,,,,
train_cifar_1572f_00001,TERMINATED,172.25.212.162:13281,16,0.0126473,2.0,58.9053,1.47403,0.4745


[2m[36m(train_cifar pid=13281)[0m Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /home/luca/ray_results/train_cifar_2023-07-30_11-20-47/train_cifar_1572f_00001_1_batch_size=16,lr=0.0126_2023-07-30_11-20-50/data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]
  0%|          | 32768/170498071 [00:00<18:42, 151886.89it/s]
  0%|          | 65536/170498071 [00:00<18:40, 152133.44it/s]
  0%|          | 98304/170498071 [00:00<18:34, 152839.64it/s]
  0%|          | 229376/170498071 [00:00<08:31, 332709.71it/s]
  0%|          | 458752/170498071 [00:01<04:44, 596811.87it/s]
  0%|          | 851968/170498071 [00:01<03:56, 715811.37it/s]
  1%|          | 1671168/170498071 [00:01<01:45, 1593809.72it/s]
  1%|▏         | 2555904/170498071 [00:01<01:13, 2284854.95it/s]
  2%|▏         | 3145728/170498071 [00:02<01:09, 2420286.00it/s]
  2%|▏         | 3768320/170498071 [00:02<00:55, 2988489.39it/s]
  2%|▏         | 4161536/170498071 [00:02<00:57, 2894371.11it/s]
  3%|▎         | 4521984/170498071 [00:02<01:03, 2614205.92it/s]
  3%|▎         | 5111808/170498071 [00:02<01:01, 2692175.63it/s]
  3%|▎         | 5799936/170498071 [00:03<00:57, 2853634.65it/s]
  4%|▍         | 6488064/170498071 [00:03<00:55, 2972146.88

[2m[36m(train_cifar pid=13281)[0m Extracting /home/luca/ray_results/train_cifar_2023-07-30_11-20-47/train_cifar_1572f_00001_1_batch_size=16,lr=0.0126_2023-07-30_11-20-50/data/cifar-10-python.tar.gz to /home/luca/ray_results/train_cifar_2023-07-30_11-20-47/train_cifar_1572f_00001_1_batch_size=16,lr=0.0126_2023-07-30_11-20-50/data
[2m[36m(train_cifar pid=13281)[0m Files already downloaded and verified
[2m[36m(train_cifar pid=13280)[0m Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /home/luca/ray_results/train_cifar_2023-07-30_11-20-47/train_cifar_1572f_00000_0_batch_size=16,lr=0.0650_2023-07-30_11-20-50/data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]
  0%|          | 32768/170498071 [00:00<37:15, 76252.20it/s]
  0%|          | 65536/170498071 [00:00<26:20, 107805.59it/s]
  0%|          | 98304/170498071 [00:00<22:54, 123998.15it/s]
  0%|          | 131072/170498071 [00:01<28:34, 99346.47it/s]
  0%|          | 163840/170498071 [00:01<31:42, 89538.18it/s]
  0%|          | 196608/170498071 [00:02<33:37, 84396.99it/s]
  0%|          | 229376/170498071 [00:02<34:48, 81539.17it/s]
  0%|          | 262144/170498071 [00:02<30:53, 91854.34it/s]
  0%|          | 294912/170498071 [00:03<32:52, 86277.56it/s]
  0%|          | 327680/170498071 [00:03<33:06, 85665.66it/s]
  0%|          | 360448/170498071 [00:04<35:33, 79731.71it/s]
  0%|          | 393216/170498071 [00:04<36:05, 78567.17it/s]
  0%|          | 425984/170498071 [00:04<32:29, 87250.93it/s]
  0%|          | 458752/170498071 [00:05<32:16, 87825.95it/s]
  0%|          | 491520/170498071 [00:05<36:28, 77669.16it/s]
  0%|          | 524288/1

[2m[36m(train_cifar pid=13281)[0m [1,  2000] loss: 1.814


  0%|          | 557056/170498071 [00:06<33:15, 85173.49it/s]
  0%|          | 589824/170498071 [00:06<32:54, 86070.58it/s]
  0%|          | 622592/170498071 [00:07<32:24, 87381.30it/s]
  0%|          | 655360/170498071 [00:07<30:00, 94353.83it/s]
  0%|          | 688128/170498071 [00:07<33:47, 83762.35it/s]
  0%|          | 720896/170498071 [00:08<34:49, 81262.86it/s]
  0%|          | 753664/170498071 [00:08<35:33, 79557.36it/s]
  0%|          | 786432/170498071 [00:09<41:36, 67972.42it/s]
  0%|          | 819200/170498071 [00:09<40:18, 70150.82it/s]
  0%|          | 851968/170498071 [00:10<39:20, 71868.19it/s]
  1%|          | 884736/170498071 [00:10<38:42, 73023.70it/s]
  1%|          | 917504/170498071 [00:11<39:29, 71563.01it/s]
  1%|          | 950272/170498071 [00:11<38:46, 72867.05it/s]
  1%|          | 983040/170498071 [00:12<45:01, 62748.44it/s]
  1%|          | 1015808/170498071 [00:12<42:38, 66232.72it/s]
  1%|          | 1048576/170498071 [00:13<52:08, 54163.07it/s]
  1%| 

[2m[36m(train_cifar pid=13281)[0m [2,  2000] loss: 1.532


  1%|          | 1277952/170498071 [00:18<54:23, 51856.53it/s]
  1%|          | 1310720/170498071 [00:18<54:44, 51506.93it/s]
  1%|          | 1343488/170498071 [00:19<53:56, 52270.54it/s]
  1%|          | 1376256/170498071 [00:20<1:01:06, 46126.86it/s]
  1%|          | 1409024/170498071 [00:21<1:03:52, 44115.53it/s]
  1%|          | 1441792/170498071 [00:22<1:05:10, 43236.43it/s]
  1%|          | 1474560/170498071 [00:22<1:04:04, 43962.87it/s]
  1%|          | 1507328/170498071 [00:23<1:12:15, 38979.74it/s]
  1%|          | 1540096/170498071 [00:24<1:19:45, 35304.54it/s]
  1%|          | 1572864/170498071 [00:25<1:13:55, 38085.91it/s]
  1%|          | 1605632/170498071 [00:26<1:08:23, 41159.18it/s]
  1%|          | 1638400/170498071 [00:26<1:02:32, 44994.69it/s]
  1%|          | 1671168/170498071 [00:27<55:24, 50775.38it/s]  
  1%|          | 1703936/170498071 [00:27<54:36, 51518.14it/s]
  1%|          | 1736704/170498071 [00:28<52:40, 53401.31it/s]
  1%|          | 1769472/170498071 

Best trial config: {'l1': 32, 'l2': 16, 'lr': 0.012647291777407418, 'batch_size': 16}
Best trial final validation loss: 1.4740296248435973
Best trial final validation accuracy: 0.4745
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 65536/170498071 [00:00<18:36, 152612.61it/s]

  0%|          | 458752/170498071 [00:01<04:44, 596957.17it/s]

In [None]:
print(f"#trials={len(result.trials)}")
print(f"time={time.time()-start_time}")
best_trial = result.get_best_trial("loss", "min", "all")
print("Best trial config: {}".format(best_trial.config))
print("Best trial final validation loss: {}".format(
    best_trial.metric_analysis["loss"]["min"]))
print("Best trial final validation accuracy: {}".format(
    best_trial.metric_analysis["accuracy"]["max"]))

best_trained_model = Net(2**best_trial.config["l1"],
                         2**best_trial.config["l2"])
device = "cpu"
if torch.cuda.is_available():
    device = "cuda:0"
    if gpus_per_trial > 1:
        best_trained_model = nn.DataParallel(best_trained_model)
best_trained_model.to(device)

checkpoint_value = getattr(best_trial.checkpoint, "dir_or_data", None) or best_trial.checkpoint.value
checkpoint_path = os.path.join(checkpoint_value, "checkpoint")

model_state, optimizer_state = torch.load(checkpoint_path)
best_trained_model.load_state_dict(model_state)

test_acc = _test_accuracy(best_trained_model, device)
print("Best trial test set accuracy: {}".format(test_acc))