# PASCAL VOC - RefineNet

## 20 semantic classes + background

### RefineNet based on ResNet-101

In [None]:
import six
import sys
import json
sys.path.append('../../')

from models.resnet import rf101

In [None]:
from utils.helpers import prepare_img
from dataset_loader import EgoHandsDatasetLoader

In [None]:
%matplotlib inline

import glob

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from livelossplot import PlotLosses
from PIL import Image

In [None]:
cmap = np.load('../../utils/cmap.npy')
has_cuda = torch.cuda.is_available()
device = torch.device("cuda:0")
img_dir = '../imgs/VOC/'
imgs = glob.glob('{}*.jpg'.format(img_dir))
n_classes = 21

In [None]:
net = rf101(n_classes, pretrained=True).eval()
net = net.to(device)
net = nn.DataParallel(net)

In [None]:
batch_size = 16


# transform_train = transforms.Compose([
#     transforms.RandomCrop(32, padding=4),
#     transforms.RandomHorizontalFlip(),
#     transforms.ToTensor(),
#     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
# ])

# transform_test = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
# ]

train_set = EgoHandsDatasetLoader(dataset_root='../../', datafile='../../train_data.json', device=device, shuffle=True)
valid_set = EgoHandsDatasetLoader(dataset_root='../../', datafile='../../valid_data.json', device=device, shuffle=True)
test_set = EgoHandsDatasetLoader(dataset_root='../../', datafile='../../test_data.json', device=device, shuffle=True)

dataloaders = {
    "train": train_set,
    "validation": valid_set
}

In [None]:
n_rows = len(imgs)

def train_model(criterion, optimizer, num_epochs=100):
    liveloss = PlotLosses()
    
    for epoch in range(num_epochs):
        logs = {}
        for phase in ['train', 'validation']:
            if phase == 'train':
                net.train()
            else:
                net.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                print('len(dataloaders[phase]): ', len(dataloaders[phase]))
                img_inp = torch.tensor(prepare_img(inputs).transpose(2, 0, 1)[None]).float()
                img_inp = img_inp.to(device)
                outputs = net(img_inp)

                new_labels = np.zeros((21, 180, 320), dtype=np.uint8)
                resized_lbl = cv2.resize(labels, (320, 180), interpolation=cv2.INTER_CUBIC)
                resized_lbl = resized_lbl / 255.0
                resized_lbl = (resized_lbl >= 0.5).astype(int)
                resized_lbl[resized_lbl > 0.5] = 15 # Label should be 15? Try 1
                new_labels[15,:,:] = resized_lbl # Person channel
                new_labels = torch.tensor(new_labels, device=device).float()
                new_labels = new_labels.unsqueeze(dim=0)

#                 print('outputs: ', type(outputs))
#                 print('new_labels: ', type(new_labels))
#                 print('outputs shape: ', outputs.shape)
#                 print('new_labels shape: ', new_labels.shape)
#                 print('outputs type: ', outputs.dtype)
#                 print('new_labels type: ', new_labels.dtype)
                loss = criterion(outputs, new_labels)

                if phase == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                _, preds = torch.max(outputs, 1)
                running_loss = loss
                print('loss: ', running_loss)
                
                epoch_loss = running_loss
                # epoch_acc = running_corrects.float() / len(dataloaders[phase])

            prefix = ''
            if phase == 'validation':
                prefix = 'val_'

            logs[prefix + 'log loss'] = epoch_loss.item()

#             liveloss.update(logs)
#             liveloss.draw()

In [None]:
optimizer = optim.RMSprop(net.parameters(), lr=0.001)
criterion = nn.BCELoss()

train_model(criterion, optimizer, num_epochs=1)