In [79]:
import numpy as np
import torch
from torch.autograd import Variable
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision
import torch.nn as nn
from torch.optim import lr_scheduler
from torchvision import datasets, transforms, models
import torch.nn.functional as F
import timeit
import time
from skimage import io, transform
from torch.utils.data import Dataset, DataLoader
import os
from __future__ import division, print_function
import shutil
%matplotlib inline
import glob

In [2]:
# The model FCNS 32
class fcns32(nn.Module):
    
    
    def __init__(self, n_classes =21, learn_bilinear=False):
        super(fcns32, self).__init__()
        self.learn_bilinear = learn_bilinear
        self.n_classes = n_classes
        
        # Creating the convolutional blocks
        self.conv1block = nn.Sequential(
                            nn.Conv2d(3, 64, 3, padding=100),
                            nn.ReLU(inplace=True),
                            nn.Conv2d(64, 64, 3, padding=1),
                            nn.ReLU(),
                            nn.MaxPool2d(2, stride=2, ceil_mode=True),
                            )
        self.conv2block = nn.Sequential(
                            nn.Conv2d(64, 128, 3, padding=1),
                            nn.ReLU(inplace=True),
                            nn.Conv2d(128, 128, 3, padding=1),
                            nn.ReLU(inplace=True),
                            nn.MaxPool2d(2, stride=2, ceil_mode=True),)
        self.conv3block = nn.Sequential(
                            nn.Conv2d(128, 256, 3, padding=1),
                            nn.ReLU(inplace=True),
                            nn.Conv2d(256,256,3, padding=1),
                            nn.ReLU(inplace=True),
                            nn.Conv2d(256,256,3, padding=1),
                            nn.ReLU(inplace=True),
                            nn.MaxPool2d(2, stride=2, ceil_mode=True),
                            )
        self.conv4block = nn.Sequential(
                            nn.Conv2d(256, 512, 3, padding=1),
                            nn.ReLU(inplace=True),
                            nn.Conv2d(512,512,3, padding=1),
                            nn.ReLU(inplace=True),
                            nn.Conv2d(512,512,3, padding=1),
                            nn.ReLU(inplace=True),
                            nn.MaxPool2d(2, stride=2, ceil_mode=True),
                            )
        self.conv5block = nn.Sequential(
                            nn.Conv2d(512, 512, 3, padding=1),
                            nn.ReLU(inplace=True),
                            nn.Conv2d(512,512,3, padding=1),
                            nn.ReLU(inplace=True),
                            nn.Conv2d(512,512,3, padding=1),
                            nn.ReLU(inplace=True),
                            nn.MaxPool2d(2, stride=2, ceil_mode=True),
                            )
        
        self.classifier = nn.Sequential(
                            nn.Conv2d(512, 4096, 7),
                            nn.ReLU(inplace=True),
                            nn.Dropout2d(),
                            nn.Conv2d(4096, 4096, 1),
                            nn.ReLU(inplace=True),
                            nn.Dropout2d(),
                            nn.Conv2d(4096, n_classes, 1),
                            )
        
    
    # The forward pass
    def forward(self, x):
        conv1 = self.conv1block(x)
        conv2 = self.conv2block(conv1)
        conv3 = self.conv3block(conv2)
        conv4 = self.conv4block(conv3)
        conv5 = self.conv5block(conv4)
        
        output = self.classifier(conv5)
        
        out = F.upsample_bilinear(output, x.size()[2:])
        
        return out
    
    #Using the pretrained network
    def init_vgg16_params(self, vgg16, copy_fc8=True):
        blocks = [self.conv1block,
                 self.conv2block,
                 self.conv3block,
                 self.conv4block,
                 self.conv5block]
        
        ranges = [[0,4], [5,9], [10, 16], [17, 23], [24, 29]]
        features = list(vgg16.features.children())
        print(features)
        
        for i, conv in enumerate(blocks):
            for l1, l2 in zip(features[ranges[i][0]:ranges[i][1]], conv):
                if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d):
                    print(l1.weight.size(), i)
                    print(l2.weight.size(), i)
                    assert l1.weight.size() == l2.weight.size()
                    assert l1.bias.size() == l2.bias.size()
                    l2.weight.data = l1.weight.data
                    l2.bias.data = l1.bias.data
        
        for i1, i2 in zip([0,3], [0,3]):
            l1 = vgg16.classifier[i1]
            l2 = self.classifier[i2]
            l2.weight.data = l1.weight.data.view(l2.weight.size())
            l2.bias.data = l1.bias.data.view(l2.bias.size())
            
        n_class = self.classifier[6].weight.size()[0]
        
        if copy_fc8:
            l1 = vgg16.classifier[6]
            l2 = self.classifier[6]
            l2.weight.data = l1.weight.data[:n_class, :].view(l2.weight.size())
            l2.bias.data = l1.bias.data[:n_class]

In [7]:
def train_model(model, optimizer, criterion, scheduler, num_epochs=25):
    since = time.time()
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                model.train(True)
            else:
                model.train(False)
                
            running_loss = 0.0
            running_corrects = 0
            
    
    

In [3]:
model = fcns32()
vgg16 = models.vgg16(pretrained=True)
model.init_vgg16_params(vgg16)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = 0.01)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=0.7, gamma=0.1)

[Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU (inplace), Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU (inplace), MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1)), Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU (inplace), Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU (inplace), MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1)), Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU (inplace), Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU (inplace), Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU (inplace), MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1)), Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU (inplace), Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU (inplace), Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 

In [18]:
# Data collection for training set

trainText = '/Users/navneetmkumar/Documents/Paper Implementations/VOC2012/ImageSets/Segmentation/train.txt'

def getTrainNames(filename):
    segment_images = []
    f = open(filename)
    filecontents = f.readlines()
    for line in filecontents:
        img_name = line.strip('\n')
        segment_images.append(img_name)
    return segment_images

In [20]:
train_images = getTrainNames(trainText)
print(len(train_images))

1464


In [59]:
if not os.path.exists('/Users/navneetmkumar/Documents/Paper Implementations/train-inputs'):
    os.makedirs('/Users/navneetmkumar/Documents/Paper Implementations/train-inputs')

if not os.path.exists('/Users/navneetmkumar/Documents/Paper Implementations/train-targets'):
    os.makedirs('/Users/navneetmkumar/Documents/Paper Implementations/train-targets')

image_dir = '/Users/navneetmkumar/Documents/Paper Implementations/VOC2012/JPEGImages'
targets_dir = '/Users/navneetmkumar/Documents/Paper Implementations/VOC2012/SegmentationClass'

def copyFiles(directory, targets=False):
    for f in os.listdir(directory):
        f_name = os.path.basename(f)
        f_name = f_name.split(".")[0]
        if f_name in train_images:
            f = os.path.join(directory, f)
            if targets:
                shutil.copy2(f, '/Users/navneetmkumar/Documents/Paper Implementations/train-targets/')
            else:
                shutil.copy2(f, '/Users/navneetmkumar/Documents/Paper Implementations/train-inputs/')

In [55]:
copyFiles(image_dir)

In [60]:
copyFiles(targets_dir, targets=True)

In [61]:
# Setting up the validation set
valText = '/Users/navneetmkumar/Documents/Paper Implementations/VOC2012/ImageSets/Segmentation/val.txt'
val_images = getTrainNames(valText)
print(len(val_images))

1449


In [66]:
if not os.path.exists('/Users/navneetmkumar/Documents/Paper Implementations/val-inputs'):
    os.makedirs('/Users/navneetmkumar/Documents/Paper Implementations/val-inputs')

if not os.path.exists('/Users/navneetmkumar/Documents/Paper Implementations/val-targets'):
    os.makedirs('/Users/navneetmkumar/Documents/Paper Implementations/val-targets')

image_dir = '/Users/navneetmkumar/Documents/Paper Implementations/VOC2012/JPEGImages'
targets_dir = '/Users/navneetmkumar/Documents/Paper Implementations/VOC2012/SegmentationClass'

def copyFiles(directory, targets=False):
    for f in os.listdir(directory):
        f_name = os.path.basename(f)
        f_name = f_name.split(".")[0]
        if f_name in val_images:
            f = os.path.join(directory, f)
            if targets:
                shutil.copy2(f, '/Users/navneetmkumar/Documents/Paper Implementations/val-targets/')
            else:
                shutil.copy2(f, '/Users/navneetmkumar/Documents/Paper Implementations/val-inputs/')

In [67]:
copyFiles(image_dir)

In [68]:
copyFiles(targets_dir, targets=True)

In [95]:
# Display the validation images
val_im_dir = '/Users/navneetmkumar/Documents/Paper Implementations/val-inputs'
val_t_dir = '/Users/navneetmkumar/Documents/Paper Implementations/val-targets'

val_im_list = glob.glob(os.path.join(val_im_dir, '*.jpg'))
val_t_list = glob.glob(os.path.join(val_t_dir, '*.png'))

def showImage(image, dire):
    f = os.path.join(dire, image)
    plt.imshow(io.imread(f))