In [1]:
import os
import sys
import time
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader

In [2]:
from vgg import VGGNet
from fcn import FCNs, FCN8s, FCN16s, FCN32s

In [3]:
from camvid import CamVid

## Set Configs

In [4]:
BATCH_SIZE = 6
LR = 1e-4
MOMENTUM = 0
WEIGHT_DECAY = 1e-5
STEP_SIZE = 50
GAMMA = 0.5

N_EPOCHS = 500

In [5]:
is_cuda = torch.cuda.is_available()

if is_cuda: device = torch.device('cuda')
else: device = torch.device('cpu')

## Set Data Loader

In [6]:
DATASET_DIR = '../datasets/camvid/data'
HEIGHT, WIDTH = 224, 224
BATCH_SIZE = 10
WORKERS = 4

In [7]:
data_transform = transforms.Compose([transforms.Resize((HEIGHT, WIDTH)), 
                                     transforms.ToTensor()])

label_transform = transforms.Compose([transforms.Resize((HEIGHT, WIDTH), Image.NEAREST),
                                      transforms.ToTensor()])

In [8]:
train_set = CamVid(DATASET_DIR, mode='train', 
                   data_transform=data_transform, label_transform=label_transform)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=WORKERS)

In [9]:
valid_set = CamVid(DATASET_DIR, mode='valid', 
                   data_transform=data_transform, label_transform=label_transform)

valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=WORKERS)

In [10]:
test_set = CamVid(DATASET_DIR, mode='valid', 
                  data_transform=data_transform, label_transform=label_transform)

test_loader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=WORKERS)

In [13]:
N_CLASS = len(CamVid.color_encoding)

## Init VGG Model

In [14]:
vgg_model = VGGNet(requires_grad=True, remove_fc=True); vgg_model.to(device);

## Init FCN Model

In [15]:
fcn_model = FCNs(pretrained_net=vgg_model, n_class=N_CLASS); fcn_model.to(device);

## Set Loss Function

In [16]:
bce_loss = nn.BCEWithLogitsLoss()
bce_loss.to(device)

BCEWithLogitsLoss()

## Set Optimizer

In [17]:
optimizer = torch.optim.RMSprop(fcn_model.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)

In [18]:
scheduler = lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA) # decay LR by a factor of 0.5 every 30 epochs

## Train [FCN](https://arxiv.org/pdf/1605.06211.pdf) Network

In [None]:
for epoch in range(1, N_EPOCHS+1):
    
    tick = time.time()
    
    for i, batch in enumerate(train_loader):
        
        optimizer.zero_grad()
        
        x = Variable(batch[0]); x.to(device);
        y = Variable(batch[2]); y.to(device);
        
        output = fcn_model(x)
        loss = bce_loss(output, y)
        loss.backward()
        
        optimizer.step()
        scheduler.step()
        
        if i % 10 == 0:
            print(f'Epoch: {epoch}, Batch: {i}, Loss: {loss.item()}')
            
    print(f'Finish Epoch {epoch}, Time Elapsed {time.time()-tick}')
    # torch.save(fcn_model, './weights/fcn_model')

Epoch: 1, Batch: 0, Loss: 0.7244716882705688


---