In [76]:
from functools import partial
import os
import tempfile
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import random_split, DataLoader
import torchvision
import torchvision.transforms as transforms
from ray import tune
from ray import train
from ray.train import Checkpoint, get_checkpoint
from ray.tune.schedulers import ASHAScheduler
import ray.cloudpickle as pickle

Construct convnet

In [77]:
class CNet(nn.Module):
    def __init__(self, l1=120, l2=84):
        super(CNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, l1)
        self.fc2 = nn.Linear(l1, l2)
        self.fc3 = nn.Linear(l2, 10)
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [78]:
def load_datasets(data_dir):
    transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )

    train_ds = torchvision.datasets.CIFAR10(
        data_dir,
        download=True,
        transform=transform,
        train=True,
    )

    test_ds = torchvision.datasets.CIFAR10(
        data_dir,
        download=True,
        transform=transform,
        train=False,
    )

    return train_ds, test_ds

In [79]:
# load net and optimizer state dicts
# returns current epoch (0 if no checkpoint)
def load_checkpoint(net, optimizer):
    checkpoint = get_checkpoint()
    if checkpoint:
        with checkpoint.as_directory() as checkpoint_dir:
            data_path = Path(checkpoint_dir) / 'data.pkl'
            with open(data_path, 'rb') as fp:
                checkpoint_state = pickle.load(fp)
            start_epoch = checkpoint_state['epoch']
            net.load_state_dict(checkpoint_state['net_state_dict'])
            optimizer.load_state_dict(checkpoint_state['optimizer_state_dict'])
    else:
        start_epoch = 0
    return start_epoch

In [80]:
# save state info and report training metrics
def save_checkpoint(net, optimizer, epoch, loss, accuracy):
    checkpoint_data = {
        'epoch': epoch,
        'net_state_dict': net.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }
    with tempfile.TemporaryDirectory() as checkpoint_dir:
        data_path = Path(checkpoint_dir) / 'data.pkl'
        with open(data_path, 'wb') as fp:
            pickle.dump(checkpoint_data, fp)
        checkpoint = Checkpoint.from_directory(checkpoint_dir)
        train.report(
            {'loss': loss, 'accuracy': accuracy},
            checkpoint=checkpoint,
        )

In [81]:
def train_epoch(net, optimizer, criterion, train_dl, device='cpu'):
    net.train()
    running_loss = 0.0
    epoch_steps = 0
    for i, data in enumerate(train_dl):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        epoch_steps += 1
        if i % 2000 == 1999:
            print(f'Batch {i:3d} train loss: {running_loss / epoch_steps:.3f}')
            running_loss = 0.0

In [82]:
def validate_epoch(net, dataloader, criterion, device):
    net.eval()
    val_loss = 0.0
    val_steps, total, correct = 0, 0, 0
    with torch.no_grad():
        for i, data in enumerate(dataloader):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            loss = criterion(outputs, labels)
            val_loss += loss.cpu().numpy()
            val_steps += 1
        loss = val_loss / val_steps
        accuracy = correct / total
        print(f'Validation loss: {loss:.5f}, {accuracy*100:.2f}% correct')
    return loss, accuracy


In [83]:
def get_device():
    device = "cpu"
    #if torch.mps.is_available():
    #    device = "mps"
    #if torch.cuda.is_available():
    #    device = "cuda"
    return device

In [84]:
def get_trainloaders(train_ds_full, config, train_frac=0.8):
    train_ds, val_ds = random_split(train_ds_full, [train_frac, 1-train_frac])
    train_dl, val_dl = [
        DataLoader(ds, batch_size=int(config['batch_size']),
                   shuffle=True, num_workers=8)
        for ds in [train_ds, val_ds]
    ]
    return train_dl, val_dl

In [None]:
def train_cifar(config, train_dataset):
    # build network and load checkpoint info
    net = CNet(config['l1'], config['l2'])
    device = get_device()
    net.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)
    start_epoch = load_checkpoint(net, optimizer)

    # build training dataset and dataloaders
    train_dl, val_dl = get_trainloaders(train_dataset, config, 0.8)

    # train and validate network
    for epoch in range(start_epoch, 10):
        print(f'Epoch {epoch} / 10')
        train_epoch(net, optimizer, criterion, train_dl, device)
        loss, accuracy = validate_epoch(net, val_dl, criterion, device)
        save_checkpoint(net, optimizer, epoch, loss, accuracy)

In [86]:
def test_accuracy(net, test_dataset, device='cpu'):
    testloader = DataLoader(
        test_dataset, batch_size=4, shuffle=False, num_works=2
    )
    correct = 0
    total = 0
    net.eval()
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

In [93]:
# experiment configuration

data_dir = os.path.abspath('./data')
max_num_epochs=10
num_samples=10

config = {
    'l1': tune.choice([2 ** i for i in range(9)]),
    'l2': tune.choice([2 ** i for i in range(9)]),
    'lr': tune.loguniform(1e-4, 1e-1),
    'batch_size': tune.choice([2 ** i for i in range(1, 5)])
}

scheduler = ASHAScheduler(
    metric='loss',
    mode='min',
    max_t=max_num_epochs,
    grace_period=1,
    reduction_factor=2
)

train_dataset, test_dataset = load_datasets(data_dir)

In [92]:
config = {
    'l1': 120,
    'l2': 84,
    'lr': 1e-3,
    'batch_size': 16,
}

train_cifar(config, train_dataset)

Epoch {epoch} / 10}




Batch 1999 train loss: 2.171


KeyboardInterrupt: 

In [94]:
result = tune.run(
    partial(train_cifar, train_dataset=train_dataset),
    resources_per_trial={'cpu':1, 'gpu': 0},
    config=config,
    num_samples=num_samples,
    scheduler=scheduler
)

2025-02-11 13:37:58,156	INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949


ValueError: Tracked actor is not managed by this event manager: <TrackedActor 225542504012728263302998860370061529421>

In [None]:
best_trial = result.get_best_trial("loss", "min", "last")
print(f"Best trial config: {best_trial.config}")
print(f"Best trial final validation loss: {best_trial.last_result['loss']}")
print(f"Best trial final validation accuracy: {best_trial.last_result['accuracy']}")

best_trained_model = CNet(best_trial.config["l1"], best_trial.config["l2"])
device = "cpu"
best_trained_model.to(device)

best_checkpoint = result.get_best_checkpoint(trial=best_trial, metric="accuracy", mode="max")
with best_checkpoint.as_directory() as checkpoint_dir:
    data_path = Path(checkpoint_dir) / "data.pkl"
    with open(data_path, "rb") as fp:
        best_checkpoint_data = pickle.load(fp)

    best_trained_model.load_state_dict(best_checkpoint_data["net_state_dict"])
    test_acc = test_accuracy(best_trained_model, device)
    print("Best trial test set accuracy: {}".format(test_acc))