In [1]:
import os

import torch
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms

import src.modelling.utils as utils
from src.modelling.resnet_model import ResNet

%load_ext autoreload
%autoreload 2

In [2]:
DATA_DIR = "./data/"
BATCH_SIZE = 8
EPOCHS = 5
MAX_LR = 0.01
GRAD_CLIP = 0.1
WEIGHT_DECAY = 1e-4

In [3]:
image_transformer = transforms.Compose([
    transforms.RandomCrop(64, padding=4, padding_mode='reflect'),
    transforms.Resize(64),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=(0.4914, 0.4822, 0.4465),
        std=(0.2023, 0.1994, 0.2010),
        inplace=True
    )
])


train_loader = DataLoader(
    dataset=ImageFolder(os.path.join(DATA_DIR, "train"), transform=image_transformer),
    batch_size=BATCH_SIZE,
    shuffle=True,
    pin_memory=True,
    num_workers=4,
)

valid_loader = DataLoader(
    dataset=ImageFolder(os.path.join(DATA_DIR, "valid"), transform=image_transformer),
    batch_size=BATCH_SIZE,
    shuffle=True,
    pin_memory=True,
    num_workers=4,
)

test_loader = DataLoader(
    dataset=ImageFolder(os.path.join(DATA_DIR, "test"), transform=image_transformer),
    batch_size=BATCH_SIZE,
    pin_memory=True,
    num_workers=4,
)

In [None]:
model = ResNet(in_channels=3, n_classes=4)

optimizer = torch.optim.Adam(model.parameters(), lr=MAX_LR, weight_decay=WEIGHT_DECAY)

torch.cuda.empty_cache()
utils.train_model(
    model=model, 
    epochs=2,
    train_loader=train_loader,
    valid_loader=valid_loader,
    optimizer=optimizer,
    max_lr=MAX_LR,
    weight_decay=WEIGHT_DECAY,
    grad_clip=GRAD_CLIP,
)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 493/493 [00:52<00:00,  9.40it/s]


Epoch [0] | lr: 0.00040 | train loss: 1.1180 | valid loss: 0.5702 | accuracy: 0.7328


  6%|██████▊                                                                                                                | 28/493 [00:02<00:47,  9.75it/s]

In [None]:
model = model.to(torch.device("cpu"))
torch.cuda.empty_cache()