# Train Segmentation with Atrous Convolution


#### References
* https://arxiv.org/pdf/1709.00179.pdf
* https://medium.com/beyondminds/a-simple-guide-to-semantic-segmentation-effcf83e7e54
* https://medium.com/dair-ai/medical-imaging-analysis-mri-cnn-pytorch-4877e64e7303
* https://medium.com/udacity-pytorch-challengers/a-brief-overview-of-loss-functions-in-pytorch-c0ddb78068f7
* https://github.com/kevinzakka/pytorch-goodies/blob/master/losses.py

In [46]:
import sat_utils
import numpy as np
import pickle
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from random import randint

# Pytorch stuff
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch
import torch.utils.data as utils
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
from torch.optim import lr_scheduler

from skimage.filters import threshold_adaptive, threshold_otsu, threshold_local

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

lr=0.001 #0.001
l2_norm=0.0000001
gamma=0.1
batch_size = 20 #20
num_epochs = 1000
step_size = 200

Device: cuda:1


#### Load Data from pickle (Bad not scalable) and create data loader

In [2]:
X = sat_utils.read_pickle_data('./data/input.pickle')
Y = sat_utils.read_pickle_data('./data/label.pickle')
tensor_x = torch.stack([torch.Tensor(sat_utils.get_rgb(x)) for x in X.values()])
#tensor_y = torch.stack([torch.Tensor(np.expand_dims(x[0,:,:],axis=0)).type(torch.LongTensor) for x in Y.values()])
tensor_y = torch.stack([torch.Tensor(np.expand_dims(x[0,:,:],axis=0)) for x in Y.values()])

dataset_train = utils.TensorDataset(tensor_x,tensor_y)
dataloader_train = utils.DataLoader(dataset_train, batch_size=batch_size)

In [3]:
print('Input:',tensor_x.shape)
print('Label:',tensor_y.shape)

Input: torch.Size([40, 3, 76, 76])
Label: torch.Size([40, 1, 76, 76])


In [4]:
class CrossEntropyLoss2d(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(CrossEntropyLoss2d, self).__init__()
        self.nll_loss = nn.NLLLoss2d(weight, size_average)

    def forward(self, inputs, targets):
        return self.nll_loss(F.log_softmax(inputs), targets)

#### Define Model

In [5]:
# Input 76x76 output 16x16
class AtrousSeg(nn.Module):
    def __init__(self, num_classes=1, num_channels=8):
        super().__init__()
        self.model = nn.Sequential(    
            nn.Conv2d(num_channels, 64, kernel_size=3, stride=1, padding=1, dilation = 1), # Front           
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, dilation = 1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=0, dilation = 2),            
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=0, dilation = 2),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=0, dilation = 3),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=0, dilation = 3),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=0, dilation = 3),            
            nn.ReLU(),
            nn.BatchNorm2d(256),
            
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=0, dilation = 3), #LFE
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, dilation = 3), 
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, dilation = 3), 
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, dilation = 2), 
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, dilation = 2), 
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, dilation = 1), 
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, dilation = 1),             
            nn.ReLU(),
            nn.BatchNorm2d(256),
            
            nn.Conv2d(256, 1024, kernel_size=7, stride=1, padding=1, dilation = 3), # Head (44x44)
            nn.ReLU(),
            nn.BatchNorm2d(1024),
            nn.Conv2d(1024, 1024, kernel_size=1, stride=1, dilation = 1), 
            nn.ReLU(),
            nn.BatchNorm2d(1024),
            nn.Conv2d(1024, num_classes, kernel_size=1, stride=1, dilation = 1),    
            nn.UpsamplingBilinear2d(size=(76, 76)),
        )
        # Initialize Weights
        for m in self.model:
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                nn.init.constant_(m.bias, 0)
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def forward(self, x):        
        return self.model(x)

In [6]:
model = AtrousSeg(num_classes=1, num_channels=tensor_x.shape[1])
#resp = model(torch.rand(1, 8, 76, 76))

In [7]:
writer = SummaryWriter('./logs')
dummy_x = torch.rand(1, tensor_x.shape[1], 76, 76)
writer.add_graph(model, dummy_x)



In [8]:
model.to(device)

AtrousSeg(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), dilation=(2, 2))
    (7): ReLU()
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), dilation=(2, 2))
    (10): ReLU()
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), dilation=(3, 3))
    (13): ReLU()
    (14): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (15): Conv2d(256, 256, kernel_size=(3, 3), stri

In [9]:
#loss_fn = nn.CrossEntropyLoss()
#loss_fn = CrossEntropyLoss2d()
loss_fn = nn.MSELoss()
#loss_fn = nn.SmoothL1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2_norm)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

#### Train Model

In [None]:
iteration_count = 0
# For all epochs
for epoch in range(num_epochs):
    # Train step
    model.train()    
    # For all elements on the training set
    for i, (imgs, labels) in enumerate(dataloader_train):
        # Send inputs/labels to GPU                
        labels = labels.to(device)
        imgs = imgs.to(device)                
        
        optimizer.zero_grad()
        
        outputs = model(imgs)
        
        loss = loss_fn(outputs, labels)        
        
        loss.backward()
        optimizer.step()
        exp_lr_scheduler.step(epoch)
        writer.add_scalar('loss/', loss.item(), iteration_count)        
        iteration_count+=1        
    print('Epoch:', epoch, 'of:', num_epochs, 'loss:', loss.item())
    # Send to tensorboard loss
    #writer.add_image('Image', imgs[:, 4, :, :].permute(0, 2, 1).unsqueeze(1), epoch)
    #writer.add_image('Image', imgs[:, 4, :, :].type(torch.LongTensor), epoch)
    #writer.add_image('Output', outputs[:, 0, :, :].type(torch.LongTensor), epoch)
    
    img_idx = randint(0, batch_size-1)    
    f, axarr = plt.subplots(1, 3)
    axarr[0].imshow(imgs[img_idx,0,:,:].cpu()) #4 With 8 channels
    axarr[1].imshow(outputs[img_idx,0,:,:].detach().cpu())
    axarr[2].imshow(labels[img_idx,0,:,:].cpu())    
    plt.show()
    

#### Test Model

In [107]:
@interact(idx_img=widgets.IntSlider(min=0,max=tensor_x.shape[0]-1))
def testModel(idx_img):
    model.eval()
    with torch.no_grad():
        img = tensor_x[idx_img].unsqueeze(0).to(device)
        pred = model(img)
        
    img_numpy = img.cpu().squeeze().numpy()
    pred_numpy = pred.cpu().squeeze().numpy()
    label_numpy = tensor_y[idx_img].squeeze().numpy()
    
    # Very very Naive way to threshold (We should use Gan here!)
    threshold = (np.max(pred_numpy) - np.min(pred_numpy)) / 2
    
    # Adaptive threshold also don't work because we need to fix a block size constant
    block_size=81
    bin_adaptive = threshold_adaptive(pred_numpy, block_size=block_size, method='gaussian');
    
    f, axarr = plt.subplots(1, 5, figsize=(15,15))
    axarr[0].imshow(img_numpy[0,:,:]) #4 With 8 channels
    axarr[0].title.set_text('Original')
    axarr[1].imshow(pred_numpy)
    axarr[1].title.set_text('Prediction')
    axarr[2].imshow(pred_numpy>(threshold))
    axarr[2].title.set_text('Threshold')
    axarr[3].imshow(label_numpy)
    axarr[3].title.set_text('Label')
    axarr[4].imshow(bin_adaptive)
    axarr[4].title.set_text('Adaptive')

interactive(children=(IntSlider(value=0, description='idx_img', max=39), Output()), _dom_classes=('widget-inte…