In [1]:
import os

from pathlib import Path
from datetime import datetime

import matplotlib.pyplot as plt

import torch

from torchvision import datasets, transforms

from torch.utils.data import DataLoader
from torch.utils.data import random_split

from torchmetrics import Accuracy, Precision, Recall

from source.resnet import TorchModel, ResNet34
from source.callback import CompositeCallback, ClassificationReporter, Profiler, Saver, DefaultCallback
from source.plotting import matplotlib_imshow

In [2]:
cwd = Path(os.getcwd())


train_dir = cwd / "imagenette2-320" / "train"
test_dir = cwd / "imagenette2-320" / "test"

tsfm_train = transforms.Compose([
    transforms.CenterCrop(size=(224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

tsfm_test = transforms.Compose([
    transforms.CenterCrop(size=(224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


trainset = datasets.ImageFolder(root=train_dir, transform=tsfm_train)
testset = datasets.ImageFolder(root=test_dir, transform=tsfm_test)

assert len(trainset.classes) == len(testset.classes)

classes = trainset.classes
num_classes = len(classes)

card = int(len(trainset) * 0.8)
trainset, valset = random_split(trainset, [card, len(trainset) - card])

In [3]:
trainloader = DataLoader(dataset=trainset, batch_size=32, shuffle=True)
testloader = DataLoader(dataset=testset, batch_size=32, shuffle=True)
valloader = DataLoader(dataset=valset, batch_size=16, shuffle=True)

trainloader.classes, testloader.classes, valloader.classes = classes, classes, classes

In [4]:
model = ResNet34(num_classes=10)
criterion = torch.nn.CrossEntropyLoss(reduction="sum")
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
metrics = [Accuracy(), Precision(), Recall()]


model_repr = f"{model.extra_repr()}"
time = datetime.now().strftime("%b%d_%H-%M-%S")

log_dir = cwd / "log" / model_repr / time
save_dir = log_dir

callback = CompositeCallback([
    ClassificationReporter(log_dir),
    Profiler.make_default(log_dir),
    Saver(save_dir)
])


torchmodel = TorchModel(model, optimizer, criterion, metrics, callback)
torchmodel.train(trainloader, valloader, epochs=10_000)

Epoch loop:   0%|          | 0/1 [00:00<?, ?epoch/s]

Batch loop:   0%|          | 0/2 [00:00<?, ?batch/s]

((tensor([1.2313]), [tensor(0.5000), tensor(0.5000), tensor(0.5000)]),
 (tensor([0.7010]),
  [tensor(0.4667), tensor(0.4667), tensor(0.4667)],
  <Figure size 1296x518.4 with 5 Axes>))

In [5]:
test_info = torchmodel.test(testloader)


print(f"Loss:", test_info[0])

for i, metric in enumerate(torchmodel.metrics):
    metric_name = type(metric).__name__
    print(f"{metric_name}Top{metric.top_k}:", test_info[1][i])


# To view in Tensorboard
test_dir = cwd / "log" / model_repr / time / "test"

reporter = ClassificationReporter(test_dir)
reporter.report(torchmodel, test_info)

Loss: tensor([0.6866])
AccuracyTop1: tensor(0.5556)
PrecisionTop1: tensor(0.5556)
RecallTop1: tensor(0.5556)
