In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import models
from torchvision.models.vgg import VGG
from sklearn.metrics import confusion_matrix
import pandas as pd
import scipy.misc
import random
import sys

if '/opt/ros/kinetic/lib/python2.7/dist-packages' in sys.path:
    sys.path.remove('/opt/ros/kinetic/lib/python2.7/dist-packages')
import cv2

from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader

from matplotlib import pyplot as plt
import numpy as np
import time
import os

In [2]:
means     = np.array([103.939, 116.779, 123.68]) / 255. # mean of three channels in the order of BGR
h, w      = 480, 640
val_h     = h
val_w     = w
n_class = 21

In [3]:
class FCN16s(nn.Module):

    def __init__(self, pretrained_net, n_class):
        super().__init__()
        self.n_class = n_class
        self.pretrained_net = pretrained_net
        self.relu    = nn.ReLU(inplace = True)
        self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn1     = nn.BatchNorm2d(512)
        self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn2     = nn.BatchNorm2d(256)
        self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn3     = nn.BatchNorm2d(128)
        self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn4     = nn.BatchNorm2d(64)
        self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn5     = nn.BatchNorm2d(32)
        self.classifier = nn.Conv2d(32, n_class, kernel_size=1)

    def forward(self, x):
        output = self.pretrained_net(x)
        x5 = output['x5']  # size=(N, 512, x.H/32, x.W/32)
        x4 = output['x4']  # size=(N, 512, x.H/16, x.W/16)

        score = self.relu(self.deconv1(x5))               # size=(N, 512, x.H/16, x.W/16)
        score = self.bn1(score + x4)                      # element-wise add, size=(N, 512, x.H/16, x.W/16)
        score = self.bn2(self.relu(self.deconv2(score)))  # size=(N, 256, x.H/8, x.W/8)
        score = self.bn3(self.relu(self.deconv3(score)))  # size=(N, 128, x.H/4, x.W/4)
        score = self.bn4(self.relu(self.deconv4(score)))  # size=(N, 64, x.H/2, x.W/2)
        score = self.bn5(self.relu(self.deconv5(score)))  # size=(N, 32, x.H, x.W)
        score = self.classifier(score)                    # size=(N, n_class, x.H/1, x.W/1)
        
        return score

In [4]:
class VGGNet(VGG):
    def __init__(self, pretrained=True, model='vgg16', requires_grad=True, remove_fc=True, show_params=False):
        super().__init__(make_layers(cfg[model]))
        self.ranges = ranges[model]

        if pretrained:
            exec("self.load_state_dict(models.%s(pretrained=True).state_dict())" % model)

        if not requires_grad:
            for param in super().parameters():
                param.requires_grad = False

        if remove_fc:  # delete redundant fully-connected layer params, can save memory
            del self.classifier

        if show_params:
            for name, param in self.named_parameters():
                print(name, param.size())

    def forward(self, x):
        output = {}

        # get the output of each maxpooling layer (5 maxpool in VGG net)
        for idx in range(len(self.ranges)):
            for layer in range(self.ranges[idx][0], self.ranges[idx][1]):      
                x = self.features[layer](x)
            output["x%d"%(idx+1)] = x
        return output

In [5]:
ranges = {
    'vgg11': ((0, 3), (3, 6),  (6, 11),  (11, 16), (16, 21)),
    'vgg13': ((0, 5), (5, 10), (10, 15), (15, 20), (20, 25)),
    'vgg16': ((0, 5), (5, 10), (10, 17), (17, 24), (24, 31)),
    'vgg19': ((0, 5), (5, 10), (10, 19), (19, 28), (28, 37))
}

# cropped version from https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py
cfg = {
    'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

def make_layers(cfg, batch_norm=False):
    layers = []
    in_channels = 3
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)

In [6]:
use_gpu = torch.cuda.is_available()
num_gpu = list(range(torch.cuda.device_count()))

vgg_model = VGGNet(requires_grad=True, remove_fc=True)
fcn_model = FCN16s(pretrained_net=vgg_model, n_class=n_class)
use_gpu = False
'''if use_gpu:
    ts = time.time()
    vgg_model = vgg_model.cuda()
    fcn_model = fcn_model.cuda()
    fcn_model = nn.DataParallel(fcn_model, device_ids=num_gpu)
    print("Finish cuda loading, time elapsed {}".format(time.time() - ts))'''

'if use_gpu:\n    ts = time.time()\n    vgg_model = vgg_model.cuda()\n    fcn_model = fcn_model.cuda()\n    fcn_model = nn.DataParallel(fcn_model, device_ids=num_gpu)\n    print("Finish cuda loading, time elapsed {}".format(time.time() - ts))'

In [7]:
class product_dataset(Dataset):

    def __init__(self, csv_file, phase, n_class=n_class, flip_rate=0.):
        self.data      = pd.read_csv(csv_file)
        self.means     = means
        self.n_class   = n_class
        self.flip_rate = flip_rate
        if phase == 'train':
            self.flip_rate = 0.5

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name   = self.data.iloc[idx, 0]
        img        = cv2.imread(os.path.join(data_dir, img_name),cv2.IMREAD_UNCHANGED)
        label_name = self.data.iloc[idx, 1]
        label      = cv2.imread(os.path.join(data_dir, label_name), cv2.IMREAD_GRAYSCALE)
        origin_img = img
        if random.random() < self.flip_rate:
            img   = np.fliplr(img)
            label = np.fliplr(label)

        # reduce mean
        img = img[:, :, ::-1]  # switch to BGR
        
        img = np.transpose(img, (2, 0, 1)) / 255.
        img[0] -= self.means[0]
        img[1] -= self.means[1]
        img[2] -= self.means[2]

        # convert to tensor
        img = torch.from_numpy(img.copy()).float()
        label = torch.from_numpy(label.copy()).long()

        # create one-hot encoding
        h, w = label.size()
        target = torch.zeros(self.n_class, h, w)
        
        '''for i in range(n_class - 1):
            if brand_class[i] in img_name:
                target[i+1][label != 0] = 1
                label[label != 0] = i+1
                #print(brand_class[i], img_name,label_name)'''
        for i in range(n_class):
            target[i][label == i] = 1
        
        target[0][label == 0] = 1
        #print(np.unique(label))
        
 
        sample = {'X': img, 'Y': target, 'l': label, 'origin': origin_img}

        return sample

In [8]:
model_use  = "unity" # "unity" "histogram" "box_gan"
data_dir  = os.path.join("/media/arg_ws3/5E703E3A703E18EB/data/mm_FCN", model_use)
model_dir = os.path.join("/media/arg_ws3/5E703E3A703E18EB/research/mm_fcn/models", model_use)
if not os.path.exists(data_dir):
    print("Data not found!")
val_file   = os.path.join(data_dir, "predict.csv")
val_data   = product_dataset(csv_file = val_file, phase = 'val', flip_rate = 0)
val_loader = DataLoader(val_data, batch_size = 4, num_workers = 0)

dataiter = iter(val_loader)

In [9]:
def prediction(model_name):
    
    # load pretrain models
    state_dict = torch.load(os.path.join(model_dir, model_name))
    fcn_model.load_state_dict(state_dict)
    
    batch = dataiter.next()
    if use_gpu:
        inputs = Variable(batch['X'].cuda())
    else:
        inputs = Variable(batch['X'])
    img    = batch['origin'] 
    label  = batch['l']
    
    output = fcn_model(inputs)
    output = output.data.cpu().numpy()

    N, _, h, w = output.shape
    pred = output.transpose(0, 2, 3, 1).reshape(-1, n_class).argmax(axis = 1).reshape(N, h, w)

    # show images
    plt.figure(figsize = (10, 12))
    img = img.numpy()
    for i in range(N):
        img[i] = cv2.cvtColor(img[i], cv2.COLOR_BGR2RGB)
        plt.subplot(N, 3, i*3 + 1)
        plt.title("origin_img")
        plt.imshow(img[i])
        #print(np.unique(_img[i]))

        plt.subplot(N, 3, i*3 + 2)
        plt.title("label_img")
        plt.imshow(label[i],cmap = "brg",vmin = 0, vmax = n_class - 1)

        plt.subplot(N, 3, i*3 + 3)
        plt.title("prediction")
        plt.imshow(pred[i],cmap = "brg",vmin = 0, vmax = n_class - 1)

    plt.show()

In [10]:
prediction("FCNs_unity_batch10_epoch0_RMSprop_lr0.0001.pkl")

RuntimeError: Error(s) in loading state_dict for FCN16s:
	Missing key(s) in state_dict: "pretrained_net.features.0.weight", "pretrained_net.features.0.bias", "pretrained_net.features.2.weight", "pretrained_net.features.2.bias", "pretrained_net.features.5.weight", "pretrained_net.features.5.bias", "pretrained_net.features.7.weight", "pretrained_net.features.7.bias", "pretrained_net.features.10.weight", "pretrained_net.features.10.bias", "pretrained_net.features.12.weight", "pretrained_net.features.12.bias", "pretrained_net.features.14.weight", "pretrained_net.features.14.bias", "pretrained_net.features.17.weight", "pretrained_net.features.17.bias", "pretrained_net.features.19.weight", "pretrained_net.features.19.bias", "pretrained_net.features.21.weight", "pretrained_net.features.21.bias", "pretrained_net.features.24.weight", "pretrained_net.features.24.bias", "pretrained_net.features.26.weight", "pretrained_net.features.26.bias", "pretrained_net.features.28.weight", "pretrained_net.features.28.bias", "deconv1.weight", "deconv1.bias", "bn1.weight", "bn1.running_var", "bn1.running_mean", "bn1.bias", "deconv2.weight", "deconv2.bias", "bn2.weight", "bn2.running_var", "bn2.running_mean", "bn2.bias", "deconv3.weight", "deconv3.bias", "bn3.weight", "bn3.running_var", "bn3.running_mean", "bn3.bias", "deconv4.weight", "deconv4.bias", "bn4.weight", "bn4.running_var", "bn4.running_mean", "bn4.bias", "deconv5.weight", "deconv5.bias", "bn5.weight", "bn5.running_var", "bn5.running_mean", "bn5.bias", "classifier.weight", "classifier.bias". 
	Unexpected key(s) in state_dict: "module.pretrained_net.features.0.weight", "module.pretrained_net.features.0.bias", "module.pretrained_net.features.2.weight", "module.pretrained_net.features.2.bias", "module.pretrained_net.features.5.weight", "module.pretrained_net.features.5.bias", "module.pretrained_net.features.7.weight", "module.pretrained_net.features.7.bias", "module.pretrained_net.features.10.weight", "module.pretrained_net.features.10.bias", "module.pretrained_net.features.12.weight", "module.pretrained_net.features.12.bias", "module.pretrained_net.features.14.weight", "module.pretrained_net.features.14.bias", "module.pretrained_net.features.17.weight", "module.pretrained_net.features.17.bias", "module.pretrained_net.features.19.weight", "module.pretrained_net.features.19.bias", "module.pretrained_net.features.21.weight", "module.pretrained_net.features.21.bias", "module.pretrained_net.features.24.weight", "module.pretrained_net.features.24.bias", "module.pretrained_net.features.26.weight", "module.pretrained_net.features.26.bias", "module.pretrained_net.features.28.weight", "module.pretrained_net.features.28.bias", "module.deconv1.weight", "module.deconv1.bias", "module.bn1.weight", "module.bn1.bias", "module.bn1.running_mean", "module.bn1.running_var", "module.bn1.num_batches_tracked", "module.deconv2.weight", "module.deconv2.bias", "module.bn2.weight", "module.bn2.bias", "module.bn2.running_mean", "module.bn2.running_var", "module.bn2.num_batches_tracked", "module.deconv3.weight", "module.deconv3.bias", "module.bn3.weight", "module.bn3.bias", "module.bn3.running_mean", "module.bn3.running_var", "module.bn3.num_batches_tracked", "module.deconv4.weight", "module.deconv4.bias", "module.bn4.weight", "module.bn4.bias", "module.bn4.running_mean", "module.bn4.running_var", "module.bn4.num_batches_tracked", "module.deconv5.weight", "module.deconv5.bias", "module.bn5.weight", "module.bn5.bias", "module.bn5.running_mean", "module.bn5.running_var", "module.bn5.num_batches_tracked", "module.classifier.weight", "module.classifier.bias". 