# FCN
This notebook contains a test CNN in pytorch, to get familiar with this developping environment. It also acts as a template for later use.

In [1]:
%matplotlib inline
import numpy as np
import torch
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F
import os,sys
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import time

#seed for reproducible results
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x11a8c363bd0>

In [2]:
# Helper functions

def load_image(infilename):
    data = mpimg.imread(infilename)
    return data

def img_float_to_uint8(img):
    rimg = img - np.min(img)
    rimg = (rimg / np.max(rimg) * 255).round().astype(np.uint8)
    return rimg

# Concatenate an image and its groundtruth
def concatenate_images(img, gt_img):
    nChannels = len(gt_img.shape)
    w = gt_img.shape[0]
    h = gt_img.shape[1]
    if nChannels == 3:
        cimg = np.concatenate((img, gt_img), axis=1)
    else:
        gt_img_3c = np.zeros((w, h, 3), dtype=np.uint8)
        gt_img8 = img_float_to_uint8(gt_img)          
        gt_img_3c[:,:,0] = gt_img8
        gt_img_3c[:,:,1] = gt_img8
        gt_img_3c[:,:,2] = gt_img8
        img8 = img_float_to_uint8(img)
        cimg = np.concatenate((img8, gt_img_3c), axis=1)
    return cimg

## Pytorch module

This module contains the FCN based on the CNN module

### Module structure

For this module, we will try to implement the complex diagram describe in http://openaccess.thecvf.com/content_cvpr_2018_workshops/papers/w4/Zhou_D-LinkNet_LinkNet_With_CVPR_2018_paper.pdf

Usefull links:

-Torch documentations (especially for the input/ouput size of Conv2d and ConvTranspose2d: https://pytorch.org/docs/stable/nn.html

-This link to better understand what each argument in Conv2d and ConvTranspose2d: https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md

In [3]:
# This class takes our input of size 400*400 and enlarges it to size 512*512

class Net100(torch.nn.Module):
        
    def __init__(self):
        super(Net100, self).__init__()

        self.conv64 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.conv128 = torch.nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv256 = torch.nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.conv512 = torch.nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
        
        self.pooleven = torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.poolodd = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
        
        self.deconv256 = torch.nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=0, output_padding=0)
        self.norm128 = torch.nn.BatchNorm2d(256)
        
        self.deconv128 = torch.nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.norm64 = torch.nn.BatchNorm2d(128)
        
        self.deconv64 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.norm3 = torch.nn.BatchNorm2d(64)
        
        self.finalconv = torch.nn.Conv2d(64, 1, kernel_size=3, padding=1)
        
        self.sigmoid = torch.nn.Sigmoid()
        
        
    def forward(self, x):
        x = F.relu(self.conv64(x)) # 100*100*3 -> 100*100*64
        x = self.pooleven(x) # 100*100*64 -> 50*50*64
        
        x = F.relu(self.conv128(x)) # 50*50*64 -> 50*50*128
        x = self.pooleven(x) # 50*50*128 -> 25*25*128
        
        x = F.relu(self.conv256(x)) # 25*25*128 -> 25*25*256
        x = self.poolodd(x) # 25*25*256 -> 12*12*256
        
        x = F.relu(self.conv512(x)) # 12*12*256 -> 12*12*512
        
        x = self.deconv256(x) # 12*12*512 -> 25*25*256
        x = self.norm128(x)
        
        x = self.deconv128(x) # 25*25*256 -> 50*50*128
        x = self.norm64(x)
        
        x = self.deconv64(x) # 50*50*128 -> 100*100*64
        x = self.norm3(x)
        
        x = F.relu(self.finalconv(x)) # 100*100*64 -> 100*100*1
        x = self.sigmoid(x)
        
        return(x)

This model optimized with adam and cross entropy will converge to all-black images every time

## Training the model

In [4]:
sat_images_100 = np.load('balanced_dataset/sat_images_100.npy').astype(np.float64).swapaxes(1,3).swapaxes(2,3) #1752 images
gt_images_100 = np.load('balanced_dataset/groundtruth_100.npy').astype(np.float64)

# We will take 1500 images as input, and the remaining 252 images as validation set
train_input = sat_images_100[:1500]
validation_input = sat_images_100[1500:]

train_target = gt_images_100[:1500]
validation_target = gt_images_100[1500:]

To shorten the computationnal time we will use a smaller amount of images to do tests.

### Instantiate the model, loss function and optimizer

In [5]:
# We will optimize the cross-entropy loss using adam algorithm
net = Net100()
loss_function = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)
# optimizer = optim.SGD(net.parameters(), lr=3.75e-2, momentum=0.9)

In [6]:
def trainNet(net, n_epochs):
    
    #Time for printing
    training_start_time = time.time()
    
    #Loop for n_epochs
    for epoch in range(n_epochs):
        
        total_loss = 0.0
        
        for index in range(np.shape(train_input)[0]):
            
            input_image = Variable(torch.tensor(train_input[index], requires_grad=True).unsqueeze(0))
            target_image = Variable(torch.tensor(train_target[index], dtype=torch.long).unsqueeze(0))
            
            #Set the parameter gradients to zero
            optimizer.zero_grad()
            
            #Forward pass, backward pass, optimize
            outputs = net(input_image.float())
            loss = loss_function(outputs, target_image)
            loss.backward()
            optimizer.step()
            
            #Print statistics
            total_loss += loss.item()
            
        print("Epoch", epoch, ", training loss:", loss.item(), ", time elapsed:", time.time() - training_start_time)
        
        #At the end of the epoch, do a pass on the validation set
        total_val_loss = 0
        for index in range(np.shape(validation_input)[0]):
            
            input_image = Variable(torch.tensor(validation_input[index], requires_grad=True).unsqueeze(0))
            target_image = Variable(torch.tensor(validation_target[index], dtype=torch.long).unsqueeze(0))
            
            #Forward pass
            val_outputs = net(input_image.float())
            val_loss = loss_function(val_outputs, target_image)
            total_val_loss += val_loss.item()
            
        print("Validation loss for epoch", epoch, ":", total_val_loss/np.shape(validation_input)[0])
        
    print("Training finished, took {:.2f}s".format(time.time() - training_start_time))

In [None]:
trainNet(net, 3)

In [None]:
input_image = torch.tensor(validation_input[3]).unsqueeze(0)
target_image = torch.tensor(validation_target[3]).unsqueeze(0)
           
#Forward pass
val_output = net(input_image.float())
output_image = val_output[0,1]

In [None]:
plt.imshow(target_image.squeeze(0), cmap='Greys_r')

In [None]:
plt.imshow(output_image.detach().numpy(), cmap='Greys_r')