In [None]:
import os

import torch
from torch.utils.data import DataLoader
from torchvision import transforms

from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import (
    Accuracy, 
    Loss,
    Fbeta, 
    DiceCoefficient, 
    ConfusionMatrix
)

from data.utils import root
from data.dataset import MoanaDataset
from data.transform import (
    ToPILImage,
    RandomHorizontalFlip,
    RandomVerticalFlip,
    RandomDiscreteRotation,
    RandomCrop,
    ToTensor
)
from data.plot import imshow_image, imshow_label

from model.modules import RDUNet

%load_ext autoreload
%autoreload 2

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Build the Dataset and DataLoader

In [None]:
XY_data = MoanaDataset(
    os.path.join(root(), "nccos", "2007"), 
    (512, 512), 
    transform=transforms.Compose([
        ToPILImage(),
        RandomHorizontalFlip(),
        RandomVerticalFlip(),
        RandomDiscreteRotation([0, 90, 180, 270]),
        RandomCrop((256, 256)),
        ToTensor(),
        transforms.Lambda(lambda data: (data[0], data[1].squeeze(0).long()))
    ])
)

XY_train, XY_valid = MoanaDataset.split(XY_data, 0.8)

XY_load_train = DataLoader(
    XY_train, 
    batch_size=8,
    shuffle=True, 
    num_workers=4
)

XY_load_valid = DataLoader(
    XY_valid, 
    batch_size=8,
    shuffle=True, 
    num_workers=4
)

#### Display 1 batch

In [None]:
images, labels = next(iter(XY_load_train))
imshow_image(images, 4)
imshow_label(labels, 4)

## Build the Model, Loss, and Optimization

In [None]:
model = RDUNet(next(iter(XY_load_train))[0].shape[1:], 4, channels=32, depth=5)

model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

def loss_func(Y_hat, Y):
    loss = torch.nn.functional.cross_entropy(Y_hat, Y.squeeze(1).long())
    return loss

#### Run on 1 sample

In [None]:
with torch.no_grad():
    output = model(next(iter(XY_load_train))[0][0].unsqueeze(0).to(device))
print(output.shape)

#### Get the loss 

In [None]:
loss = loss_func(output, labels[0].unsqueeze(0).to(device))
print(loss)

## Build the Training Loop

In [None]:
trainer = create_supervised_trainer(
    model, 
    optimizer, 
    loss_func,
    device=device
)

evaluator = create_supervised_evaluator(
    model,
    metrics={
        'loss': Loss(loss_func),
        'f1': Fbeta(1)
    },
    device=device
)

@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(trainer):
    print("Epoch[{}] Loss: {:.2f}".format(trainer.state.epoch, trainer.state.output))

@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(trainer):
    evaluator.run(XY_load_train)
    metrics = evaluator.state.metrics
    print("Training Results - Epoch: {}  Avg Loss: {:.2f} Avg F1: {}"
          .format(trainer.state.epoch, metrics['loss'], metrics['f1']))

@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(trainer):
    evaluator.run(XY_load_valid)
    metrics = evaluator.state.metrics
    print("Validation Results - Epoch: {}  Avg Loss: {:.2f} Avg F1: {}"
          .format(trainer.state.epoch, metrics['loss'], metrics['f1']))


In [None]:
trainer.run(XY_load_train, max_epochs=10)

#### Run on 1 train sample

In [None]:
img, lab = next(iter(XY_load_train))

output = model(img[0].unsqueeze(0).to(device))
imshow_image(img[0].unsqueeze(0).cpu().detach(), 4)
imshow_label(lab[0].unsqueeze(0).cpu().detach(), 4)
imshow_label(torch.argmax(output.cpu().detach(), dim=1), 4)

#### Run on 1 validation sample

In [None]:
img, lab = next(iter(XY_load_valid))

output = model(img[0].unsqueeze(0).to(device))
imshow_image(img[0].unsqueeze(0).cpu().detach(), 4)
imshow_label(lab[0].unsqueeze(0).cpu().detach(), 4)
imshow_label(torch.argmax(output.cpu().detach(), dim=1), 4)