In [1]:
%matplotlib inline
import scipy.misc
from matplotlib import pyplot
from glob import glob
import os 
import numpy as np
import re
import random
import time
from imageio import imread
from skimage.transform import resize
import torch
import torch.nn as nn
from torch.nn.functional import relu
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
import torch.nn.functional as F
from data_helper import MyDataset

In [2]:
batch_size, n_class, h, w = 10, 2, 224, 320
image_shape = (h, w, 3)
data_dir = 'data_road'

myTrainDataset = MyDataset(data_dir, image_shape=image_shape, isTrain=True, n_class=n_class)
myDataLoader = DataLoader(myTrainDataset, batch_size=batch_size, shuffle=True)

myTestDataset = MyDataset(data_dir, image_shape=image_shape, isTrain=False, n_class=n_class)
myTestDataLoader = DataLoader(myTestDataset, batch_size=batch_size, shuffle=True)

In [3]:
def VGG(pretrained=True):
    model = torchvision.models.vgg16(pretrained=pretrained)
    if pretrained:
        model.load_state_dict(model.state_dict())
    return model

In [4]:
# https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/surgery.py
def get_upsampling_weight(in_channels, out_channels, kernel_size):
    """Make a 2D bilinear kernel suitable for upsampling"""
    factor = (kernel_size + 1) // 2
    if kernel_size % 2 == 1:
        center = factor - 1
    else:
        center = factor - 0.5
    og = np.ogrid[:kernel_size, :kernel_size]
    filt = (1 - abs(og[0] - center) / factor) * \
           (1 - abs(og[1] - center) / factor)
    weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size),
                      dtype=np.float64)
    weight[range(in_channels), range(out_channels), :, :] = filt
    return torch.from_numpy(weight).float()

In [5]:
class FCN32(nn.Module):
    def __init__(self, n_class=2, requires_grad=True):
        super(FCN32, self).__init__()
        self.requires_grad = requires_grad
        # conv1
        self.conv1_1 = nn.Conv2d(3, 64, 3, padding=100)
        self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1)
        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2
        # conv2
        self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1)
        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4
        # conv3
        self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1)
        self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1)
        self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1)
        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8
        # conv4
        self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1)
        self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1)
        self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1)
        self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16
        # conv5
        self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1)
        self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1)
        self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1)
        self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/32
        # FC6
        self.fc6 = nn.Conv2d(512, 4096, 7)
        self.drop6 = nn.Dropout2d()
        # FC7
        self.fc7 = nn.Conv2d(4096, 4096, 1)
        self.drop7 = nn.Dropout2d()
        # score
        self.score = nn.Conv2d(4096, n_class, 1)
        self.final = nn.ConvTranspose2d(n_class, n_class, 64, stride=32, bias=False)
        # initialize weights and bias
        self._initialize_weights()
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.weight.data.zero_()
                if m.bias is not None:
                    m.bias.data.zero_()
            if isinstance(m, nn.ConvTranspose2d):
                assert m.kernel_size[0] == m.kernel_size[1]
                initial_weight = get_upsampling_weight(m.in_channels, m.out_channels, m.kernel_size[0])
                m.weight.data.copy_(initial_weight)
    def forward(self, x):
        # conv1
        _ = relu(self.conv1_1(x), inplace=self.requires_grad)
        _ = relu(self.conv1_2(_), inplace=self.requires_grad)
        _ = self.pool1(_)
        # conv2
        _ = relu(self.conv2_1(_), inplace=self.requires_grad)
        _ = relu(self.conv2_2(_), inplace=self.requires_grad)
        _ = self.pool2(_)
        # conv3
        _ = relu(self.conv3_1(_), inplace=self.requires_grad)
        _ = relu(self.conv3_2(_), inplace=self.requires_grad)
        _ = relu(self.conv3_3(_), inplace=self.requires_grad)
        _ = self.pool3(_)
        # conv4
        _ = relu(self.conv4_1(_), inplace=self.requires_grad)
        _ = relu(self.conv4_2(_), inplace=self.requires_grad)
        _ = relu(self.conv4_3(_), inplace=self.requires_grad)
        _ = self.pool4(_)
        # conv5
        _ = relu(self.conv5_1(_), inplace=self.requires_grad)
        _ = relu(self.conv5_2(_), inplace=self.requires_grad)
        _ = relu(self.conv5_3(_), inplace=self.requires_grad)
        _ = self.pool5(_)
        # FC6
        _ = relu(self.fc6(_), inplace=True)
        _ = self.drop6(_)
        # FC7
        _ = relu(self.fc7(_), inplace=True)
        _ = self.drop7(_)
        # score
        _ = self.score(_)
        # final
        _f = self.final(_)
        _f = _f[:, :, 19:19+x.size()[2], 19:19+x.size()[3]].contiguous()
        return _f
    def copy_params_from_vgg16(self, vgg16):
        features = [
            self.conv1_1, None, self.conv1_2, None, None,
            self.conv2_1, None, self.conv2_2, None, None,
            self.conv3_1, None, self.conv3_2, None, self.conv3_3, None, None, 
            self.conv4_1, None, self.conv4_2, None, self.conv4_3, None, None,
            self.conv5_1, None, self.conv5_2, None, self.conv5_3, None, None,
        ]
        for l1, l2 in zip(vgg16.features, features):
            if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d):
                assert l1.weight.size() == l2.weight.size()
                assert l1.bias.size() == l2.bias.size()
                l2.weight.data = l1.weight.data
                l2.bias.data = l1.bias.data
                if not self.requires_grad:
                    l2.requires_grad = False
        for i, name in zip([0, 3], ['fc6', 'fc7']):
            l1 = vgg16.classifier[i]
            l2 = getattr(self, name)
            l2.weight.data = l1.weight.data.view(l2.weight.size())
            l2.bias.data = l1.bias.data.view(l2.bias.size())

In [6]:
vgg16_model = VGG(pretrained=True)
fcn32_model = FCN32(n_class=n_class)
fcn32_model.copy_params_from_vgg16(vgg16_model)

In [7]:
lr = 1e-3
w_decay = 1e-5
step_size = 30
gamma = 0.5

cross_entropy = nn.CrossEntropyLoss()
optimizer = optim.Adam(fcn32_model.parameters(), lr=lr, weight_decay=w_decay)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)  # decay LR by a factor of 0.5 every 30 epochs

In [8]:
epochs = 3

for epoch in range(epochs):
    scheduler.step()
    ts = time.time()
    for n_iter, batch in enumerate(myDataLoader):
        optimizer.zero_grad()
        inputs, labels = torch.autograd.Variable(batch['X']), torch.autograd.Variable(batch['Y'])
        outputs = fcn32_model(inputs)
        loss = cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()
        print("epoch: {}, iter: {}, loss: {}".format(epoch+1, n_iter+1, loss.item()))
    print("Finish epoch: {}, time elapsed {}".format(epoch+1, time.time() - ts))
    
# model_path = os.path.join('model')
# if not os.path.exists(model_path):
#     os.mkdir(model_path)
# torch.save(fcn32_model, os.path.join(model_path, "saved.pb"))

RuntimeError: invalid argument 3: only batches of spatial targets supported (3D tensors) but got targets of dimension: 4 at /Users/administrator/nightlies/2018_12_21/wheel_build_dirs/wheel_3.7/pytorch/aten/src/THNN/generic/SpatialClassNLLCriterion.c:59

In [None]:
inputs = next(iter(myTestDataLoader))
outputs = fcn32_model(torch.autograd.Variable(inputs['X']))
outputs = outputs.data.cpu().numpy()

In [None]:
for i in range(2):
    N, H, W = outputs[i].shape
    im_softmax = outputs[i].transpose(1, 2, 0)[:, :, 1]
    segmentation = (im_softmax > 0.2).reshape(H, W, 1)
    mask = np.dot(segmentation, np.array([[0, 255, 0, 127]]))
    mask = scipy.misc.toimage(mask, mode="RGBA")

    street_im = scipy.misc.toimage(inputs['X'][i])
    street_im.paste(mask, box=None, mask=mask)
    pyplot.figure()
    pyplot.imshow(street_im)
    pyplot.show()