# PASCAL VOC - RefineNet

## 20 semantic classes + background

### RefineNet based on ResNet-101

In [1]:
import six
import sys
import json
sys.path.append('../../')

from models.resnet import rf101

In [2]:
from utils.helpers import prepare_img
from dataset_loader import EgoHandsDatasetLoader

In [3]:
%matplotlib inline

import glob

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from livelossplot import PlotLosses
from PIL import Image

In [4]:
cmap = np.load('../../utils/cmap.npy')
has_cuda = torch.cuda.is_available()
device = torch.device("cuda:0")
img_dir = '../imgs/VOC/'
imgs = glob.glob('{}*.jpg'.format(img_dir))
n_classes = 1

In [5]:
# Initialise models
model_inits = { 
    'rf_101_voc'  : rf101,
}

models = dict()
for key,fun in six.iteritems(model_inits):
    net = fun(n_classes, pretrained=True).eval()
    if has_cuda:
        net = net.to(device)
    models[key] = net

In [6]:
batch_size = 16

# transform_train = transforms.Compose([
#     transforms.RandomCrop(32, padding=4),
#     transforms.RandomHorizontalFlip(),
#     transforms.ToTensor(),
#     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
# ])

# transform_test = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
# ])

train_set = EgoHandsDatasetLoader(dataset_root='../../', datafile='../../train_data.json', device=device, shuffle=True)
valid_set = EgoHandsDatasetLoader(dataset_root='../../', datafile='../../valid_data.json', device=device, shuffle=True)
test_set = EgoHandsDatasetLoader(dataset_root='../../', datafile='../../test_data.json', device=device, shuffle=True)

dataloaders = {
    "train": train_set,
    "validation": valid_set
}

In [7]:
# Figure 2 from the paper
n_cols = len(models) + 2 # 1 - for image, 1 - for GT
n_rows = len(imgs)

plt.figure(figsize=(16, 12))
idx = 1

def train_model(model, criterion, optimizer, num_epochs=80):
    liveloss = PlotLosses()
    #model = model.to(device)
    
    for mname, mnet in six.iteritems(models):
        for epoch in range(num_epochs):
            logs = {}
            for phase in ['train', 'validation']:
                if phase == 'train':
                    mnet.train()
                else:
                    mnet.eval()

                running_loss = 0.0
                running_corrects = 0

                for inputs, labels in dataloaders[phase]:
                    img_inp = torch.tensor(prepare_img(inputs).transpose(2, 0, 1)[None]).float()
                    img_inp = img_inp.to(device)
                    outputs = model(img_inp)
                    outputs = outputs.to(device)
                    
                    new_labels = np.zeros((1, 180, 320))
                    new_labels[0,:,:] = cv2.resize(labels[0,:,:], (320, 180), interpolation=cv2.INTER_CUBIC)
                    new_labels = torch.tensor(new_labels, device=device).long()
#                     labels = labels.unsqueeze(dim=0)
#                     labels[labels > 0] = 1
#                     labels = labels.unsqueeze(dim=0)
#                     labels = labels.unsqueeze(dim=0)
                    
#                     f, axarr = plt.subplots(2,1)
#                     axarr[0,0].imshow(inputs)
#                     axarr[0,1].imshow(labels)
                    
                    print('outputs: ', type(outputs))
                    print('new_labels: ', type(new_labels))
                    print('outputs shape: ', outputs.shape)
                    print('new_labels shape: ', new_labels.shape)
                    loss = criterion(outputs, new_labels)

                    if phase == 'train':
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                    _, preds = torch.max(outputs, 1)
                    running_loss += loss.detach() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)

                epoch_loss = running_loss / len(dataloaders[phase].dataset)
                epoch_acc = running_corrects.float() / len(dataloaders[phase].dataset)

                prefix = ''
                if phase == 'validation':
                    prefix = 'val_'

                logs[prefix + 'log loss'] = epoch_loss.item()
                logs[prefix + 'accuracy'] = epoch_acc.item()

            liveloss.update(logs)
            liveloss.draw()
# with torch.no_grad():
#     for img_path in imgs:
#         img = np.array(Image.open(img_path))
#         msk = cmap[np.array(Image.open(img_path.replace('jpg', 'png')))]
#         orig_size = img.shape[:2][::-1]
        
#         img_inp = torch.tensor(prepare_img(img).transpose(2, 0, 1)[None]).float()
#         if has_cuda:
#             img_inp = img_inp.cuda()
        
#         plt.subplot(n_rows, n_cols, idx)
#         plt.imshow(img)
#         plt.title('img')
#         plt.axis('off')
#         idx += 1
        
#         plt.subplot(n_rows, n_cols, idx)
#         plt.imshow(msk)
#         plt.title('gt')
#         plt.axis('off')
#         idx += 1
        
#         for mname, mnet in six.iteritems(models):
#             segm = mnet.cuda()(img_inp)[0].data.cpu().numpy().transpose(1, 2, 0)
#             segm = cv2.resize(segm, orig_size, interpolation=cv2.INTER_CUBIC)
#             segm = cmap[segm.argmax(axis=2).astype(np.uint8)]
            
#             plt.subplot(n_rows, n_cols, idx)
#             plt.imshow(segm)
#             plt.title(mname)
#             plt.axis('off')
#             idx += 1
#             break

<Figure size 1152x864 with 0 Axes>

In [8]:
optimizer = optim.RMSprop(net.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss()

train_model(net, criterion, optimizer, num_epochs=80)

TypeError: Broadcast function not implemented for CPU tensors

In [None]:
# Convert RGB segmentation to Black and White

bw = np.asarray(segm).copy()

# Pixel range is 0...255, 256/2 = 128
bw[bw < 128] = 0    # Black
bw[bw >= 128] = 255 # White

plt.imshow(bw)
bw.shape