In [2]:
from functools import partial
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import random_split
import torchvision
import torchvision.transforms as transforms
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler

In [3]:
def load_data(data_dir = 'data'):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    trainset = torchvision.datasets.CIFAR10(
        root=data_dir, train=True, download=True, transform=transform
    )
    testset = torchvision.datasets.CIFAR10(
        root=data_dir, train=False, download=True, transform=transform
    )
    
    return trainset, testset

In [4]:
class Net(nn.Module):
    def __init__(self, l1=120, l2=84):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2  = nn.Conv2d(6, 16, 5)
        self.f1 = nn.Linear(16 * 5 * 5, l1)
        self.f2 = nn.Linear(l1, l2)
        self.f3 = nn.Linear(l2, 10)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.viwew(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
        

In [5]:
config = {
    "l1" : tune.sample_from(lambda _: 2**np.random.randint(2, 9)),
    "l2" : tune.sample_from(lambda _: 2**np.random.randint(2, 9)),
    "lr" : tune.loguniform(1e-4, 1e-1),
    "batch_size": tune.choice([2, 4, 6, 8])
}

num_samples=10 
max_num_epochs=10
gpus_per_trial=2

scheduler = ASHAScheduler(
    metric="loss",
    mode="min",
    max_t=max_num_epochs,
    grace_period=1,
    reduction_factor=2
)

repoter = CLIReporter(
    metric_columns=["loss", "accuracy", "training_iteration"]
)

In [7]:
def train_cifar(config, checkpoint_dir=None, data_dir=None):
    
    net = Net(config["l1"], config["l2"])
    device = "mps"
    net.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)
    
    if checkpoint_dir:
        model_state, optimizer_state = torch.load(
            os.path.join(checkpoint_dir, "checkpoint"))
        net.load_state_dict(model_state)
        optimizer.load_state_dict(optimizer_state)
    
    trainset, testset = load_data(data_dir)
    
    test_abs = int(len(trainset) * 0.8)
    train_subset, val_subset = random_split(
        trainset, [test_abs, len(trainset) - test_abs])
    
    trainloader = torch.utils.data.DataLoader(
        train_subset,
        batch_size = int(config["batch_size"]),
        shuffle = True,
        num_workers=8
    )
    
    valloader = torch.utils.DataLoader(
        val_subset,
        batch_size = int(config("batch_size")),
        shuffle = True,
        num_workers=8
    )
    
    for epoch in range(10):
        running_loss = 0.0
        epoch_steps = 0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            epoch_steps += 1
            if i % 2000 == 1999:
                print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1,
                                                running_loss / epoch_steps))
                running_loss = 0.0    
    return None

In [8]:
result = tune.run(
    partial(train_cifar, data_dir=data_dir)
)

NameError: name 'data_dir' is not defined