In [1]:
import os
import sys
import time
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader

In [2]:
from vgg import VGGNet
from fcn import FCNs, FCN8s, FCN16s, FCN32s

In [3]:
from utils import calculate_iou, calculate_acc
from camvid import CamVid

## Set Configs

In [4]:
BATCH_SIZE = 6
LR = 1e-4
MOMENTUM = 0
WEIGHT_DECAY = 1e-5
STEP_SIZE = 50
GAMMA = 0.5

N_EPOCHS = 500

In [5]:
is_cuda = torch.cuda.is_available()

if is_cuda: device = torch.device('cuda')
else: device = torch.device('cpu')

## Set Data Loader

In [6]:
DATASET_DIR = '../datasets/camvid/data'
HEIGHT, WIDTH = 224, 224
BATCH_SIZE = 10
WORKERS = 4

In [7]:
data_transform = transforms.Compose([transforms.Resize((HEIGHT, WIDTH)), 
                                     transforms.ToTensor()])

label_transform = transforms.Compose([transforms.Resize((HEIGHT, WIDTH), Image.NEAREST),
                                      transforms.ToTensor()])

In [8]:
train_set = CamVid(DATASET_DIR, mode='train', 
                   data_transform=data_transform, label_transform=label_transform)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=WORKERS)

In [9]:
valid_set = CamVid(DATASET_DIR, mode='valid', 
                   data_transform=data_transform, label_transform=label_transform)

valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=WORKERS)

In [10]:
test_set = CamVid(DATASET_DIR, mode='valid', 
                  data_transform=data_transform, label_transform=label_transform)

test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=True, num_workers=WORKERS)

In [11]:
N_CLASS = len(CamVid.color_encoding)

## Init VGG Model

In [12]:
vgg_model = VGGNet(requires_grad=True, remove_fc=True); vgg_model.to(device);

## Init FCN Model

In [13]:
fcn_model = FCNs(pretrained_net=vgg_model, n_class=N_CLASS); fcn_model.to(device);

## Set Loss Function

In [14]:
bce_loss = nn.BCEWithLogitsLoss()
bce_loss.to(device)

BCEWithLogitsLoss()

## Set Optimizer

In [15]:
optimizer = torch.optim.RMSprop(fcn_model.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)

In [16]:
scheduler = lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA) # decay LR by a factor of 0.5 every 30 epochs

## Train [FCN](https://arxiv.org/pdf/1605.06211.pdf) Network

In [None]:
for epoch in range(1, N_EPOCHS+1):
    
    tick = time.time()
    train_loss = []; valid_loss = []; best_loss = np.inf
    
    for i, batch in enumerate(train_loader):
        print(i)
        optimizer.zero_grad()
        
        inputs = Variable(batch[0]); inputs.to(device);
        labels = Variable(batch[2]); labels.to(device);
        
        outputs = fcn_model(inputs)
        loss = bce_loss(outputs, labels)
        loss.backward()
        
        optimizer.step()
        scheduler.step()
        
        train_loss.append(loss.item())
            
    # validate the network
    fcn_model.eval()
    accuracy = []; total_iou = []

    for i, batch in enumerate(valid_loader):
        print(i)
        inputs = Variable(batch[0]); inputs.to(device);
        labels = Variable(batch[2]); labels.to(device);
        
        outputs = fcn_model(inputs)
        loss = bce_loss(outputs, labels)
        
        valid_loss.append(loss.item())
        
        outputs = outputs.data.cpu().numpy()
        N, _, h, w = outputs.shape
        prediction = outputs.transpose(0, 2, 3, 1).reshape(-1, N_CLASS)
        prediction = prediction.argmax(axis=1)
        prediction = prediction.reshape(N, h, w)
        
        labels = labels.cpu().numpy()
        labels = labels.transpose(0, 2, 3, 1).reshape(-1, N_CLASS)
        labels = labels.argmax(axis=1)
        labels = labels.reshape(N, h, w)
        
        for pred, label in zip(prediction, labels):
            iou = calculate_iou(pred, label, N_CLASS)
            acc = calculate_acc(pred, label)
            
            total_iou.append(iou); accuracy.append(acc)
            
    # calculate average IoU, accuracy & loss
    total_iou = np.array(total_iou).T # N_CLASS * valid_len
    iou = np.nanmean(total_iou, axis=1)
    accuracy = np.array(accuracy).mean()
    
    train_loss = train_loss/len(train_loader)
    valid_loss = valid_loss/len(valid_loader)
    
    if epoch == 1 or (valid_loss < best_loss):
        best_loss = valid_loss
        torch.save(fcn_model, f'./weights/fcn_model_loss{best_loss}.hdf5')
    
    print(f'Epoch {epoch}, Valid Loss: {valid_loss}', 
          f'Accuracy: {accuracy}, mIoU: {np.nanmean(iou)}, Time Elapsed {time.time()-tick}')

0
1
2
3
4


---