In [1]:
import os
import sys
import time
import numpy as np
from matplotlib import pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader

In [2]:
from vgg import VGGNet
from fcn import FCNs, FCN8s, FCN16s, FCN32s

## Set Configs

In [16]:
BATCH_SIZE = 6
LR = 1e-4
MOMENTUM = 0
WEIGHT_DECAY = 1e-5
STEP_SIZE = 50
GAMMA = 0.5

N_CLASS = 20
N_EPOCHS = 500

In [4]:
is_cuda = torch.cuda.is_available()

if is_cuda: device = torch.device('cuda')
else: device = torch.device('cpu')

## Set Data Loader

In [21]:
train_loader = []

In [22]:
valid_loader = []

## Init VGG Model

In [13]:
vgg_model = VGGNet(requires_grad=True, remove_fc=True); vgg_model.to(device);

## Init FCN Model

In [14]:
fcn_model = FCNs(pretrained_net=vgg_model, n_class=N_CLASS); fcn_model.to(device);

## Set Loss Function

In [18]:
bce_loss = nn.BCEWithLogitsLoss()
bce_loss.to(device)

BCEWithLogitsLoss()

## Set Optimizer

In [19]:
optimizer = torch.optim.RMSprop(fcn_model.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)

In [20]:
scheduler = lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA) # decay LR by a factor of 0.5 every 30 epochs

## Train [FCN](https://arxiv.org/pdf/1605.06211.pdf) Network

In [24]:
for epoch in range(1, N_EPOCHS+1):
    
    tick = time.time()
    
    for i, batch in enumerate(train_loader):
        print(i)
        
        optimizer.zero_grad()
        
        x = Variable(batch['x']); x.to(device);
        y = Variable(batch['y'].cuda()); y.to(device);
        
        output = fcn_model(x)
        loss = bce_loss(output, y)
        loss.backward()
        
        optimizer.step()
        scheduler.step()
        
        if i % 10 == 0:
            print(f'Epoch: {epoch}, Batch: {i}, Loss: {loss.item()}')
            
    print(f'Finish Epoch {epoch}, Time Elapsed {time.time()-tick}')
    torch.save(fcn_model, './weights/fcn_model.')
    
    # validate the network

---