# Train Segmentation with Atrous Convolution


#### References
* https://arxiv.org/pdf/1709.00179.pdf
* https://medium.com/beyondminds/a-simple-guide-to-semantic-segmentation-effcf83e7e54
* https://medium.com/dair-ai/medical-imaging-analysis-mri-cnn-pytorch-4877e64e7303
* https://medium.com/udacity-pytorch-challengers/a-brief-overview-of-loss-functions-in-pytorch-c0ddb78068f7
* https://github.com/kevinzakka/pytorch-goodies/blob/master/losses.py
* https://arxiv.org/pdf/1702.03275.pdf
* https://medium.com/syncedreview/facebook-ai-proposes-group-normalization-alternative-to-batch-normalization-fb0699bffae7
* https://www.jeremyjordan.me/evaluating-image-segmentation-models/
* https://github.com/martinkersner/py_img_seg_eval
* https://medicaltorch.readthedocs.io/en/stable/
* https://github.com/meetshah1995/pytorch-semseg
* https://discuss.pytorch.org/t/leaf-variable-has-been-moved-into-the-graph-interior/18679/9

#### Andrew Ng on Accuracy/Precision/Recall
Accuracy it's not important if your dataset is imbalanced (Skewed), for example if your model say 100% of time that someone has no cancer, it will be really accurate, like 99.999% but it's Recall will be zero. 
* https://www.youtube.com/watch?v=k1JGvqr56Yk&list=PLLssT5z_DsK-h9vYZkQkYNWcItqhlRJLN&index=66
* https://www.youtube.com/watch?v=wGw6R8AbcuI
* https://www.youtube.com/watch?v=W5meQnGACGo

In [1]:
import sat_utils
import numpy as np
import pickle
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from random import randint
from tqdm import tqdm

# Pytorch stuff
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import torch
import torch.utils.data as utils
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
from torch.optim import lr_scheduler

from skimage.filters import threshold_adaptive, threshold_otsu, threshold_local

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
#device = 'cpu'
print('Device:', device)

lr=0.0001 #0.001
l2_norm=0.0000001
gamma=0.1
batch_size = 20 #20
num_epochs = 500
step_size = 200
SMOOTH = 1e-11

Device: cuda:1


#### Load Data from pickle (Bad not scalable) and create data loader

In [2]:
X = sat_utils.read_pickle_data('./data/input.pickle')
Y = sat_utils.read_pickle_data('./data/label.pickle')
tensor_x = torch.stack([torch.Tensor(sat_utils.get_rgb(x)) for x in X.values()])

# Label from Mean squared error Loss
#tensor_y = torch.stack([torch.Tensor(x) for x in Y.values()])
tensor_y = torch.stack([torch.Tensor(x/255.0) for x in Y.values()])
# Try to make background even more far
#tensor_y[tensor_y==tensor_y.min()] = -100.0

# Changes on label for Cross-Entropy (3 classes all mixed on the same image, N,W,H)
# Changes on label for BCEWithLogitsLoss (3 classes on 3 Channels, N,C,W,H)
#tensor_y = torch.stack([torch.Tensor(x/255.0).type(torch.LongTensor) for x in Y.values()])
#tensor_y = torch.stack([torch.Tensor(x/255.0).type(torch.FloatTensor) for x in Y.values()])
# Just one class
#tensor_y[:,2,:,:][tensor_y[:,2,:,:]==2.0] = 0
#tensor_y = tensor_y[:,0,:,:] + (tensor_y[:,2,:,:] * 2.0)
#print(torch.unique(tensor_y))
#print(torch.unique(tensor_y))
#tensor_y[:,0,:,:] = 1
#tensor_y[:,1,:,:] = 2
#tensor_y[:,2,:,:] = 3

dataset_train = utils.TensorDataset(tensor_x,tensor_y)
dataloader_train = utils.DataLoader(dataset_train, batch_size=batch_size, shuffle=True)

In [3]:
print('Input:',tensor_x.shape)
print('Label:',tensor_y.shape)
num_classes = tensor_y.shape[1]
#num_classes = 4
print('num_classes:', num_classes)
print('Max val on label:', torch.max(tensor_y).item())
print('Min val on label:', torch.min(tensor_y).item())

Input: torch.Size([40, 3, 76, 76])
Label: torch.Size([40, 3, 76, 76])
num_classes: 3
Max val on label: 1.0
Min val on label: 0.0


#### Custom Loss Functions
![alt text](./imgs_doc/metrics.png "Title")

In [13]:
# Reference: https://gist.github.com/wassname/7793e2058c5c9dacb5212c0ac0b18a8a
# https://arxiv.org/pdf/1606.04797v1.pdf
def dice_loss(model_outputs, labels):    
    smooth = 1.

    iflat = model_outputs.view(-1)
    tflat = labels.view(-1)
    intersection =  torch.abs(iflat * tflat).sum()
    
    return 1 - ((2. * intersection + smooth) /
              ((iflat*iflat).sum() + (tflat*tflat).sum() + smooth))


# This version uses hard threshold so it can't be used on the loss function but is precise 
# to verify the real dice coefficient
def dice_val(model_outputs, labels):
    #xt = torch.FloatTensor(x,requires_grad=True)
    bin_model_outputs = torch.zeros_like(model_outputs, requires_grad=True).type(torch.FloatTensor)
    bin_labels = (labels > 0).type(torch.FloatTensor)
    
    # Convert all channels from model_outputs to binary
    list_max = [torch.max(model_outputs[:,ch,:,:]) for ch in range(model_outputs.shape[1])]
    list_min = [torch.min(model_outputs[:,ch,:,:]) for ch in range(model_outputs.shape[1])]
    list_threshold = [(list_max[ch] - list_min[ch]) / 2.0 for ch in range(model_outputs.shape[1])]
    for ch in range(model_outputs.shape[1]):
        bin_model_outputs[:,ch,:,:] = model_outputs[:,ch,:,:] > list_threshold[ch]
        
    smooth = 1.

    iflat = bin_model_outputs.view(-1)
    tflat = bin_labels.view(-1)
    intersection = (iflat * tflat).sum()
    
    return 1 - ((2. * intersection + smooth) /
              (iflat.sum() + tflat.sum() + smooth))

In [14]:
# model_outputs, labels format: BATCH x Channels x ROWs x COLs
# Reference: https://gist.github.com/wassname/f1452b748efcbeb4cb9b1d059dce6f96
# This requires that both labels and model_outputs are on the same range (0..1)
def iou_loss(model_outputs, labels):
    # Avoid negative values 
    intersection = torch.abs(model_outputs * labels).sum()
    union = torch.abs(model_outputs).sum() + torch.abs(labels).sum()
    iou = (intersection + SMOOTH) / (union - intersection + SMOOTH)
    # Invert to "minimize" to just plug and play on the lost
    return 1 - iou

# This version uses hard threshold so it can't be used on the loss function but is precise 
# to verify the real IoU
# Reference: https://www.kaggle.com/iezepov/fast-iou-scoring-metric-in-pytorch-and-numpy
def iou_val(model_outputs, labels):
    bin_model_outputs = torch.zeros_like(model_outputs, requires_grad=False).type(torch.LongTensor)
    bin_labels = (labels > 0).type(torch.LongTensor)
    # Convert all channels from model_outputs to binary
    list_max = [torch.max(model_outputs[:,ch,:,:]) for ch in range(model_outputs.shape[1])]
    list_min = [torch.min(model_outputs[:,ch,:,:]) for ch in range(model_outputs.shape[1])]
    list_threshold = [(list_max[ch] - list_min[ch]) / 2.0 for ch in range(model_outputs.shape[1])]
    for ch in range(model_outputs.shape[1]):
        bin_model_outputs[:,ch,:,:] = model_outputs[:,ch,:,:] > list_threshold[ch]
    
    intersection = (bin_model_outputs & bin_labels).float().sum((2, 3))  # Will be zero if Truth=0 or Prediction=0
    union = (bin_model_outputs | bin_labels).float().sum((2, 3))         # Will be zzero if both are 0
    # Calculate Intersect over Union (Jaccard Index)
    iou = (intersection + SMOOTH) / (union + SMOOTH)  # We smooth our devision to avoid 0/0
    return iou.mean()

#### Define Model

In [6]:
# Input 76x76 output 16x16
class AtrousSeg(nn.Module):
    def __init__(self, num_classes=1, num_channels=8):
        super().__init__()
        self.model = nn.Sequential(
            #nn.BatchNorm2d(num_channels),
            nn.Conv2d(num_channels, 64, kernel_size=3, stride=1, padding=1, dilation = 1), # Front           
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, dilation = 1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=0, dilation = 2),            
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=0, dilation = 2),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=0, dilation = 3),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=0, dilation = 3),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=0, dilation = 3),            
            nn.ReLU(),
            nn.BatchNorm2d(256),
            
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=0, dilation = 3), #LFE
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, dilation = 3), 
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, dilation = 3), 
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, dilation = 2), 
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, dilation = 2), 
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, dilation = 1), 
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, dilation = 1),             
            nn.ReLU(),
            nn.BatchNorm2d(256),
            
            nn.Conv2d(256, 1024, kernel_size=7, stride=1, padding=1, dilation = 3), # Head (44x44)
            nn.ReLU(),
            nn.BatchNorm2d(1024),
            nn.Conv2d(1024, 1024, kernel_size=1, stride=1, dilation = 1), 
            nn.ReLU(),
            nn.BatchNorm2d(1024),
            nn.Conv2d(1024, num_classes, kernel_size=1, stride=1, dilation = 1),
            #nn.Sigmoid(),
            nn.UpsamplingBilinear2d(size=(76, 76)),
        )
        # Initialize Weights
        for m in self.model:
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                nn.init.constant_(m.bias, 0)
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def forward(self, x):
        result = self.model(x)
        # Better for MSE
        #return result
        # Better for BCEWithLogitsLoss
        return F.sigmoid(result)
        #return result
        #return self.model(x)
        if self.training:
            return result
        else:
            return F.sigmoid(result)
        #    result = self.model(x)
        #    print(x.shape)
        #    print(result.shape)
        #    F.softmax(result, dim=2)

In [7]:
model = AtrousSeg(num_classes=num_classes, num_channels=tensor_x.shape[1])
#resp = model(torch.rand(1, 8, 76, 76))

In [8]:
#writer = SummaryWriter('./logs')
# Default directory "runs"
writer = SummaryWriter()
dummy_x = torch.rand(1, tensor_x.shape[1], 76, 76)
writer.add_graph(model, dummy_x)



In [9]:
model.to(device)

AtrousSeg(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), dilation=(2, 2))
    (7): ReLU()
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), dilation=(2, 2))
    (10): ReLU()
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), dilation=(3, 3))
    (13): ReLU()
    (14): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (15): Conv2d(256, 256, kernel_size=(3, 3), stri

#### Initialize Losses and Optimizers

In [10]:
# Classification losses will have predictions on the format (batch, n_class, 12, 12)
# and labels (batch, rows, cols) with values related to indexes of class
#loss_fn = nn.CrossEntropyLoss()
#loss_fn = nn.BCEWithLogitsLoss()

# Regularization losses accept any logits on format (batch, channels, rows, cols) 
#for both prediction and label
loss_fn = nn.MSELoss()
#loss_fn = nn.SmoothL1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2_norm)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

#### Train Model

In [11]:
iteration_count = 0
# For all epochs
for epoch in tqdm(range(num_epochs), desc='Training'):
    # Train step
    model.train()    
    # For all elements on the training set
    for i, (imgs, labels) in enumerate(dataloader_train):
        # Send inputs/labels to GPU                
        labels = labels.to(device)
        imgs = imgs.to(device)                
        
        optimizer.zero_grad()
        
        outputs = model(imgs)
        
        #loss_dice = dice_loss(outputs, labels)
        iou_value = iou_loss(outputs, labels)
        dice_value  = dice_loss(outputs, labels)
        #print(type(dice_val))
        #loss = loss_fn(outputs, labels) + (1000*dice_val)
        #loss = dice_val
        loss = iou_value
        
        loss.backward()
        optimizer.step()
        exp_lr_scheduler.step(epoch)
        writer.add_scalar('loss/', loss.item(), iteration_count)
        #writer.add_scalar('loss_dice/', loss_dice.item(), iteration_count)
        iteration_count+=1        
    #print('Epoch:', epoch, 'of:', num_epochs, 'loss:', loss.item())
    # Get number of channels on output (Number of classes)
    num_channels_outputs = outputs.shape[1]
    
    # Get Iou, Dice
    iou_val = iou_loss(outputs, labels)
    writer.add_scalar('iou/', iou_value.item(), epoch)
    dice_val = dice_loss(outputs, labels)
    writer.add_scalar('dice/', dice_value.item(), epoch)
    
    # Get current learning rate (To display on Tensorboard)
    for param_group in optimizer.param_groups:
        curr_learning_rate = param_group['lr']
        writer.add_scalar('learning_rate/', curr_learning_rate, epoch)
    
    # Send to tensorboard loss
    out_norm = sat_utils.img_minmax_norm_torch(outputs)
    labels_norm = sat_utils.img_minmax_norm_torch(labels)
    imgs_norm = sat_utils.img_minmax_norm_torch(imgs)
    writer.add_images('Image', imgs_norm, epoch)
    if num_classes > 1:
        writer.add_images('out_mask', out_norm[:, 0, :, :].unsqueeze(1), epoch)    
        writer.add_images('out_between', out_norm[:, 1, :, :].unsqueeze(1), epoch)
        writer.add_images('out_border', out_norm[:, 2, :, :].unsqueeze(1), epoch)
        if len(labels_norm.shape) > 3:
            writer.add_images('label_mask', labels_norm[:, 0, :, :].unsqueeze(1), epoch)
            writer.add_images('label_between', labels_norm[:, 1, :, :].unsqueeze(1), epoch)
            writer.add_images('label_border', labels_norm[:, 2, :, :].unsqueeze(1), epoch)
        else:
            writer.add_images('label_mask', labels_norm.unsqueeze(1), epoch)  
    else:
        writer.add_images('out_mask', out_norm, epoch)            
        writer.add_images('label_mask', labels_norm.unsqueeze(1), epoch)  
    
    # Save Model
    torch.save(model, './model_save/model_'+str(epoch)+'.cpkt')
    
    #img_idx = randint(0, batch_size-1)
    #img_input = imgs_norm[img_idx,:,:,:].cpu().numpy()
    #f, axarr = plt.subplots(1, (num_channels_outputs*2) + 1, figsize=(15,15))
    #axarr[0].imshow(np.moveaxis(img_input, 0, 2)) #4 With 8 channels
    #axarr[1].imshow(outputs[img_idx,0,:,:].detach().cpu())
    #axarr[2].imshow(outputs[img_idx,1,:,:].detach().cpu())
    #axarr[3].imshow(outputs[img_idx,2,:,:].detach().cpu())    
    #axarr[4].imshow(labels[img_idx,0,:,:].cpu())
    #axarr[5].imshow(labels[img_idx,1,:,:].cpu())
    #axarr[6].imshow(labels[img_idx,2,:,:].cpu())
    #***-----------plt.show()
    

  "type " + obj.__name__ + ". It won't be checked "
Training:  25%|██▌       | 127/500 [16:01<47:48,  7.69s/it]

KeyboardInterrupt: 

#### Load Specific Model

In [16]:
model = torch.load('./model_save/model_126.cpkt')
model.eval()

AtrousSeg(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), dilation=(2, 2))
    (7): ReLU()
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), dilation=(2, 2))
    (10): ReLU()
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), dilation=(3, 3))
    (13): ReLU()
    (14): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (15): Conv2d(256, 256, kernel_size=(3, 3), stri

#### Test Model

In [17]:
@interact(idx_img=widgets.IntSlider(min=0,max=tensor_x.shape[0]-1), th_mask_iteractive=widgets.IntSlider(min=0,max=100), use_threshold = False)
def testModel(idx_img):
    model.eval()
    with torch.no_grad():
        img = tensor_x[idx_img].unsqueeze(0).to(device)
        pred = model(img)
        label = tensor_y[idx_img].to(device)
        dice_value = dice_val(pred, label.unsqueeze(0))
        iou_value = iou_val(pred, label.unsqueeze(0))
        dice_loss_value = dice_loss(pred, label.unsqueeze(0))
        iou_loss_value = iou_loss(pred, label.unsqueeze(0))
        
    img_numpy = img.cpu().squeeze().numpy()
    img_numpy = sat_utils.img_minmax_norm(img_numpy)
    img_numpy = np.moveaxis(img_numpy, 0, 2)
    pred_numpy = pred.cpu().squeeze().numpy()
    # Test
    pred_numpy_sig = F.sigmoid(pred).cpu().squeeze().numpy()
    label_numpy = tensor_y[idx_img].squeeze().numpy()   
    img_numpy = sat_utils.img_minmax_norm(img_numpy)
    
    print('Dice Val:', dice_value)
    print('IoU Val:', iou_value)
    print('Dice Loss:', dice_loss_value)
    print('IoU Loss:', iou_loss_value)
    
    if num_classes > 1:
        # Merge Mask and Border
        mask_border = pred_numpy[0,:,:] - pred_numpy[1,:,:]
        
        f, axarr = plt.subplots(1, 5, figsize=(15,15))
        axarr[0].imshow(img_numpy) #4 With 8 channels
        axarr[0].title.set_text('Original')
        axarr[1].imshow(pred_numpy[0,:,:])
        axarr[1].title.set_text('Prediction Mask')
        axarr[2].imshow(pred_numpy[1,:,:])
        axarr[2].title.set_text('Prediction Border')
        axarr[3].imshow(mask_border)
        axarr[3].title.set_text('Subtracted')
        axarr[4].imshow(label_numpy[0,:,:])
        axarr[4].title.set_text('Label')
    else:
        f, axarr = plt.subplots(1, 3, figsize=(15,15))
        axarr[0].imshow(img_numpy[0,:,:]) #4 With 8 channels
        axarr[0].title.set_text('Original')
        axarr[1].imshow(pred_numpy[:,:])
        axarr[1].title.set_text('Prediction Mask')
        axarr[2].imshow(label_numpy)
        axarr[2].title.set_text('Label')

interactive(children=(IntSlider(value=0, description='idx_img', max=39), Output()), _dom_classes=('widget-inte…