In [None]:
import os

import torch
from torch.utils.data import DataLoader
from torchvision import transforms

from data.utils import root
from data.dataset import MoanaDataset
from data.transform import (
    ToPILImage,
    RandomHorizontalFlip,
    RandomVerticalFlip,
    RandomDiscreteRotation,
    ToTensor,
)
from data.plot import imshow_image, imshow_label

from model.modules import RDUNet

%load_ext autoreload
%autoreload 2

## Build the Dataset and DataLoader

In [None]:
XY_data_train = MoanaDataset(
    os.path.join(root(), "nccos", "2007"), 
    (512, 512), 
    transform=transforms.Compose([
        ToPILImage(),
        RandomHorizontalFlip(),
        RandomVerticalFlip(),
        RandomDiscreteRotation([0, 90, 180, 270]),
        ToTensor(),
    ])
)

XY_load_train = DataLoader(
    XY_data_train, 
    batch_size=1,
    shuffle=True, 
    num_workers=4
)

#### Display 1 batch

In [None]:
images, labels = next(iter(XY_load_train))
imshow_image(images, 4)
imshow_label(labels, 4)

## Build the Model, Loss, and Optimization

In [None]:
model = RDUNet((3, 512, 512), 4, channels=32, depth=5)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-3)

def loss_func(Y_hat, Y):
    loss = torch.nn.functional.cross_entropy(Y_hat, Y.squeeze(1).long())
    return loss

#### Run on 1 sample

In [None]:
with torch.no_grad():
    output = model(images[0].unsqueeze(0))
print(output.shape)

#### Get the loss 

In [None]:
loss = loss_func(output, labels[0].unsqueeze(0))
print(loss)

## Build the Training Loop

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model.to(device)

for epoch in range(8):

    running_loss = 0.0
    for i, (image, label) in enumerate(XY_load_train):

        image = image.to(device)
        label = label.to(device)

        optimizer.zero_grad()

        output = model(image)

        loss = loss_func(output, label)

        loss.backward()

        optimizer.step()

        running_loss += loss.item()
        if i % 2 == 1:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0