# MRI NIH : UNet

To vizualize on tensor board :

tensorboard --logdir runs

## Imports

In [1]:
import torch
from dataset import MRI2DSegDataset
import transforms
import json
from torchvision import transforms as torch_transforms
import torchvision.utils as vutils
from tensorboardX import SummaryWriter
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
import torch.optim as optim
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np

## Hyperparameters

In [2]:
parameters = json.load(open('/Users/frpau_local/Documents/nih/data/luisa_with_gt/parameters.json'))
print json.dumps(parameters, indent=4)

{
    "training": {
        "learning_rate": 0.0001, 
        "optimizer": "sgd", 
        "loss_function": "dice", 
        "batch_size": 2, 
        "sgd_momentum": 0.9
    }, 
    "transforms": {
        "flip_rate": 0.5, 
        "ratio_range": [
            0.75, 
            1.25
        ], 
        "elastic_rate": 0.3, 
        "sigma_range": [
            3.5, 
            4
        ], 
        "alpha_range": [
            10, 
            15
        ], 
        "crop_size": [
            224, 
            128
        ], 
        "max_angle": 20, 
        "scale_range": [
            0.5, 
            1
        ]
    }, 
    "net": {
        "drop_rate": 0.01, 
        "bn_momentum": 0.1
    }
}


## Create dataset

In [3]:
toTensor = transforms.ToTensor()
toPIL = transforms.ToPIL()
randomVFlip = transforms.RandomVerticalFlip()
randomResizedCrop = transforms.RandomResizedCrop(parameters["transforms"]["crop_size"], scale=parameters["transforms"]["scale_range"], ratio=parameters["transforms"]["ratio_range"])
randomRotation = transforms.RandomRotation(parameters["transforms"]["max_angle"])
elasticTransform = transforms.ElasticTransform(parameters["transforms"]["alpha_range"], parameters["transforms"]["sigma_range"], parameters["transforms"]["elastic_rate"])

composed = torch_transforms.Compose([toPIL,randomVFlip,randomRotation,randomResizedCrop, elasticTransform, toTensor])

dataset = MRI2DSegDataset("/Users/frpau_local/Documents/nih/data/luisa_with_gt/filenames_csf_gm_nawm.txt", transform = composed)

  from ._conv import register_converters as _register_converters


## Define net

In [4]:
class DownConv(nn.Module):
    def __init__(self, in_feat, out_feat, drop_rate=0.4, bn_momentum=0.1):
        super(DownConv, self).__init__()
        self.conv1 = nn.Conv2d(in_feat, out_feat, kernel_size=3, padding=1)
        self.conv1_bn = nn.BatchNorm2d(out_feat, momentum=bn_momentum)
        self.conv1_drop = nn.Dropout2d(drop_rate)
        
        self.conv2 = nn.Conv2d(out_feat, out_feat, kernel_size=3, padding=1)
        self.conv2_bn = nn.BatchNorm2d(out_feat, momentum=bn_momentum)
        self.conv2_drop = nn.Dropout2d(drop_rate)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.conv1_bn(x)
        x = self.conv1_drop(x)
        
        x = F.relu(self.conv2(x))
        x = self.conv2_bn(x)
        x = self.conv2_drop(x)        
        return x
    
class UpConv(nn.Module):
    def __init__(self, in_feat, out_feat, drop_rate=0.4, bn_momentum=0.1):
        super(UpConv, self).__init__()
        self.up1 = nn.Upsample(scale_factor=2, mode='bilinear')
        self.downconv = DownConv(in_feat, out_feat, drop_rate, bn_momentum)
    
    def forward(self, x, y):
        x = self.up1(x)
        x = torch.cat([x, y], dim=1)
        x = self.downconv(x)
        return x

class UNet(nn.Module):
    def __init__(self, drop_rate=0.4, bn_momentum=0.1):
        super(UNet, self).__init__()
        
        #Downsampling path
        self.conv1 = DownConv(1, 64, drop_rate, bn_momentum)
        self.mp1 = nn.MaxPool2d(2)

        self.conv2 = DownConv(64, 128, drop_rate, bn_momentum)
        self.mp2 = nn.MaxPool2d(2)    

        self.conv3 = DownConv(128, 256, drop_rate, bn_momentum)
        self.mp3 = nn.MaxPool2d(2)          

        # Bottom
        self.conv4 = DownConv(256, 256, drop_rate, bn_momentum)

        # Upsampling path
        self.up1 = UpConv(512, 256, drop_rate, bn_momentum)
        self.up2 = UpConv(384, 128, drop_rate, bn_momentum)
        self.up3 = UpConv(192, 64, drop_rate, bn_momentum)

        self.conv9 = nn.Conv2d(64, 4, kernel_size=3, padding=1)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.mp1(x1)

        x3 = self.conv2(x2)
        x4 = self.mp2(x3)
        
        x5 = self.conv3(x4)
        x6 = self.mp3(x5)    
        
        # Bottom
        x7 = self.conv4(x6)
        
        # Up-sampling
        x8 = self.up1(x7, x5)
        x9 = self.up2(x8, x3)
        x10 = self.up3(x9, x1)
        
        x11 = self.conv9(x10)
        preds = F.sigmoid(x11)        
        
        return preds


## Dice loss

In [5]:
def get_bg_gt(gts):
    gt_size = gts[0].size()
    bg_gt = torch.ones([gt_size[0],1,gt_size[2], gt_size[3]])
    zeros = torch.zeros([gt_size[0],1,gt_size[2], gt_size[3]])
    for gt in gts:
        bg_gt = torch.max(bg_gt - gt, zeros)
    return bg_gt

def dice(pred, gt):
    eps = 0.0000000001
    return -(2*(pred.data.contiguous().view(-1)*gt.view(-1)).sum()+eps)/(torch.sum(pred)+torch.sum(gt)+eps)

def dice_loss(pred, gts):
    bg_gt = get_bg_gt(gts)
    loss = dice(pred[::,0,::,::], bg_gt)
    for i in range(len(gts)):
        loss = loss+dice(pred[::,i+1,::,::], gts[i])
    return loss

## Metrics

In [6]:
def numeric_score(pred, gts):
    """Computation of statistical numerical scores:

    * FP = False Positives
    * FN = False Negatives
    * TP = True Positives
    * TN = True Negatives

    return: tuple (FP, FN, TP, TN)
    """
    np_pred = pred.numpy()
    np_gts = [get_bg_gt(gts)]+gts
    np_gts = [gt.numpy() for gt in gts]
    FP = []
    FN = []
    TP = []
    TN = []
    for i in range(len(gts)):
        FP.append(np.float(np.sum((np_pred == i) & (np_gts[i] == 0))))
        FN.append(np.float(np.sum((np_pred != i) & (np_gts[i] == 1))))
        TP.append(np.float(np.sum((np_pred == i) & (np_gts[i] == 1))))
        TN.append(np.float(np.sum((np_pred != i) & (np_gts[i] == 0))))
    return FP, FN, TP, TN


def precision_score(FP, FN, TP, TN):
    # PPV
    precision = []
    for i in range(len(FP)):
        if (TP[i] + FP[i]) <= 0.0:
            precision.append(0.0)
        else:
            precision.append(np.divide(TP[i], TP[i] + FP[i])* 100.0)
    return precision


def recall_score(FP, FN, TP, TN):
    # TPR, sensitivity
    TPR = []
    for i in range(len(FP)):
        if (TP[i] + FN[i]) <= 0.0:
            TPR.append(0.0)
        else:
            TPR.append(np.divide(TP[i], TP[i] + FN[i]) * 100.0)
    return TPR


def specificity_score(FP, FN, TP, TN):
    TNR = []
    for i in range(len(FP)):
        if (TN[i] + FP[i]) <= 0.0:
            TNR.append(0.0)
        else:
            TNR.append(np.divide(TN[i], TN[i] + FP[i]) * 100.0)
    return TNR 


def intersection_over_union(FP, FN, TP, TN):
    IOU = []
    for i in range(len(FP)):
        if (TP[i] + FP[i] + FN[i]) <= 0.0:
            IOU.append(0.0)
        else:
            IOU.append(TP[i] / (TP[i] + FP[i] + FN[i]) * 100.0)
    return IOU


def accuracy_score(FP, FN, TP, TN):
    accuracy = []
    for i in range(len(FP)):
        N = FP[i] + FN[i] + TP[i] + TN[i]
        accuracy.append(np.divide(TP[i] + TN[i], N) * 100.0)
    return accuracy

In [7]:
net = UNet(drop_rate=parameters["net"]["drop_rate"], bn_momentum=parameters["net"]["bn_momentum"])
dataloader = DataLoader(dataset, batch_size=parameters["training"]["batch_size"], shuffle=True, num_workers=4)
if parameters["training"]["optimizer"]=="sgd":
    optimizer = optim.SGD(net.parameters(), lr=parameters["training"]['learning_rate'], momentum=parameters["training"]['sgd_momentum'])
if parameters["training"]["loss_function"]=="dice":
    loss_function = dice_loss

## training

In [None]:
writer = SummaryWriter()
# add hyperparameters to description
writer.add_text("hyperparameters", json.dumps(parameters))

for epoch in tqdm(range(30)):
    
    loss_agg = 0.
    
    for i_batch, sample_batched in enumerate(dataloader):
        output =  net(sample_batched['input'])
        predictions = torch.argmax(output, 1, keepdim=True)
        loss = loss_function(output, sample_batched['gt'])
        loss.backward()
        optimizer.step()
        loss_agg += loss.item()
        
        # Visualization
        if not i_batch%5 and i_batch>0:
            n_iter = epoch*len(dataloader)+i_batch
            
            # loss
            writer.add_scalar("loss_"+parameters["training"]["loss_function"], loss_agg/5, n_iter)
            loss_agg = 0.
            
            # metrics
            FP, FN, TP, TN = numeric_score(predictions, sample_batched['gt'])
            precision = precision_score(FP, FN, TP, TN)
            recall = recall_score(FP, FN, TP, TN)
            specificity = specificity_score(FP, FN, TP, TN)
            iou = intersection_over_union(FP, FN, TP, TN)
            accuracy = accuracy_score(FP, FN, TP, TN)
            for i in range(len(sample_batched['gt'])):
                writer.add_scalar("precision_"+str(i), precision[i], n_iter)
                writer.add_scalar("recall_"+str(i), recall[i], n_iter)
                writer.add_scalar("specificity_"+str(i), specificity[i], n_iter)
                writer.add_scalar("intersection_over_union_"+str(i), iou[i], n_iter)
                writer.add_scalar("accuracy_"+str(i), accuracy[i], n_iter)

            #images
            input_image = vutils.make_grid(sample_batched['input'][0]/torch.max(sample_batched['input'][0]), normalize=True, scale_each=True)
            writer.add_image('Input image', input_image, n_iter)
            output_bg = vutils.make_grid(output[0,0,::,::], normalize=True, scale_each=True)
            pred_bg = vutils.make_grid(predictions[0,0,::,::]==0, normalize=True, scale_each=True)
            writer.add_image('Output background', output_bg, n_iter)
            writer.add_image('Prediction background', pred_bg, n_iter)
            for i in range(len(sample_batched['gt'])):
                output_image = vutils.make_grid(output[0,i+1,::,::], normalize=True, scale_each=True)
                pred_image = vutils.make_grid(predictions[0,0,::,::]==i+1, normalize=True, scale_each=True)
                writer.add_image('Output class '+str(i+1), output_image, n_iter)
                writer.add_image('Prediction class '+str(i+1), pred_image, n_iter)
                gt_image = vutils.make_grid(sample_batched['gt'][i][0,::,::], normalize=True, scale_each=True)
                writer.add_image('gt class '+str(i+1), gt_image, n_iter)
            bg_gt = get_bg_gt(sample_batched['gt'])
            bg_gt_image = vutils.make_grid(bg_gt[0,::,::], normalize=True, scale_each=True)
            writer.add_image('gt background', bg_gt_image, n_iter)
                
            # net parameters histograms
            for name, param in net.named_parameters():
                writer.add_histogram(name, param.clone().cpu().data.numpy(), n_iter)

                
writer.export_scalars_to_json("./all_scalars.json")
writer.close()

print "training complete"
    



  0%|          | 0/30 [00:00<?, ?it/s][A[A
[A

  3%|▎         | 1/30 [01:50<53:13, 110.13s/it][A[A

  7%|▋         | 2/30 [03:41<51:37, 110.63s/it][A[A

 10%|█         | 3/30 [05:30<49:26, 109.88s/it][A[A

 13%|█▎        | 4/30 [07:19<47:35, 109.84s/it][A[A

 17%|█▋        | 5/30 [09:10<45:51, 110.06s/it][A[A

 20%|██        | 6/30 [11:01<44:10, 110.43s/it][A[A

 23%|██▎       | 7/30 [12:52<42:23, 110.60s/it][A[A

In [59]:
writer.close()