# Train U-Net Network

In [None]:
import os
import sys
import time
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader

In [None]:
from unet import UNet

In [None]:
from utils import calculate_iou, calculate_acc, moving_average
from camvid import CamVid

In [None]:
plt.style.use('ggplot')

## Set Configs

In [None]:
BATCH_SIZE = 1
LR = 1e-1
WEIGHT_DECAY = 1e-8

N_EPOCHS = 5

In [None]:
is_cuda = torch.cuda.is_available()

if is_cuda: device = torch.device('cuda')
else: device = torch.device('cpu')

## Set Data Loader

In [None]:
DATASET_DIR = '../datasets/camvid/data'
HEIGHT, WIDTH = 224, 224
WORKERS = 4

In [None]:
data_transform = transforms.Compose([transforms.Resize((HEIGHT, WIDTH)), 
                                     transforms.ToTensor()])

label_transform = transforms.Compose([transforms.Resize((HEIGHT, WIDTH), Image.NEAREST),
                                      transforms.ToTensor()])

In [None]:
train_set = CamVid(DATASET_DIR, mode='train', 
                   data_transform=data_transform, label_transform=label_transform)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=WORKERS)

In [None]:
valid_set = CamVid(DATASET_DIR, mode='valid', 
                   data_transform=data_transform, label_transform=label_transform)

valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=WORKERS)

In [None]:
test_set = CamVid(DATASET_DIR, mode='valid', 
                  data_transform=data_transform, label_transform=label_transform)

test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=True, num_workers=WORKERS)

In [None]:
N_CLASS = len(CamVid.color_encoding)

## Init UNet Model

In [None]:
unet_model = UNet(num_channels=3, num_classes=N_CLASS); unet_model.to(device);

## Set Loss Function

In [None]:
bce_loss = nn.BCEWithLogitsLoss(); bce_loss.to(device)

## Set Optimizer

In [None]:
optimizer = optim.RMSprop(unet_model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

## Train [U-Net](https://arxiv.org/pdf/1505.04597.pdf) Network

In [None]:
for epoch in range(1, N_EPOCHS+1):
    
    unet_model.train()
    
    tick = time.time()
    train_loss = []; valid_loss = []; best_loss = np.inf
    
    for i, batch in enumerate(train_loader):
        
        optimizer.zero_grad()
        
        inputs = Variable(batch[0]); inputs.to(device);
        labels = Variable(batch[2]); labels.to(device);
        
        outputs = unet_model(inputs)
        loss = bce_loss(outputs, labels)
        loss.backward()
        
        optimizer.step()
        
        train_loss.append(loss.item())
        
    # validate the unet network
    unet_model.eval()
    accuracy = []; total_iou = []

    for i, batch in enumerate(valid_loader):
        
        inputs = Variable(batch[0]); inputs.to(device);
        labels = Variable(batch[2]); labels.to(device);
        
        outputs = unet_model(inputs)
        loss = bce_loss(outputs, labels)
        
        valid_loss.append(loss.item())
        
        outputs = outputs.data.cpu().numpy()
        N, _, h, w = outputs.shape
        prediction = outputs.transpose(0, 2, 3, 1).reshape(-1, N_CLASS)
        prediction = prediction.argmax(axis=1)
        prediction = prediction.reshape(N, h, w)
        
        labels = labels.cpu().numpy()
        labels = labels.transpose(0, 2, 3, 1).reshape(-1, N_CLASS)
        labels = labels.argmax(axis=1)
        labels = labels.reshape(N, h, w)
        
        for pred, label in zip(prediction, labels):
            iou = calculate_iou(pred, label, N_CLASS)
            acc = calculate_acc(pred, label)
            
            total_iou.append(iou); accuracy.append(acc)
        
    # calculate average IoU, accuracy & loss
    total_iou = np.array(total_iou).T # N_CLASS * valid_len
    iou = np.nanmean(total_iou, axis=1)
    accuracy = np.array(accuracy).mean()
    
    train_loss = np.mean(train_loss)
    valid_loss = np.mean(valid_loss)
    
    if epoch == 1 or (valid_loss < best_loss):
        
        # set the lower valid loss as best loss
        best_loss = valid_loss
        
        if os.path.exists('./weights') is not True: os.makedirs('./weights')
        torch.save(unet_model, f'./weights/fcn_model_loss{best_loss}.hdf5')
    
    print(f'Epoch {epoch}, Valid Loss: {valid_loss:.6f},', 
          f'Accuracy: {accuracy:.6f}, mIoU: {np.nanmean(iou):.6f}, Time Taken: {time.time()-tick:.2f}s')

## Evaluate The Network

In [None]:
all_train_loss = moving_average(train_loss).tolist()
all_valid_loss = moving_average(valid_loss).tolist()

plt.plot(all_train_loss, label="Train Loss")
plt.plot(all_valid_loss, label="Valid Loss")
plt.title('Train & Valid Loss Metric of Training Process')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()

if os.path.exists('./images') is not True: os.makedirs('./images')
plt.savefig('./images/plot_train_validation_loss.png')
plt.show()

## Test The Network

In [None]:
# validate the unet network
unet_model.eval()
accuracy = []; total_iou = []; test_loss = []

for i, batch in enumerate(test_loader):
    
    inputs = Variable(batch[0]); inputs.to(device);
    labels = Variable(batch[2]); labels.to(device);

    outputs = unet_model(inputs)
    loss = bce_loss(outputs, labels)

    test_loss.append(loss.item())

    outputs = outputs.data.cpu().numpy()
    N, _, h, w = outputs.shape
    prediction = outputs.transpose(0, 2, 3, 1).reshape(-1, N_CLASS)
    prediction = prediction.argmax(axis=1)
    prediction = prediction.reshape(N, h, w)

    labels = labels.cpu().numpy()
    labels = labels.transpose(0, 2, 3, 1).reshape(-1, N_CLASS)
    labels = labels.argmax(axis=1)
    labels = labels.reshape(N, h, w)

    for pred, label in zip(prediction, labels):
        iou = calculate_iou(pred, label, N_CLASS)
        acc = calculate_acc(pred, label)

        total_iou.append(iou); accuracy.append(acc)
        
# calculate average IoU, accuracy & loss
total_iou = np.array(total_iou).T # N_CLASS * valid_len
iou = np.nanmean(total_iou, axis=1)
accuracy = np.array(accuracy).mean()

test_loss = np.mean(test_loss)

print(f'U-Net - Test Loss: {test_loss:.6f}, Accuracy: {accuracy:.6f}, mIoU: {np.nanmean(iou):.6f}')

---