In [2]:
import os

import torch
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms

import src.modelling.utils as utils
from src.modelling.resnet_model import ResNet

In [3]:
DATA_DIR = "./data/"
BATCH_SIZE = 1
EPOCHS = 5
MAX_LR = 0.01
GRAD_CLIP = 0.1
WEIGHT_DECAY = 1e-4

In [4]:
image_transformer = transforms.Compose([
    transforms.RandomCrop(64, padding=4, padding_mode='reflect'),
    transforms.Resize(64),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010), inplace=True)
])


train_loader = DataLoader(
    dataset=ImageFolder(os.path.join(DATA_DIR, "train"), transform=image_transformer),
    batch_size=BATCH_SIZE,
    shuffle=True,
)

valid_loader = DataLoader(
    dataset=ImageFolder(os.path.join(DATA_DIR, "valid"), transform=image_transformer),
    batch_size=BATCH_SIZE,
    shuffle=True,
)

test_loader = DataLoader(
    dataset=ImageFolder(os.path.join(DATA_DIR, "test"), transform=image_transformer),
    batch_size=BATCH_SIZE,
)

In [5]:
model = ResNet(in_channels=3, n_classes=4)

In [6]:
optimizer = torch.optim.Adam(model.parameters(), lr=MAX_LR, weight_decay=WEIGHT_DECAY)

In [None]:
utils.train_model(
    model=model, 
    epochs=EPOCHS,
    train_loader=train_loader,
    valid_loader=valid_loader,
    optimizer=optimizer,
    max_lr=MAX_LR,
    weight_decay=WEIGHT_DECAY,
    grad_clip=GRAD_CLIP,
)

tensor(1.4391, grad_fn=<NllLossBackward0>)
tensor(0.9498, grad_fn=<NllLossBackward0>)
tensor(1.8616, grad_fn=<NllLossBackward0>)
tensor(1.2598, grad_fn=<NllLossBackward0>)
tensor(1.7506, grad_fn=<NllLossBackward0>)
tensor(1.1251, grad_fn=<NllLossBackward0>)
tensor(1.2681, grad_fn=<NllLossBackward0>)
tensor(1.4753, grad_fn=<NllLossBackward0>)
tensor(1.3357, grad_fn=<NllLossBackward0>)
tensor(2.0526, grad_fn=<NllLossBackward0>)
tensor(1.4497, grad_fn=<NllLossBackward0>)
tensor(1.5442, grad_fn=<NllLossBackward0>)
tensor(1.6347, grad_fn=<NllLossBackward0>)
tensor(0.9558, grad_fn=<NllLossBackward0>)
tensor(1.7569, grad_fn=<NllLossBackward0>)
tensor(2.0158, grad_fn=<NllLossBackward0>)
tensor(1.3123, grad_fn=<NllLossBackward0>)
tensor(1.9049, grad_fn=<NllLossBackward0>)
tensor(1.6064, grad_fn=<NllLossBackward0>)
tensor(1.7232, grad_fn=<NllLossBackward0>)
tensor(1.2404, grad_fn=<NllLossBackward0>)
tensor(1.0550, grad_fn=<NllLossBackward0>)
tensor(1.7949, grad_fn=<NllLossBackward0>)
tensor(2.14