## Model Train

This notebook contains the process of training a model locally and saving metadata, training and evaluation metrics, as well as the artifacts to the public [W&B Experiment](https://wandb.ai/mikasenghaas/bsc?workspace=user-mikasenghaas). 

In [None]:
import sys
sys.path.insert(0, "../src")

In [None]:
import os
import random

import torch
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from timeit import default_timer

# custom scripts
from config import *
from utils import *
from model import MODELS, FinetunedImageClassifier
from transform import ImageTransformer
from data import ImageDataset

## Hyperparameters

In [None]:
# specify model args
MODEL = "resnet18"
PRETRAINED = True

assert MODEL in MODELS, f"Specified model has to be one of {list(MODELS.keys())}"

In [None]:
# specify data args
FILEPATH = PROCESSED_DATA_PATH
INCLUDE_CLASSES = CLASSES
RATIO = 1.0

In [None]:
# specify training args
MAX_EPOCHS = 1
BATCH_SIZE = 32
LR = 1e-4
STEP_SIZE = 5
GAMMA = 0.1

In [None]:
# specify wand args
WANDB_LOG = False
WANDB_NAME = ""
WANDB_GROUP = ""
WANDB_TAGS = []

In [None]:
# start run
import wandb

if WANDB_LOG:
    wandb.init(
        project="bsc", 
        group=WANDB_GROUP if WANDB_GROUP else None, 
        name=WANDB_NAME if WANDB_NAME else None, 
        tags=WANDB_TAGS if WANDB_TAGS else None)

    wandb.define_metric("training_loss", summary="min")
    wandb.define_metric("validation_loss", summary="min")
    wandb.define_metric("training_accuracy", summary="max")
    wandb.define_metric("validation_accuracy", summary="max")

## Load Data, Transforms, Model

In [None]:
# initialise data and loaders
data = { split: ImageDataset(split=split, include_classes=INCLUDE_CLASSES, ratio=RATIO) for split in SPLITS } 
loader = { split: DataLoader(data[split], batch_size=BATCH_SIZE) for split in SPLITS}

id2class, class2id = data["train"].id2class, data["train"].class2id

In [None]:
# initialise transforms
transform = ImageTransformer()

In [None]:
# initialise model
model = FinetunedImageClassifier(
        model_name=MODEL,
        num_classes=len(INCLUDE_CLASSES),
        pretrained=PRETRAINED, 
        id2class=id2class,
        class2id=class2id)

## Train Model

In [None]:
# define loss, optimiser and lr scheduler
criterion = nn.CrossEntropyLoss() # pyright: ignore
optim = torch.optim.Adam(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optim, STEP_SIZE, GAMMA)

In [None]:
def train(model, transform, train_loader, val_loader, criterion, optim, scheduler):
    model.to(DEVICE)
    pbar = tqdm(range(MAX_EPOCHS))
    pbar.set_description(f'XXX/XX (XX.Xms/ XX.Xms) - Train: X.XXX (XX.X%) - Val: X.XXX (XX.X%)')
    train_loss, val_loss = 0.0, 0.0
    train_acc, val_acc = 0.0, 0.0
    training_times, inference_times = [], []
    for epoch in pbar:
        running_loss, running_correct = 0.0, 0
        running_training_time, running_inference_time = 0.0, 0.0
        model.train()
        for batch_num, (inputs, labels) in enumerate(train_loader):
            inputs = inputs.to(DEVICE)
            labels = labels.to(DEVICE)
  
            # zero the parameter gradients
            optim.zero_grad()
  
            # forward pass
            start = default_timer()
            logits = model(transform(inputs))
            running_inference_time += default_timer() - start

            # compute predictions
            preds = torch.argmax(logits, 1)

            # compute loss
            loss = criterion(logits, labels)
  
            # backprop error
            loss.backward()
            optim.step()

            running_training_time += default_timer() - start

            # performance metrics
            running_loss += loss.item()
            running_correct += torch.sum(preds == labels)
            samples_seen = (batch_num + 1) * BATCH_SIZE

            # normalise
            train_acc = running_correct / samples_seen
            train_loss = running_loss / samples_seen
            
            pbar.set_description(f'{str(epoch).zfill(len(str(MAX_EPOCHS)))}/{str(batch_num).zfill(len(str(len(train_loader))))} ({round(running_training_time / samples_seen * 1000, 1)}ms | {round(running_inference_time / samples_seen * 1000, 1)}ms) - Train: {train_loss:.3f} ({(train_acc * 100):.1f}%) - Val: {val_loss:.3f} ({(val_acc * 100):.1f}%)')

            # log epoch metrics for train and val split
            if WANDB_LOG:
                wandb.log({
                    'training_accuracy': train_acc, 
                    'validation_accuracy': val_acc,
                    'training_loss': train_loss, 
                    'validation_loss': val_loss})

        training_times.append(running_training_time)
        inference_times.append(running_inference_time)
                
        if val_loader != None:
            running_loss, running_correct = 0.0, 0
            model.eval()
            for batch_num, (inputs, labels) in enumerate(val_loader):
                inputs = inputs.to(DEVICE)
                labels = labels.to(DEVICE)
      
                logits = model(transform(inputs))
                preds = torch.argmax(logits, 1)
                loss = criterion(logits, labels)

                # accumulate loss and correct predictions
                running_loss += loss.item()
                running_correct += torch.sum(labels == preds)

            val_loss = running_loss / len(val_loader.dataset)
            val_acc = running_correct / len(val_loader.dataset)

            pbar.set_description(f'{str(epoch).zfill(len(str(MAX_EPOCHS)))}/00 - Train: {train_loss:.3f} ({(train_acc * 100):.1f}%) - Val: {val_loss:.3f} ({(val_acc * 100):.1f}%)')

        # adjust learning rate
        scheduler.step()

    # log average training step time/ sample + inference time/ sample
    if WANDB_LOG:
        wandb.config.update({
            "training_time_per_sample_ms" : round(sum(training_times) / len(training_times), 1),
            "inference_time_per_sample_ms" : round(sum(inference_times) / len(inference_times), 1)
            })

    return model

In [None]:
trained_model = train(model, transform, loader["train"], loader["val"], criterion, optim, scheduler)

## Example Prediction

In [None]:
# set model in inference model
trained_model.eval()
trained_model.to('cpu')

# load images from test split
images, labels = next(iter(loader["test"]))
test_id2class = data["test"].id2class

# predict on images
logits = trained_model(transform(images))
probs = softmax(logits, 1)
max_probs, preds = torch.max(probs, 1)

# show images alongside true and predicted label
show_images(images, titles=[f"True: {test_id2class[labels[i].item()]}\nPred: {id2class[preds[i].item()]}" for i in range(len(labels))])