# Train Segmentation with Atrous Convolution


#### References
* https://arxiv.org/pdf/1709.00179.pdf
* https://medium.com/beyondminds/a-simple-guide-to-semantic-segmentation-effcf83e7e54
* https://medium.com/dair-ai/medical-imaging-analysis-mri-cnn-pytorch-4877e64e7303
* https://medium.com/udacity-pytorch-challengers/a-brief-overview-of-loss-functions-in-pytorch-c0ddb78068f7
* https://github.com/kevinzakka/pytorch-goodies/blob/master/losses.py
* https://arxiv.org/pdf/1702.03275.pdf
* https://medium.com/syncedreview/facebook-ai-proposes-group-normalization-alternative-to-batch-normalization-fb0699bffae7
* https://www.jeremyjordan.me/evaluating-image-segmentation-models/
* https://github.com/martinkersner/py_img_seg_eval
* https://medicaltorch.readthedocs.io/en/stable/
* https://github.com/meetshah1995/pytorch-semseg

#### Andrew Ng on Accuracy/Precision/Recall
Accuracy it's not important if your dataset is imbalanced (Skewed), for example if your model say 100% of time that someone has no cancer, it will be really accurate, like 99.999% but it's Recall will be zero. 
* https://www.youtube.com/watch?v=k1JGvqr56Yk&list=PLLssT5z_DsK-h9vYZkQkYNWcItqhlRJLN&index=66
* https://www.youtube.com/watch?v=wGw6R8AbcuI
* https://www.youtube.com/watch?v=W5meQnGACGo

In [1]:
import sat_utils
import numpy as np
import pickle
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from random import randint
from tqdm import tqdm

# Pytorch stuff
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import torch
import torch.utils.data as utils
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
from torch.optim import lr_scheduler

from skimage.filters import threshold_adaptive, threshold_otsu, threshold_local

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
#device = 'cpu'
print('Device:', device)

lr=0.0001 #0.001
l2_norm=0.0000001
gamma=0.1
batch_size = 20 #20
num_epochs = 500
step_size = 200
SMOOTH = 1e-11

Device: cuda:1


#### Load Data from pickle (Bad not scalable) and create data loader

In [2]:
X = sat_utils.read_pickle_data('./data/input.pickle')
Y = sat_utils.read_pickle_data('./data/label.pickle')
tensor_x = torch.stack([torch.Tensor(sat_utils.get_rgb(x)) for x in X.values()])

# Label from Mean squared error Loss
#tensor_y = torch.stack([torch.Tensor(x) for x in Y.values()])
tensor_y = torch.stack([torch.Tensor(x/255.0) for x in Y.values()])
# Try to make background even more far
#tensor_y[tensor_y==tensor_y.min()] = -100.0

# Changes on label for Cross-Entropy (3 classes all mixed on the same image, N,W,H)
# Changes on label for BCEWithLogitsLoss (3 classes on 3 Channels, N,C,W,H)
#tensor_y = torch.stack([torch.Tensor(x/255.0).type(torch.LongTensor) for x in Y.values()])
#tensor_y = torch.stack([torch.Tensor(x/255.0).type(torch.FloatTensor) for x in Y.values()])
# Just one class
#tensor_y[:,2,:,:][tensor_y[:,2,:,:]==2.0] = 0
#tensor_y = tensor_y[:,0,:,:] + (tensor_y[:,2,:,:] * 2.0)
#print(torch.unique(tensor_y))
#print(torch.unique(tensor_y))
#tensor_y[:,0,:,:] = 1
#tensor_y[:,1,:,:] = 2
#tensor_y[:,2,:,:] = 3

dataset_train = utils.TensorDataset(tensor_x,tensor_y)
dataloader_train = utils.DataLoader(dataset_train, batch_size=batch_size)

In [3]:
print('Input:',tensor_x.shape)
print('Label:',tensor_y.shape)
num_classes = tensor_y.shape[1]
#num_classes = 4
print('num_classes:', num_classes)
print('Max val on label:', torch.max(tensor_y).item())
print('Min val on label:', torch.min(tensor_y).item())

Input: torch.Size([40, 3, 76, 76])
Label: torch.Size([40, 3, 76, 76])
num_classes: 3
Max val on label: 1.0
Min val on label: 0.0


#### Custom Loss Functions
![alt text](./imgs_doc/metrics.png "Title")

In [4]:
class LossBinaryISB:

    def __init__(self, jaccard_weight=0):
        self.nll_loss = nn.BCEWithLogitsLoss()
        self.jaccard_weight = jaccard_weight


    def __call__(self, outputs, targets):
        loss = (1 - self.jaccard_weight) * self.nll_loss(outputs, targets)

        if self.jaccard_weight:
            eps = 1e-15
            jaccard_target = (targets == 1).float()
            jaccard_output = F.sigmoid(outputs)

            intersection = (jaccard_output * jaccard_target).sum()
            union = jaccard_output.sum() + jaccard_target.sum()

            loss = self.jaccard_weight *loss+ (1-self.jaccard_weight)*(1-(intersection + eps) / (union - intersection + eps))
        return loss

In [5]:
def dice_loss(model_outputs, labels):
    #xt = torch.FloatTensor(x,requires_grad=True)
    bin_model_outputs = torch.zeros_like(model_outputs, requires_grad=True).type(torch.FloatTensor)
    bin_labels = (labels > 0).type(torch.FloatTensor)
    
    # Convert all channels from model_outputs to binary
    list_max = [torch.max(model_outputs[:,ch,:,:]) for ch in range(model_outputs.shape[1])]
    list_min = [torch.min(model_outputs[:,ch,:,:]) for ch in range(model_outputs.shape[1])]
    list_threshold = [(list_max[ch] - list_min[ch]) / 2.0 for ch in range(model_outputs.shape[1])]
    for ch in range(model_outputs.shape[1]):
        bin_model_outputs[:,ch,:,:] = model_outputs[:,ch,:,:] > list_threshold[ch]
        
    smooth = 1.

    iflat = bin_model_outputs.view(-1)
    tflat = bin_labels.view(-1)
    intersection = (iflat * tflat).sum()
    
    return 1 - ((2. * intersection + smooth) /
              (iflat.sum() + tflat.sum() + smooth))

In [6]:
def CrossEntropyLoss2d(input, target, weight=None, size_average=True):
    n, c, h, w = input.size()
    print(target.shape)
    nt, ht, wt = target.size()

    # Handle inconsistent size between input and target
    if h != ht and w != wt:  # upsample labels
        input = F.interpolate(input, size=(ht, wt), mode="bilinear", align_corners=True)

    input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
    target = target.view(-1)
    loss = F.cross_entropy(
        input, target, weight=weight, size_average=size_average, ignore_index=250
    )
    return loss

In [7]:
# model_outputs, labels format: BATCH x Channels x ROWs x COLs
# Reference: https://www.kaggle.com/iezepov/fast-iou-scoring-metric-in-pytorch-and-numpy
def iou(model_outputs, labels):
    bin_model_outputs = torch.zeros_like(model_outputs, requires_grad=True).type(torch.LongTensor)
    bin_labels = (labels > 0).type(torch.LongTensor)
    # Convert all channels from model_outputs to binary
    list_max = [torch.max(model_outputs[:,ch,:,:]) for ch in range(model_outputs.shape[1])]
    list_min = [torch.min(model_outputs[:,ch,:,:]) for ch in range(model_outputs.shape[1])]
    list_threshold = [(list_max[ch] - list_min[ch]) / 2.0 for ch in range(model_outputs.shape[1])]
    for ch in range(model_outputs.shape[1]):
        bin_model_outputs[:,ch,:,:] = model_outputs[:,ch,:,:] > list_threshold[ch]
    
    intersection = (bin_model_outputs & bin_labels).float().sum((2, 3))  # Will be zero if Truth=0 or Prediction=0
    union = (bin_model_outputs | bin_labels).float().sum((2, 3))         # Will be zzero if both are 0
    # Calculate Intersect over Union (Jaccard Index)
    iou = (intersection + SMOOTH) / (union + SMOOTH)  # We smooth our devision to avoid 0/0
    
    # Giounona idea was to sigmoid the output (I suspect that it's making the objective function more robust, maybe even easier
    # for the backprop but more)
    #jaccard_target = (labels > 0.0).float()
    #jaccard_output = F.sigmoid(model_outputs)

    #intersection = (jaccard_output * jaccard_target).sum((2, 3))
    #union = jaccard_output.sum() + jaccard_target.sum((2, 3))
    #iou = (intersection + SMOOTH) / (union + SMOOTH)
    
    # This is equal to comparing with thresolds
    thresholded = torch.clamp(20 * (iou - 0.5), 0, 10).ceil() / 10  
    # Or thresholded.mean() if you are interested in average across the batch
    return iou.mean(), thresholded.mean()

#### Define Model

In [8]:
# Input 76x76 output 16x16
class AtrousSeg(nn.Module):
    def __init__(self, num_classes=1, num_channels=8):
        super().__init__()
        self.model = nn.Sequential(
            #nn.BatchNorm2d(num_channels),
            nn.Conv2d(num_channels, 64, kernel_size=3, stride=1, padding=1, dilation = 1), # Front           
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, dilation = 1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=0, dilation = 2),            
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=0, dilation = 2),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=0, dilation = 3),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=0, dilation = 3),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=0, dilation = 3),            
            nn.ReLU(),
            nn.BatchNorm2d(256),
            
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=0, dilation = 3), #LFE
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, dilation = 3), 
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, dilation = 3), 
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, dilation = 2), 
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, dilation = 2), 
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, dilation = 1), 
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, dilation = 1),             
            nn.ReLU(),
            nn.BatchNorm2d(256),
            
            nn.Conv2d(256, 1024, kernel_size=7, stride=1, padding=1, dilation = 3), # Head (44x44)
            nn.ReLU(),
            nn.BatchNorm2d(1024),
            nn.Conv2d(1024, 1024, kernel_size=1, stride=1, dilation = 1), 
            nn.ReLU(),
            nn.BatchNorm2d(1024),
            nn.Conv2d(1024, num_classes, kernel_size=1, stride=1, dilation = 1),
            #nn.Sigmoid(),
            nn.UpsamplingBilinear2d(size=(76, 76)),
        )
        # Initialize Weights
        for m in self.model:
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                nn.init.constant_(m.bias, 0)
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def forward(self, x):
        result = self.model(x)
        # Better for MSE
        #return result
        # Better for BCEWithLogitsLoss
        return F.sigmoid(result)
        #return result
        #return self.model(x)
        if self.training:
            return result
        else:
            return F.sigmoid(result)
        #    result = self.model(x)
        #    print(x.shape)
        #    print(result.shape)
        #    F.softmax(result, dim=2)

In [None]:
model = AtrousSeg(num_classes=num_classes, num_channels=tensor_x.shape[1])
#resp = model(torch.rand(1, 8, 76, 76))

In [None]:
#writer = SummaryWriter('./logs')
# Default directory "runs"
writer = SummaryWriter()
dummy_x = torch.rand(1, tensor_x.shape[1], 76, 76)
writer.add_graph(model, dummy_x)



In [None]:
model.to(device)

AtrousSeg(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), dilation=(2, 2))
    (7): ReLU()
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), dilation=(2, 2))
    (10): ReLU()
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), dilation=(3, 3))
    (13): ReLU()
    (14): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (15): Conv2d(256, 256, kernel_size=(3, 3), stri

#### Initialize Losses and Optimizers

In [None]:
# Classification losses will have predictions on the format (batch, n_class, 12, 12)
# and labels (batch, rows, cols) with values related to indexes of class
#loss_fn = nn.CrossEntropyLoss()
#loss_fn = nn.BCEWithLogitsLoss()

# Regularization losses accept any logits on format (batch, channels, rows, cols) 
#for both prediction and label
loss_fn = nn.MSELoss()
#loss_fn = nn.SmoothL1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2_norm)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

#### Train Model

In [None]:
iteration_count = 0
# For all epochs
for epoch in tqdm(range(num_epochs), desc='Training'):
    # Train step
    model.train()    
    # For all elements on the training set
    for i, (imgs, labels) in enumerate(dataloader_train):
        # Send inputs/labels to GPU                
        labels = labels.to(device)
        imgs = imgs.to(device)                
        
        optimizer.zero_grad()
        
        outputs = model(imgs)
        
        #loss_dice = dice_loss(outputs, labels)
        #iou_val, _ = iou(outputs, labels)
        dice_val  = dice_loss(outputs, labels).to(device)
        #print(type(dice_val))
        loss = loss_fn(outputs, labels) + (1000*dice_val)
        
        loss.backward()
        optimizer.step()
        exp_lr_scheduler.step(epoch)
        writer.add_scalar('loss/', loss.item(), iteration_count)
        #writer.add_scalar('loss_dice/', loss_dice.item(), iteration_count)
        iteration_count+=1        
    #print('Epoch:', epoch, 'of:', num_epochs, 'loss:', loss.item())
    # Get number of channels on output (Number of classes)
    num_channels_outputs = outputs.shape[1]
    
    # Get Iou, Dice
    iou_val, _ = iou(outputs, labels)
    writer.add_scalar('iou/', iou_val.item(), epoch)
    dice_val = dice_loss(outputs, labels)
    writer.add_scalar('dice/', dice_val.item(), epoch)
    
    # Get current learning rate (To display on Tensorboard)
    for param_group in optimizer.param_groups:
        curr_learning_rate = param_group['lr']
        writer.add_scalar('learning_rate/', curr_learning_rate, epoch)
    
    # Send to tensorboard loss
    out_norm = sat_utils.img_minmax_norm_torch(outputs)
    labels_norm = sat_utils.img_minmax_norm_torch(labels)
    imgs_norm = sat_utils.img_minmax_norm_torch(imgs)
    writer.add_images('Image', imgs_norm, epoch)
    if num_classes > 1:
        writer.add_images('out_mask', out_norm[:, 0, :, :].unsqueeze(1), epoch)    
        writer.add_images('out_between', out_norm[:, 1, :, :].unsqueeze(1), epoch)
        writer.add_images('out_border', out_norm[:, 2, :, :].unsqueeze(1), epoch)
        if len(labels_norm.shape) > 3:
            writer.add_images('label_mask', labels_norm[:, 0, :, :].unsqueeze(1), epoch)
            writer.add_images('label_between', labels_norm[:, 1, :, :].unsqueeze(1), epoch)
            writer.add_images('label_border', labels_norm[:, 2, :, :].unsqueeze(1), epoch)
        else:
            writer.add_images('label_mask', labels_norm.unsqueeze(1), epoch)  
    else:
        writer.add_images('out_mask', out_norm, epoch)            
        writer.add_images('label_mask', labels_norm.unsqueeze(1), epoch)        
    
    #img_idx = randint(0, batch_size-1)
    #img_input = imgs_norm[img_idx,:,:,:].cpu().numpy()
    #f, axarr = plt.subplots(1, (num_channels_outputs*2) + 1, figsize=(15,15))
    #axarr[0].imshow(np.moveaxis(img_input, 0, 2)) #4 With 8 channels
    #axarr[1].imshow(outputs[img_idx,0,:,:].detach().cpu())
    #axarr[2].imshow(outputs[img_idx,1,:,:].detach().cpu())
    #axarr[3].imshow(outputs[img_idx,2,:,:].detach().cpu())    
    #axarr[4].imshow(labels[img_idx,0,:,:].cpu())
    #axarr[5].imshow(labels[img_idx,1,:,:].cpu())
    #axarr[6].imshow(labels[img_idx,2,:,:].cpu())
    #***-----------plt.show()
    

Training:  32%|███▏      | 158/500 [28:12<58:37, 10.29s/it]  

#### Test Model

In [None]:
@interact(idx_img=widgets.IntSlider(min=0,max=tensor_x.shape[0]-1), th_mask_iteractive=widgets.IntSlider(min=0,max=100), use_threshold = False)
def testModel(idx_img, th_mask_iteractive, use_threshold):
    model.eval()
    with torch.no_grad():
        img = tensor_x[idx_img].unsqueeze(0).to(device)
        pred = model(img)
        label = tensor_y[idx_img].to(device)
        dice_val = dice_loss(pred, label.unsqueeze(0))
        iou_val, th_iou = iou(pred, label.unsqueeze(0))
        
    img_numpy = img.cpu().squeeze().numpy()
    pred_numpy = pred.cpu().squeeze().numpy()
    # Test
    pred_numpy_sig = F.sigmoid(pred).cpu().squeeze().numpy()
    label_numpy = tensor_y[idx_img].squeeze().numpy()   
    img_numpy = sat_utils.img_minmax_norm(img_numpy)
    
    print('Dice:', dice_val)
    print('IoU:', iou_val)
    
    if num_classes > 1:
        # Display prediction ranges
        th_mask = (np.max(pred_numpy[0,:,:]) - np.min(pred_numpy[0,:,:])) / 2

        # Merge Mask and Border
        mask_border = pred_numpy[0,:,:] - pred_numpy[1,:,:]
        th_mask_sub = (np.max(mask_border) - np.min(mask_border)) / 2
        #print('th_mask_sub:', th_mask_sub)
        if use_threshold:        
            mask_border_ts = mask_border > (th_mask_iteractive)
        else:
            mask_border_ts = mask_border > (th_mask/2)
            mask_pred_ts = pred_numpy[0,:,:] > th_mask

        f, axarr = plt.subplots(1, 8, figsize=(15,15))
        axarr[0].imshow(img_numpy[0,:,:]) #4 With 8 channels
        axarr[0].title.set_text('Original')
        axarr[1].imshow(pred_numpy[0,:,:])
        axarr[1].title.set_text('Prediction Mask')
        axarr[2].imshow(pred_numpy[1,:,:])
        axarr[2].title.set_text('Prediction Border')
        axarr[3].imshow(mask_border)
        axarr[3].title.set_text('Subtracted')
        axarr[4].imshow(label_numpy[0,:,:])
        axarr[4].title.set_text('Label')
        axarr[5].imshow(mask_pred_ts)
        axarr[5].title.set_text('Mask Threshold')
        axarr[6].imshow(mask_border_ts)
        axarr[6].title.set_text('Subtracted Threshold')
        axarr[7].imshow(pred_numpy_sig[0,:,:])
        axarr[7].title.set_text('Sigmoid Prediction')
    else:
        f, axarr = plt.subplots(1, 3, figsize=(15,15))
        axarr[0].imshow(img_numpy[0,:,:]) #4 With 8 channels
        axarr[0].title.set_text('Original')
        axarr[1].imshow(pred_numpy[:,:])
        axarr[1].title.set_text('Prediction Mask')
        axarr[2].imshow(label_numpy)
        axarr[2].title.set_text('Label')

#### Quick Test IoU

In [None]:
idx_img = 17
model.eval()
with torch.no_grad():
    img = tensor_x[idx_img].unsqueeze(0).to(device)
    pred = model(img)
    label = tensor_y[idx_img].to(device)
    
    # Dice best is value 0
    dice_val = dice_loss(pred, label.unsqueeze(0))
    print('Different images(Dice):',dice_val)
    dice_val = dice_loss(label.unsqueeze(0), label.unsqueeze(0))
    print('Equal images(Dice):',dice_val)
    
    # Jaccard best is value 1
    iou_val, _ = iou(pred, label.unsqueeze(0))
    print('Different images(IoU):',iou_val)
    #iou_val, _ = iou(label.unsqueeze(0), label.unsqueeze(0))
    iou_val, _ = iou(pred, pred)
    print('Equal images(IoU):',iou_val)

In [None]:
jaccard_target = (label.unsqueeze(0)> 0.0).float()
jaccard_output = F.sigmoid(pred)

intersection = (jaccard_output * jaccard_target).sum((2, 3))
union = jaccard_output.sum() + jaccard_target.sum((2, 3))
iou = (intersection + SMOOTH) / (union + SMOOTH)
print(iou)

In [None]:
jaccard_target = (label.unsqueeze(0)> 0.0).float()
jaccard_output = F.relu(label.unsqueeze(0)) / torch.max(label.unsqueeze(0))

intersection = (jaccard_output * jaccard_target).sum()
union = jaccard_output.sum() + jaccard_target.sum()
iou = (intersection + SMOOTH) / (union + SMOOTH)
print(iou)

In [None]:
jaccard_target

In [None]:
jaccard_output

In [None]:
F.sigmoid(torch.zeros(1))

In [None]:
F.tanh(torch.zeros(1))

In [None]:
F.tanh(torch.ones(1) * )

In [None]:
torch.ones(1).type(torch.LongTensor) & torch.ones(1).type(torch.LongTensor)*2