# Train Segmentation with Atrous Convolution

#### Loss
This project will use BSE Loss and Dice loss

#### Metrics
![alt text](./imgs_doc/metrics.png "Title")

#### References
* https://arxiv.org/pdf/1709.00179.pdf
* https://medium.com/beyondminds/a-simple-guide-to-semantic-segmentation-effcf83e7e54
* https://medium.com/dair-ai/medical-imaging-analysis-mri-cnn-pytorch-4877e64e7303
* https://medium.com/udacity-pytorch-challengers/a-brief-overview-of-loss-functions-in-pytorch-c0ddb78068f7
* https://github.com/kevinzakka/pytorch-goodies/blob/master/losses.py
* https://arxiv.org/pdf/1702.03275.pdf
* https://medium.com/syncedreview/facebook-ai-proposes-group-normalization-alternative-to-batch-normalization-fb0699bffae7
* https://www.jeremyjordan.me/evaluating-image-segmentation-models/
* https://github.com/martinkersner/py_img_seg_eval
* https://medicaltorch.readthedocs.io/en/stable/
* https://arxiv.org/pdf/1702.05659.pdf
* https://medium.com/@dmitrijtichonov/debunking-loss-functions-in-deep-learning-4b9abc4c8d4c
* https://github.com/meetshah1995/pytorch-semseg
* https://discuss.pytorch.org/t/leaf-variable-has-been-moved-into-the-graph-interior/18679/9
* https://github.com/EKami/carvana-challenge/tree/original_unet
* https://imgaug.readthedocs.io/en/latest/source/installation.html
* https://medium.com/earthcube-stories/techsecret-how-to-use-deep-learning-on-satellite-imagery-episode-1-playing-with-the-loss-8fc05c90a63a
* https://towardsdatascience.com/understanding-semantic-segmentation-with-unet-6be4f42d4b47

#### Andrew Ng on Accuracy/Precision/Recall
Accuracy it's not important if your dataset is imbalanced (Skewed), for example if your model say 100% of time that someone has no cancer, it will be really accurate, like 99.999% but it's Recall will be zero. 
* https://www.youtube.com/watch?v=k1JGvqr56Yk&list=PLLssT5z_DsK-h9vYZkQkYNWcItqhlRJLN&index=66
* https://www.youtube.com/watch?v=wGw6R8AbcuI
* https://www.youtube.com/watch?v=W5meQnGACGo

#### Distributed Training Pytorch
* https://pytorch.org/tutorials/intermediate/dist_tuto.html
* https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
* https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html
* https://pytorch.org/tutorials/intermediate/model_parallel_tutorial.html
* https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255

In [1]:
import sat_utils
import seg_loss
import seg_metrics
import seg_models

import numpy as np
import pickle
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from random import randint
from tqdm import tqdm

# Pytorch stuff
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import torch
import torch.utils.data as utils
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
from torch.optim import lr_scheduler

# Warm-Up Scheduler
# https://github.com/ildoonet/pytorch-gradual-warmup-lr
from warmup_scheduler import GradualWarmupScheduler

# Library for augmentations on batch of numpy/tensors
from imgaug import augmenters as iaa

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#device = 'cpu'
print('Device:', device)
num_gpu = torch.cuda.device_count()
print('Number of GPUs Available:', num_gpu)

lr=0.06521 #0.001 0.0001-(Good with Dice, 0.007 training)
warm_up_epochs = 40
l2_norm=0.0000001
gamma=0.1
batch_size = 128 #32 #20
num_epochs = 500
step_size = 200

Device: cuda:0
Number of GPUs Available: 8


#### Load Data from pickle (Bad not scalable) and create data loader

In [2]:
X = sat_utils.read_pickle_data('./data/input.pickle')
Y = sat_utils.read_pickle_data('./data/label.pickle')

# Convert dictionaries to numpy array
X = np.stack([sat_utils.get_rgb(x) for x in X.values()])
Y = np.stack([(x/255.0) for x in Y.values()])

# Split train/validation
X_t, X_v, Y_t, Y_v = train_test_split(X, Y, test_size=1/10, random_state=42)
print('X_t:', X_t.shape)
print('Y_t:', Y_t.shape)
print('X_v:', X_v.shape)
print('Y_v:', Y_v.shape)

# Changes on label for Cross-Entropy (3 classes all mixed on the same image, N,W,H)
# Changes on label for BCEWithLogitsLoss (3 classes on 3 Channels, N,C,W,H)
tensor_x_t = torch.Tensor(X_t)
tensor_y_t = torch.Tensor(Y_t)
tensor_x_v = torch.Tensor(X_v)
tensor_y_v = torch.Tensor(Y_v)

# Define some augmentations
seq_augm = iaa.Sequential([
    iaa.Fliplr(0.5), # horizontally flip 50% of the images
    iaa.Flipud(0.5), # vertically flip 50% of the images
    #iaa.Affine(rotate=(-10, 10)), # Rotate the images
])

dataset_train = utils.TensorDataset(tensor_x_t,tensor_y_t)
dataset_val = utils.TensorDataset(tensor_x_v,tensor_y_v)
dataloader_train = utils.DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
dataloader_val = utils.DataLoader(dataset_val, batch_size=batch_size, shuffle=False)

X_t: (2038, 3, 76, 76)
Y_t: (2038, 3, 76, 76)
X_v: (227, 3, 76, 76)
Y_v: (227, 3, 76, 76)


In [3]:
print('Input:',tensor_x_t.shape)
print('Label:',tensor_y_t.shape)
num_classes = tensor_y_t.shape[1]
batch_val = tensor_y_v.shape[0]
#num_classes = 4
print('num_classes:', num_classes)
print('Max val on label:', torch.max(tensor_y_t).item())
print('Min val on label:', torch.min(tensor_y_t).item())

Input: torch.Size([2038, 3, 76, 76])
Label: torch.Size([2038, 3, 76, 76])
num_classes: 3
Max val on label: 1.0
Min val on label: 0.0


#### Define Model

In [4]:
model = seg_models.AtrousSeg(num_classes=num_classes, num_channels=tensor_x_t.shape[1])
#resp = model(torch.rand(1, 8, 76, 76))

#### Start Tensorboard Interface

In [5]:
#writer = SummaryWriter('./logs')
# Default directory "runs"
writer = SummaryWriter()
dummy_x = torch.rand(1, tensor_x_t.shape[1], 76, 76)
writer.add_graph(model, dummy_x)



#### Distribute model on available GPUs
On this case we're using DataParallel mode. It will copy the same model and split the batch between multiple GPUs.

In [6]:
if num_gpu > 1:
    print("Let's use", num_gpu, "GPUs!")
    model = nn.DataParallel(model)
model.to(device)

Let's use 8 GPUs!


DataParallel(
  (module): AtrousSeg(
    (model): Sequential(
      (0): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (2): ReLU()
      (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): ReLU()
      (6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), dilation=(2, 2))
      (8): ReLU()
      (9): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), dilation=(2, 2))
      (11): ReLU()
      (12): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (13): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), dilation=(3, 3))
      (14): Re

#### Initialize Losses and Optimizers

In [7]:
# Classification losses will have predictions on the format (batch, n_class, 12, 12)
# and labels (batch, rows, cols) with values related to indexes of class
#loss_fn = nn.CrossEntropyLoss()
loss_fn = nn.BCEWithLogitsLoss()

# Regularization losses accept any logits on format (batch, channels, rows, cols) 
#for both prediction and label
#loss_fn = nn.MSELoss()
#loss_fn = nn.SmoothL1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2_norm)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

# Warmup optimizer (warm first for 10 epochs)
# Decrease learning rate if some metric doesnt change for "patience" epochs
scheduler_plateau = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=4, verbose=True)
#scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=10, total_epoch=warm_up_epochs, after_scheduler=exp_lr_scheduler)

#### Train Model

In [None]:
iteration_count = 0
iteration_count_val = 0
# For all epochs
for epoch in tqdm(range(num_epochs), desc='Training'):
    # Train step
    model.train()    
    running_loss = 0.0
    # For all elements on the training set
    for i, (imgs, labels) in enumerate(dataloader_train):
        # Do augmentations the augmentation library expect numpy arrays on format Batch x Row x Cols x Channels
        imgs_aug, labels_aug = seq_augm(images=np.moveaxis(imgs.numpy(),1,3), heatmaps=np.moveaxis(labels.numpy(),1,3))
        
        # Move axis back and convert back to tensor
        imgs = torch.from_numpy(np.moveaxis(imgs_aug,3,1))
        labels = torch.from_numpy(np.moveaxis(labels_aug,3,1))
        
        # Send inputs/labels to GPU                
        labels = labels.to(device)
        imgs = imgs.to(device)                
        
        optimizer.zero_grad()
        
        outputs = model(imgs)
        
        #loss_dice = dice_loss(outputs, labels)
        iou_value = seg_loss.iou_loss(outputs, labels)
        dice_value  = seg_loss.dice_loss(outputs, labels)
        criterion = loss_fn(outputs, labels)
        #print(type(dice_val))
        #loss = loss_fn(outputs, labels) + (1000*dice_val)
        #loss = criterion# + (1000*dice_value)
        #loss = iou_value
        #loss = dice_value
        loss = criterion + dice_value
        running_loss += loss.item()
        
        loss.backward()
        optimizer.step()
        writer.add_scalar('loss/', loss.item(), iteration_count)
        writer.add_scalar('base_loss/', criterion.item(), iteration_count)
        writer.add_scalar('iou_loss/', iou_value.item(), iteration_count)
        writer.add_scalar('dice_loss/', dice_value.item(), iteration_count)
        iteration_count+=1        
    
    # Print Finished epoch
    
    # Get current learning rate (To display on Tensorboard)
    for param_group in optimizer.param_groups:
        curr_learning_rate = param_group['lr']
        writer.add_scalar('learning_rate/', curr_learning_rate, epoch)
    
    # Send images from training to tensorboard
    out_norm = sat_utils.img_minmax_norm_torch(outputs)
    labels_norm = sat_utils.img_minmax_norm_torch(labels)
    imgs_norm = sat_utils.img_minmax_norm_torch(imgs)
    writer.add_images('img_t', imgs_norm, epoch)
    writer.add_images('out_mask_t', out_norm[:, 0, :, :].unsqueeze(1), epoch)    
    writer.add_images('out_sep_t', out_norm[:, 1, :, :].unsqueeze(1), epoch)
    writer.add_images('out_border_t', out_norm[:, 2, :, :].unsqueeze(1), epoch)
    writer.add_images('label_mask_t', labels_norm[:, 0, :, :].unsqueeze(1), epoch)
    writer.add_images('label_sep_t', labels_norm[:, 1, :, :].unsqueeze(1), epoch)
    writer.add_images('label_border_t', labels_norm[:, 2, :, :].unsqueeze(1), epoch)
    
    # Run Validation
    model.eval()
    with torch.no_grad():
        for i, (imgs, labels) in enumerate(dataloader_val):
            # Send inputs/labels to GPU                
            labels = labels.to(device)
            imgs = imgs.to(device)                
            outputs = model(imgs)
            iou_value = seg_metrics.iou(outputs, labels)
            dice_value = seg_metrics.dice(outputs, labels)
            writer.add_scalar('iou_val/', iou_value.item(), iteration_count_val)
            writer.add_scalar('dice_val/', dice_value.item(), iteration_count_val)
            iteration_count_val += 1
    
    # Send images from validation to tensorboard
    out_norm = sat_utils.img_minmax_norm_torch(outputs)
    labels_norm = sat_utils.img_minmax_norm_torch(labels)
    imgs_norm = sat_utils.img_minmax_norm_torch(imgs)
    writer.add_images('img_v', imgs_norm, epoch)
    writer.add_images('out_mask_v', out_norm[:, 0, :, :].unsqueeze(1), epoch)    
    writer.add_images('label_mask_v', labels_norm[:, 0, :, :].unsqueeze(1), epoch)
    
    # Save Model
    torch.save(model, './model_save/model_'+str(epoch)+'.cpkt')
    
    # Step learning rate Decay
    #exp_lr_scheduler.step(epoch)
    #scheduler_warmup.step()
    scheduler_plateau.step(running_loss)

Training:   3%|▎         | 13/500 [11:19<6:49:52, 50.50s/it]

Epoch    12: reducing learning rate of group 0 to 6.5210e-03.


Training:   4%|▎         | 18/500 [15:28<6:45:02, 50.42s/it]

Epoch    17: reducing learning rate of group 0 to 6.5210e-04.


Training:   5%|▌         | 25/500 [21:10<6:28:14, 49.04s/it]

Epoch    24: reducing learning rate of group 0 to 6.5210e-05.


Training:   6%|▌         | 29/500 [24:27<6:27:37, 49.38s/it]

Epoch    28: reducing learning rate of group 0 to 6.5210e-06.


Training:   7%|▋         | 33/500 [27:45<6:26:56, 49.71s/it]

Epoch    32: reducing learning rate of group 0 to 6.5210e-07.


Training:   7%|▋         | 37/500 [31:01<6:21:26, 49.43s/it]

Epoch    36: reducing learning rate of group 0 to 6.5210e-08.


Training:   8%|▊         | 41/500 [34:16<6:13:07, 48.77s/it]

Epoch    40: reducing learning rate of group 0 to 6.5210e-09.


Training:   9%|▉         | 45/500 [37:34<6:11:35, 49.00s/it]

#### Load Specific Model

In [None]:
model = torch.load('./model_save/model_499.cpkt')
model.eval()

#### Test Model

In [None]:
@interact(idx_img=widgets.IntSlider(min=0,max=tensor_x_t.shape[0]-1), th_mask_iteractive=widgets.IntSlider(min=0,max=100), use_threshold = False)
def testModel(idx_img):
    model.eval()
    with torch.no_grad():
        img = tensor_x_t[idx_img].unsqueeze(0).to(device)
        pred = model(img)
        label = tensor_y_t[idx_img].to(device)
        dice_value = seg_metrics.dice(pred, label.unsqueeze(0))
        iou_value = seg_metrics.iou(pred, label.unsqueeze(0))
        dice_loss_value = seg_loss.dice_loss(pred, label.unsqueeze(0))
        iou_loss_value = seg_loss.iou_loss(pred, label.unsqueeze(0))
        
    img_numpy = img.cpu().squeeze().numpy()
    img_numpy = sat_utils.img_minmax_norm(img_numpy)
    img_numpy_m = np.moveaxis(img_numpy, 0, 2)
    
    pred_numpy = pred.cpu().squeeze().numpy()
    # Test
    pred_numpy_sig = F.sigmoid(pred).cpu().squeeze().numpy()
    label_numpy = tensor_y_t[idx_img].squeeze().numpy()   
    img_numpy = sat_utils.img_minmax_norm(img_numpy)
    
    print('Dice Val:', dice_value)
    print('IoU Val:', iou_value)
    print('Dice Loss:', dice_loss_value)
    print('IoU Loss:', iou_loss_value)
    
    if num_classes > 1:
        # Merge Mask and Border
        mask_border = pred_numpy[0,:,:] - pred_numpy[1,:,:]
        
        f, axarr = plt.subplots(1, 5, figsize=(15,15))
        axarr[0].imshow(img_numpy_m) #4 With 8 channels
        axarr[0].title.set_text('Original')
        axarr[1].imshow(pred_numpy[0,:,:])
        axarr[1].title.set_text('Prediction Mask')
        axarr[2].imshow(pred_numpy[1,:,:])
        axarr[2].title.set_text('Prediction Border')
        axarr[3].imshow(mask_border)
        axarr[3].title.set_text('Subtracted')
        axarr[4].imshow(label_numpy[0,:,:])
        axarr[4].title.set_text('Label')
    else:
        f, axarr = plt.subplots(1, 3, figsize=(15,15))
        axarr[0].imshow(img_numpy[0,:,:]) #4 With 8 channels
        axarr[0].title.set_text('Original')
        axarr[1].imshow(pred_numpy[:,:])
        axarr[1].title.set_text('Prediction Mask')
        axarr[2].imshow(label_numpy)
        axarr[2].title.set_text('Label')