In [1]:
import os

from pathlib import Path
from datetime import datetime
from functools import partial

import torch

from torchvision import datasets, transforms

from torch.utils.data import DataLoader
from torch.utils.data import random_split

from torchmetrics import Accuracy, Precision, Recall

from resnet import TorchModel, ResNet, ResNet34, ResNet50
from callback import CompositeCallback, ClassificationReporter, Profiler, Saver, DefaultCallback
from plotting import matplotlib_imshow

from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler

In [None]:
cwd = Path(os.getcwd())


train_dir = cwd / "imagenette2-320" / "train"
test_dir = cwd / "imagenette2-320" / "test"

tsfm_train = transforms.Compose([
    transforms.CenterCrop(size=(224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

tsfm_test = transforms.Compose([
    transforms.CenterCrop(size=(224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


trainset = datasets.ImageFolder(root=train_dir, transform=tsfm_train)
testset = datasets.ImageFolder(root=test_dir, transform=tsfm_test)

assert len(trainset.classes) == len(testset.classes)

classes = trainset.classes
num_classes = len(classes)

card = int(len(trainset) * 0.8)
trainset, valset = random_split(trainset, [card, len(trainset) - card])

In [None]:
trainloader = DataLoader(dataset=trainset, batch_size=32, shuffle=True)
testloader = DataLoader(dataset=testset, batch_size=32, shuffle=True)
valloader = DataLoader(dataset=valset, batch_size=5, shuffle=True)

trainloader.classes, testloader.classes, valloader.classes = classes, classes, classes

In [None]:
tune_dir = cwd / "tune"

max_num_epochs = 1_000
num_samples = 10
gpus_per_trial = 2


config = {
    "arch": [ResNet34(num_classes), ResNet50(num_classes)],
    "lr": tune.loguniform(1e-4, 1e-1),
    "batch_size": tune.choice([2, 4, 8, 16, 32, 64])
}
scheduler = ASHAScheduler(
    metric="loss",
    mode="min",
    max_t=max_num_epochs,
    grace_period=1,
    reduction_factor=2
)
reporter = CLIReporter(
    metric_columns=["loss", "accuracy", "training_iteration"]
)
result = tune.run(
    partial(torchmodel.train, tune=True),
    resources_per_trial={"cpu": 1, "gpu": gpus_per_trial},
    config=config,
    num_samples=num_samples,
    scheduler=scheduler,
    progress_reporter=reporter
)


In [None]:
def tune(config, trainloader, valloader, checkpoint_dir=None):
    model = ResNet(config["block_cls"], [config["l1"], config["l2"], config["l3"], config["l4"]], num_classes=10)
    criterion = torch.nn.CrossEntropyLoss(reduction="sum")
    optimizer = torch.optim.SGD(model.parameters(), lr=config["lr"], momentum=0.9)

    torchmodel = TorchModel(model, optimizer, criterion)
    train_info, val_info = torchmodel.train(trainloader, valloader, epochs=10_000)


    with tune.checkpoint_dir(epoch) as checkpoint_dir:
        path = os.path.join(checkpoint_dir, "checkpoint")
        torch.save((net.state_dict(), optimizer.state_dict()), path)

        tune.report(loss=(val_loss / val_steps), accuracy=correct / total)