# MRI NIH : UNet

To vizualize on tensor board :

tensorboard --logdir < path to runs directory >

## Imports

In [31]:
import sys
import os
sys.path.insert(0, '../')
os.environ["CUDA_VISIBLE_DEVICES"] = '1' # number of the GPU to use if cuda is enabled
from dataset import *
import transforms
import json
from torchvision import transforms as torch_transforms
import torchvision.utils as vutils
from tensorboardX import SummaryWriter
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from tqdm import tqdm
import numpy as np
import torch
from models import UNet
import losses
from metrics import *
from datetime import datetime

## Paths

In [16]:
path_parameters = '/Users/frpau_local/Documents/nih/data/luisa_with_gt/parameters.json'
path_filenames_training = "/Users/frpau_local/Documents/nih/data/luisa_with_gt/filenames_training.txt"
path_filenames_validation = "/Users/frpau_local/Documents/nih/data/luisa_with_gt/filenames_validation.txt"
path_runs = "/Users/frpau_local/Documents/nih/code/runs"

## Device

In [17]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print "working on {}".format(device)
if torch.cuda.is_available():
    print "using GPU number {}".format(gpu_number)

working on cpu


## Hyperparameters

In [18]:
parameters = json.load(open(path_parameters))
print json.dumps(parameters, indent=4)

{
    "training": {
        "poly_schedule_p": 0.9, 
        "optimizer": "adam", 
        "write_param_histograms": false, 
        "learning_rate": 1e-06, 
        "batch_size": 10, 
        "loss_function": "dice", 
        "lr_schedule": "poly", 
        "nb_epochs": 10
    }, 
    "transforms": {
        "flip_rate": 0.5, 
        "ratio_range": [
            0.75, 
            1.25
        ], 
        "elastic_rate": 0.3, 
        "sigma_range": [
            3.5, 
            4
        ], 
        "alpha_range": [
            10, 
            15
        ], 
        "crop_size": [
            224, 
            128
        ], 
        "max_angle": 20, 
        "scale_range": [
            0.5, 
            1
        ]
    }, 
    "net": {
        "drop_rate": 0.3, 
        "bn_momentum": 0.1
    }
}


## Create dataset

In [19]:
# defining transormations
toTensor = transforms.ToTensor()
toPIL = transforms.ToPIL()
randomVFlip = transforms.RandomVerticalFlip()
randomResizedCrop = transforms.RandomResizedCrop(parameters["transforms"]["crop_size"], scale=parameters["transforms"]["scale_range"], ratio=parameters["transforms"]["ratio_range"])
randomRotation = transforms.RandomRotation(parameters["transforms"]["max_angle"])
elasticTransform = transforms.ElasticTransform(parameters["transforms"]["alpha_range"], parameters["transforms"]["sigma_range"], parameters["transforms"]["elastic_rate"])
centerCrop = transforms.CenterCrop2D(parameters["transforms"]["crop_size"])

# creating composed transformation
# Composed transformations should always contain toPIL as first transformations (since other transforamtions are made to work on PIL images) and toTensor as last transforamtion (since the network is excpecting tensors as input). 
composed = torch_transforms.Compose([toPIL,randomVFlip,randomRotation,randomResizedCrop, elasticTransform, toTensor])
crop_val = torch_transforms.Compose([toPIL, centerCrop, toTensor])

# creating datasets
# Datasets should be created with at least a toTensor transformation or a composed transformation with toTensor as last transformation since the network is excpecting tensors as input.
training_dataset = MRI2DSegDataset(path_filenames_training, transform = composed)
validation_dataset = MRI2DSegDataset(path_filenames_validation, transform = crop_val)

# creating data loaders
training_dataloader = DataLoader(training_dataset, batch_size=parameters["training"]["batch_size"], shuffle=True)
validation_dataloader = DataLoader(validation_dataset, batch_size=parameters["training"]["batch_size"], shuffle=True)

## Define net

In [20]:
class DownConv(nn.Module):
    def __init__(self, in_feat, out_feat, drop_rate=0.4, bn_momentum=0.1):
        super(DownConv, self).__init__()
        self.conv1 = nn.Conv2d(in_feat, out_feat, kernel_size=3, padding=1)
        self.conv1_bn = nn.BatchNorm2d(out_feat, momentum=bn_momentum)
        self.conv1_drop = nn.Dropout2d(drop_rate)
        
        self.conv2 = nn.Conv2d(out_feat, out_feat, kernel_size=3, padding=1)
        self.conv2_bn = nn.BatchNorm2d(out_feat, momentum=bn_momentum)
        self.conv2_drop = nn.Dropout2d(drop_rate)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.conv1_bn(x)
        x = self.conv1_drop(x)
        
        x = F.relu(self.conv2(x))
        x = self.conv2_bn(x)
        x = self.conv2_drop(x)        
        return x
    
class UpConv(nn.Module):
    def __init__(self, in_feat, out_feat, drop_rate=0.4, bn_momentum=0.1):
        super(UpConv, self).__init__()
        self.up1 = nn.Upsample(scale_factor=2, mode='bilinear')
        self.downconv = DownConv(in_feat, out_feat, drop_rate, bn_momentum)
    
    def forward(self, x, y):
        x = self.up1(x)
        x = torch.cat([x, y], dim=1)
        x = self.downconv(x)
        return x

class UNet(nn.Module):
    def __init__(self, drop_rate=0.4, bn_momentum=0.1):
        super(UNet, self).__init__()
        
        #Downsampling path
        self.conv1 = DownConv(1, 64, drop_rate, bn_momentum)
        self.mp1 = nn.MaxPool2d(2)

        self.conv2 = DownConv(64, 128, drop_rate, bn_momentum)
        self.mp2 = nn.MaxPool2d(2)    

        self.conv3 = DownConv(128, 256, drop_rate, bn_momentum)
        self.mp3 = nn.MaxPool2d(2)          

        # Bottom
        self.conv4 = DownConv(256, 256, drop_rate, bn_momentum)

        # Upsampling path
        self.up1 = UpConv(512, 256, drop_rate, bn_momentum)
        self.up2 = UpConv(384, 128, drop_rate, bn_momentum)
        self.up3 = UpConv(192, 64, drop_rate, bn_momentum)

        self.conv9 = nn.Conv2d(64, 4, kernel_size=3, padding=1)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.mp1(x1)

        x3 = self.conv2(x2)
        x4 = self.mp2(x3)
        
        x5 = self.conv3(x4)
        x6 = self.mp3(x5)    
        
        # Bottom
        x7 = self.conv4(x6)
        
        # Up-sampling
        x8 = self.up1(x7, x5)
        x9 = self.up2(x8, x3)
        x10 = self.up3(x9, x1)
        
        x11 = self.conv9(x10)
        preds = F.sigmoid(x11)        
        
        return preds


## Dice loss

In [21]:
def get_bg_gt(gts):
    gt_size = gts[0].size()
    bg_gt = torch.ones([gt_size[0],1,gt_size[2], gt_size[3]])
    zeros = torch.zeros([gt_size[0],1,gt_size[2], gt_size[3]])
    for gt in gts:
        bg_gt = torch.max(bg_gt - gt, zeros)
    return bg_gt

def dice_loss(pred, gts):
    eps = 0.0000000001
    loss = 1.
    intersections = []
    unions = []
    weights = []

    for i in range(len(gts)):
        weights.append(1/(torch.sum(gts[i]))**2+eps)
        intersections.append((pred[::,i,::,::].data.contiguous().view(-1)*gts[i].view(-1)).sum())
        unions.append(torch.sum(pred[::,i,::,::])+torch.sum(gts[i]))

    loss = loss-2*sum([w*i for w,i in zip(weights, intersections)])/(sum([w*u for w,u in zip(weights, unions)])+eps)
    return loss

## Metrics

In [22]:
def numeric_score(pred, gts):
    """Computation of statistical numerical scores:

    * FP = False Positives
    * FN = False Negatives
    * TP = True Positives
    * TN = True Negatives

    return: tuple (FP, FN, TP, TN)
    """
    np_pred = pred.numpy()
    np_gts = [get_bg_gt(gts)]+gts
    np_gts = [gt.numpy() for gt in np_gts]
    FP = []
    FN = []
    TP = []
    TN = []
    for i in range(len(np_gts)):
        FP.append(np.float(np.sum((np_pred == i) & (np_gts[i] == 0))))
        FN.append(np.float(np.sum((np_pred != i) & (np_gts[i] == 1))))
        TP.append(np.float(np.sum((np_pred == i) & (np_gts[i] == 1))))
        TN.append(np.float(np.sum((np_pred != i) & (np_gts[i] == 0))))
    return FP, FN, TP, TN


def precision_score(FP, FN, TP, TN):
    # PPV
    precision = []
    for i in range(len(FP)):
        if (TP[i] + FP[i]) <= 0.0:
            precision.append(0.0)
        else:
            precision.append(np.divide(TP[i], TP[i] + FP[i])* 100.0)
    return precision


def recall_score(FP, FN, TP, TN):
    # TPR, sensitivity
    TPR = []
    for i in range(len(FP)):
        if (TP[i] + FN[i]) <= 0.0:
            TPR.append(0.0)
        else:
            TPR.append(np.divide(TP[i], TP[i] + FN[i]) * 100.0)
    return TPR


def specificity_score(FP, FN, TP, TN):
    TNR = []
    for i in range(len(FP)):
        if (TN[i] + FP[i]) <= 0.0:
            TNR.append(0.0)
        else:
            TNR.append(np.divide(TN[i], TN[i] + FP[i]) * 100.0)
    return TNR 


def intersection_over_union(FP, FN, TP, TN):
    IOU = []
    for i in range(len(FP)):
        if (TP[i] + FP[i] + FN[i]) <= 0.0:
            IOU.append(0.0)
        else:
            IOU.append(TP[i] / (TP[i] + FP[i] + FN[i]) * 100.0)
    return IOU


def accuracy_score(FP, FN, TP, TN):
    accuracy = []
    for i in range(len(FP)):
        N = FP[i] + FN[i] + TP[i] + TN[i]
        accuracy.append(np.divide(TP[i] + TN[i], N) * 100.0)
    return accuracy

def write_metrics(writer, predictions, gts, loss, epoch, tag):
    """
    Write scalar metrics to tensorboard

    :param writer: SummaryWriter object to write on
    :param predictions: tensor containing predictions
    :param gts: array of tensors containing ground truth
    :param loss: tensor containing the loss value
    :param epoch: int, number of the iteration
    :param tag: string to specify which dataset is used (e.g. "training" or "validation")
    """
    FP, FN, TP, TN = numeric_score(predictions, gts)
    precision = precision_score(FP, FN, TP, TN)
    recall = recall_score(FP, FN, TP, TN)
    specificity = specificity_score(FP, FN, TP, TN)
    iou = intersection_over_union(FP, FN, TP, TN)
    accuracy = accuracy_score(FP, FN, TP, TN)

    writer.add_scalar("loss_"+tag, loss, epoch)
    for i in range(len(precision)):
        writer.add_scalar("precision_"+str(i)+"_"+tag, precision[i], epoch)
        writer.add_scalar("recall_"+str(i)+"_"+tag, recall[i], epoch)
        writer.add_scalar("specificity_"+str(i)+"_"+tag, specificity[i], epoch)
        writer.add_scalar("intersection_over_union_"+str(i)+"_"+tag, iou[i], epoch)
        writer.add_scalar("accuracy_"+str(i)+"_"+tag, accuracy[i], epoch)


def write_images(writer, input, output, predictions, gts, epoch, tag):
    """
    Write images to tensorboard

    :param writer: SummaryWriter object to write on
    :param input: tensor containing input values
    :param output: tensor containing output values
    :param predictions: tensor containing predictions
    :param gts: array of tensors containing ground truth
    :param epoch: int, number of the iteration
    :param tag: string to specify which dataset is used (e.g. "training" or "validation")
    """
    input_max = max(torch.max(input), 0.00000001)
    input_image = vutils.make_grid(input/input_max, normalize=True)
    writer.add_image('Input '+tag, input_image, epoch)
    for i in range(len(gts)):
        output_image = vutils.make_grid(output[i,::,::], normalize=True)
        writer.add_image('Output class '+str(i)+' '+tag, output_image, epoch)
        pred_image = vutils.make_grid(predictions==i, normalize=False)
        writer.add_image('Prediction class '+str(i)+' '+tag, pred_image, epoch)
        gt_image = vutils.make_grid(gts[i], normalize=True)
        writer.add_image('GT class '+str(i)+' '+tag, gt_image, epoch) 

In [23]:
net = UNet(drop_rate=parameters["net"]["drop_rate"], bn_momentum=parameters["net"]["bn_momentum"])
net = net.to(device)

## loss, optimizer and lr schedule

In [24]:
if parameters["training"]["optimizer"]=="sgd":
    optimizer = optim.SGD(net.parameters(), lr=parameters["training"]['learning_rate'], momentum=parameters["training"]['sgd_momentum'])
elif parameters["training"]["optimizer"]=="adam":
    optimizer = optim.Adam(net.parameters(), lr=parameters["training"]['learning_rate'])
    
if parameters["training"]["loss_function"]=="dice":
    loss_function = dice_loss
    
if parameters["training"]["lr_schedule"]=="cosine":
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, parameters["training"]["nb_epochs"])
elif parameters["training"]["lr_schedule"]=="poly":
    if not "poly_schedule_p" in parameters["training"]:
        parameters["training"]['poly_schedule_p']=0.9
    lr_lambda = lambda epoch: (1-epoch/parameters["training"]["nb_epochs"])**parameters["training"]["poly_schedule_p"]
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

## training

In [32]:
current_time = datetime.now().strftime('%b%d_%H-%M-%S')
writer = SummaryWriter(log_dir=os.path.join(path_runs, current_time))
writer.add_text("hyperparameters", json.dumps(parameters)) # add hyperparameters to description
last_run_dir = os.listdir(path_runs)[-1] # get the name of the directory of the current run (to save the model in that directory)

best_loss = 0.
batch_length = len(training_dataloader)

for epoch in tqdm(range(parameters["training"]["nb_epochs"])):
    
    loss_sum = 0.
    scheduler.step()
    net.train()
    
    writer.add_scalar("learning_rate", scheduler.get_lr()[0], epoch)
    
    for i_batch, sample_batched in enumerate(training_dataloader):
        optimizer.zero_grad()
        input = sample_batched['input'].to(device)
        output =  net(input)
        gts = [get_bg_gt(sample_batched['gt'])]+sample_batched['gt'] # make an array of ground truths (with the computed background gt mask)
        loss = loss_function(output, [gt.to(device) for gt in gts])
        loss.backward()
        optimizer.step()
        loss_sum += loss.item()/batch_length
    
    predictions = torch.argmax(output, 1, keepdim=True).to("cpu") # get predicted class for each pixel (on cpu to compute metrics)
        
    # metrics
    write_metrics(writer, predictions, gts, loss_sum, epoch, "training")

    # images
    input_for_image = sample_batched['input'][0]
    output_for_image = output[0,::,::,::]
    pred_for_image = predictions[0,0,::,::]
    gts_for_image = [gt[0,::,::] for gt in gts]

    write_images(writer, input_for_image, output_for_image, pred_for_image, gts_for_image, epoch, "training")

    if "write_param_histograms" in parameters["training"].keys() and parameters["training"]["write_param_histograms"]:
        # write net parameters histograms (make the training significantly slower)
        for name, param in net.named_parameters():
            writer.add_histogram(name, param.clone().cpu().data.numpy(), epoch)
    
    ## Validation ##  

    loss_sum = 0.
    net.eval()

    for i_batch, sample_batched in enumerate(validation_dataloader):
        output =  net(sample_batched['input'].to(device))
        gts = [get_bg_gt(sample_batched['gt'])]+sample_batched['gt']
        loss = loss_function(output, [gt.to(device) for gt in gts])
        loss_sum += loss.item()/len(validation_dataloader)

    predictions = torch.argmax(output, 1, keepdim=True).to("cpu")


    if loss_sum < best_loss:
        torch.save(net, path_runs+"/"+last_run_dir+"/best_model.pt")

    # metrics
    write_metrics(writer, predictions, gts, loss_sum, epoch, "validation")

    #images
    input_for_image = sample_batched['input'][0]
    output_for_image = output[0,::,::,::]
    pred_for_image = predictions[0,0,::,::]
    gts_for_image = [gt[0,::,::] for gt in gts]

    write_images(writer, input_for_image, output_for_image, pred_for_image, gts_for_image, epoch, "validation")

                
writer.export_scalars_to_json(path_runs+"/"+last_run_dir+"/all_scalars.json")
writer.close()

torch.save(net, path_runs+"/"+last_run_dir+"/final_model.pt")

print "training complete, model saved at "+path_runs+"/"+last_run_dir+"/final_model.pt"
    





  0%|          | 0/10 [00:00<?, ?it/s][A[A[A[A


[A[A[A

KeyboardInterrupt: 

In [None]:
writer.close()