In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
from torchvision.datasets import ImageFolder

from classification.network.resnet import SeResNet
from classification.utils.loaders import VehicleDataLoader
from classification.utils.transforms import VehicleTransform

train_dataset = ImageFolder("../data/vehicles/train")

num_classes = len(train_dataset.classes)
transform = VehicleTransform(size=(224, 224))

train_loader = VehicleDataLoader(
    train_dataset,
    train_transform=transform.train_transform,
    eval_transform=transform.eval_transform,
    batch_size=8,
    shuffle=True,
)
train_loader.train()
model = SeResNet(num_classes=6)

x, y = next(iter(train_loader))
pred = model(x)

In [25]:
train_loader.dataset.transform

Compose(
    Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=True)
    RandomHorizontalFlip(p=0.5)
    RandomRotation(degrees=[-15.0, 15.0], interpolation=nearest, expand=False, fill=0)
    ColorJitter(brightness=(0.8, 1.2), contrast=(0.8, 1.2), saturation=(0.8, 1.2), hue=None)
    RandomGrayscale(p=0.2)
    ToTensor()
    Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
)

In [26]:
pred

tensor([[-0.0272,  0.0333, -0.0345,  0.0185,  0.0499, -0.0281],
        [-0.0272,  0.0333, -0.0345,  0.0185,  0.0499, -0.0281],
        [-0.0424,  0.2013,  0.0552,  0.1123,  0.1894,  0.0087],
        [-0.0424,  0.2013,  0.0552,  0.1123,  0.1894,  0.0087],
        [ 1.9684,  0.3220, -2.3137,  2.1032,  1.4017,  1.1368],
        [-0.0424,  0.2013,  0.0552,  0.1123,  0.1894,  0.0087],
        [-0.0272,  0.0333, -0.0345,  0.0185,  0.0499, -0.0281],
        [-0.0424,  0.2013,  0.0552,  0.1123,  0.1894,  0.0087]])

In [27]:
import torch

torch.softmax(pred, dim=-1)

tensor([[0.1618, 0.1719, 0.1606, 0.1694, 0.1747, 0.1616],
        [0.1618, 0.1719, 0.1606, 0.1694, 0.1747, 0.1616],
        [0.1458, 0.1860, 0.1607, 0.1702, 0.1838, 0.1534],
        [0.1458, 0.1860, 0.1607, 0.1702, 0.1838, 0.1534],
        [0.2982, 0.0575, 0.0041, 0.3412, 0.1692, 0.1298],
        [0.1458, 0.1860, 0.1607, 0.1702, 0.1838, 0.1534],
        [0.1618, 0.1719, 0.1606, 0.1694, 0.1747, 0.1616],
        [0.1458, 0.1860, 0.1607, 0.1702, 0.1838, 0.1534]])

In [28]:
from torcheval.metrics import MulticlassAccuracy

metric = MulticlassAccuracy(average=None, num_classes=6)
metric.update(pred, y)

<torcheval.metrics.classification.accuracy.MulticlassAccuracy at 0x240ab17c170>

In [20]:
metric.compute()

tensor([0.5000,    nan, 0.3333, 0.0000, 0.0000,    nan])

In [210]:
import os
from pathlib import Path
from time import strftime
import torch
import torch.nn as nn
import torchinfo
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler

from classification.network import layers


class ModelCheckpoint:
    def __init__(
        self, model: Module, optimizer: Optimizer, scheduler: LRScheduler, freq: int = 1
    ) -> None:
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.freq = int(freq)
        self.checkpoints_dir = Path("./checkpoints/")
        self.run_dir: Path | None = None
        self.history: dict[int, list[float]] = dict()

    def __call__(self, history: dict[str, list[float]]) -> None:
        self.save(history)

    def save(self, history: dict[str, list[float]]) -> None:
        if self.run_dir is None:
            self.run_dir = self.checkpoints_dir / strftime("run-%Y-%m-%d-%H-%M-%S")
            self.run_dir.mkdir(parents=True, exist_ok=True)

        epoch = len(history["train_loss"])
        if epoch % self.freq == 0:
            checkpoint = Path(f"checkpoint-epoch-{epoch:03d}").with_suffix(".pt")
            torch.save(
                {
                    "history": dict(history),
                    "model": self.model.state_dict(),
                    "optimizer": self.optimizer.state_dict(),
                    "scheduler": self.scheduler.state_dict(),
                },
                self.run_dir / checkpoint,
            )

    def load(self) -> None:
        latest_checkpoint_path = max(
            self.checkpoints_dir.rglob("*.pt"), key=os.path.getctime, default=None
        )
        if latest_checkpoint_path is None:
            raise FileNotFoundError("No checkpoints found")

        checkpoint = torch.load(
            latest_checkpoint_path, weights_only=True, map_location=torch.device("cpu")
        )
        self.history = checkpoint["history"]
        self.model.load_state_dict(checkpoint["model"])
        self.optimizer.load_state_dict(checkpoint["optimizer"])
        self.scheduler.load_state_dict(checkpoint["scheduler"])

In [23]:
from classification.utils.callbacks import ModelCheckpoint

model = SeResNet(num_classes=6)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
model_checkpoint = ModelCheckpoint(model, optimizer, scheduler, freq=1)
model_checkpoint.load()

In [24]:
import pandas as pd

pd.DataFrame(model_checkpoint.history)#.plot()

Unnamed: 0,epoch,train_loss,train_accuracy,val_loss,val_accuracy
0,1,1.313808,0.605052,1.306793,0.641372
1,2,1.131823,0.729233,1.11188,0.736525
2,3,1.007131,0.75832,0.974949,0.775483
3,4,0.939409,0.766607,0.956383,0.756894
4,5,0.901797,0.801223,0.911345,0.770414
5,6,0.833613,0.772192,0.852732,0.767659
6,7,0.790328,0.785958,0.819621,0.78233
7,8,0.75588,0.835709,0.757462,0.838788
8,9,0.69357,0.83885,0.659808,0.866951
9,10,0.693816,0.834892,0.659488,0.829865
