# MRI NIH : UNet

To vizualize on tensor board :

tensorboard --logdir runs

## Imports

In [2]:
import torch
from dataset import MRI2DSegDataset
import transforms
import json
from torchvision import transforms as torch_transforms
import torchvision.utils as vutils
from tensorboardX import SummaryWriter
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
import torch.optim as optim
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np

## Hyperparameters

In [3]:
parameters = json.load(open('/Users/frpau_local/Documents/nih/data/luisa_with_gt/parameters.json'))
print json.dumps(parameters, indent=4)

{
    "training": {
        "learning_rate": 0.0001, 
        "optimizer": "sgd", 
        "loss_function": "dice", 
        "sgd_momentum": 0.9
    }, 
    "transforms": {
        "flip_rate": 0.5, 
        "ratio_range": [
            0.75, 
            1.25
        ], 
        "elastic_rate": 0.3, 
        "sigma_range": [
            3.5, 
            4
        ], 
        "alpha_range": [
            10, 
            15
        ], 
        "crop_size": [
            224, 
            128
        ], 
        "max_angle": 20, 
        "scale_range": [
            0.5, 
            1
        ]
    }
}


## Create dataset

In [4]:
toTensor = transforms.ToTensor()
toPIL = transforms.ToPIL()
randomVFlip = transforms.RandomVerticalFlip()
randomResizedCrop = transforms.RandomResizedCrop(parameters["transforms"]["crop_size"], scale=parameters["transforms"]["scale_range"], ratio=parameters["transforms"]["ratio_range"])
randomRotation = transforms.RandomRotation(parameters["transforms"]["max_angle"])
elasticTransform = transforms.ElasticTransform(parameters["transforms"]["alpha_range"], parameters["transforms"]["sigma_range"], parameters["transforms"]["elastic_rate"])

composed = torch_transforms.Compose([toPIL,randomVFlip,randomRotation,randomResizedCrop, elasticTransform, toTensor])

dataset = MRI2DSegDataset("/Users/frpau_local/Documents/nih/data/luisa_with_gt/filenames_csf_gm_nawm.txt", transform = composed)

  from ._conv import register_converters as _register_converters


## Define net

In [5]:
class DownConv(nn.Module):
    def __init__(self, in_feat, out_feat, drop_rate=0.4, bn_momentum=0.1):
        super(DownConv, self).__init__()
        self.conv1 = nn.Conv2d(in_feat, out_feat, kernel_size=3, padding=1)
        self.conv1_bn = nn.BatchNorm2d(out_feat, momentum=bn_momentum)
        self.conv1_drop = nn.Dropout2d(drop_rate)
        
        self.conv2 = nn.Conv2d(out_feat, out_feat, kernel_size=3, padding=1)
        self.conv2_bn = nn.BatchNorm2d(out_feat, momentum=bn_momentum)
        self.conv2_drop = nn.Dropout2d(drop_rate)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.conv1_bn(x)
        x = self.conv1_drop(x)
        
        x = F.relu(self.conv2(x))
        x = self.conv2_bn(x)
        x = self.conv2_drop(x)        
        return x
    
class UpConv(nn.Module):
    def __init__(self, in_feat, out_feat, drop_rate=0.4, bn_momentum=0.1):
        super(UpConv, self).__init__()
        self.up1 = nn.Upsample(scale_factor=2, mode='bilinear')
        self.downconv = DownConv(in_feat, out_feat, drop_rate, bn_momentum)
    
    def forward(self, x, y):
        x = self.up1(x)
        x = torch.cat([x, y], dim=1)
        x = self.downconv(x)
        return x

class UNet(nn.Module):
    def __init__(self, drop_rate=0.4, bn_momentum=0.1):
        super(UNet, self).__init__()
        
        #Downsampling path
        self.conv1 = DownConv(1, 64, drop_rate, bn_momentum)
        self.mp1 = nn.MaxPool2d(2)

        self.conv2 = DownConv(64, 128, drop_rate, bn_momentum)
        self.mp2 = nn.MaxPool2d(2)    

        self.conv3 = DownConv(128, 256, drop_rate, bn_momentum)
        self.mp3 = nn.MaxPool2d(2)          

        # Bottom
        self.conv4 = DownConv(256, 256, drop_rate, bn_momentum)

        # Upsampling path
        self.up1 = UpConv(512, 256, drop_rate, bn_momentum)
        self.up2 = UpConv(384, 128, drop_rate, bn_momentum)
        self.up3 = UpConv(192, 64, drop_rate, bn_momentum)

        self.conv9 = nn.Conv2d(64, 4, kernel_size=3, padding=1)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.mp1(x1)

        x3 = self.conv2(x2)
        x4 = self.mp2(x3)
        
        x5 = self.conv3(x4)
        x6 = self.mp3(x5)    
        
        # Bottom
        x7 = self.conv4(x6)
        
        # Up-sampling
        x8 = self.up1(x7, x5)
        x9 = self.up2(x8, x3)
        x10 = self.up3(x9, x1)
        
        x11 = self.conv9(x10)
        preds = F.sigmoid(x11)        
        
        return preds


In [6]:
def dice(pred, gt):
    return -(2*torch.dot(pred, Variable(gt))+0.0000000001)/(torch.sum(pred)+torch.sum(gt)+0.0000000001)

def dice_loss(pred, gts):
    pred_size = pred.size()
    bg_gt = torch.ones([pred_size[0],1,pred_size[2], pred_size[3]])
    zeros = torch.zeros([pred_size[0],1,pred_size[2], pred_size[3]])
    for gt in gts:
        bg_gt = torch.max(bg_gt - gt, zeros)
    loss = dice(pred[::,0,::,::], bg_gt)
    for i in range(len(gts)):
        loss = loss+dice(pred[::,i+1,::,::], gts[i])
    return loss

In [7]:
net = UNet()
dataloader = DataLoader(dataset, batch_size=parameters["training"]["batch_size"], shuffle=True, num_workers=4)
if parameters["training"]["optimizer"]=="sgd":
    optimizer = optim.SGD(net.parameters(), lr=parameters["training"]['learning_rate'], momentum=parameters["training"]['sgd_momentum'])
if parameters["training"]["loss_function"]=="dice":
    loss_function = dice_loss
writer = SummaryWriter()

## training

In [9]:
# add hyperparameters to description
writer.add_text("hyperparameters", json.dumps(parameters, indent=4))

for epoch in tqdm(range(10)):

    for i_batch, sample_batched in enumerate(dataloader):
        output =  net(Variable(sample_batched['input']))
        loss = loss_function(output, sample_batched['gt'])
        #print loss
        
        # Visualization
        if not i_batch%5:
            n_iter = epoch*len(dataloader)+i_batch
            
            # loss
            writer.add_scalar('data/loss', loss.data[0], n_iter)
            
            #images
            input_image = vutils.make_grid(sample_batched['input'][0]/torch.max(sample_batched['input'][0]), normalize=True, scale_each=True)
            writer.add_image('Input image', input_image, n_iter)
            output_bg = vutils.make_grid(output[0,0,::,::].data, normalize=True, scale_each=True)
            writer.add_image('Output background', output_bg, n_iter)
            for i in range(len(sample_batched['gt'])):
                output_image = vutils.make_grid(output[0,i+1,::,::].data, normalize=True, scale_each=True)
                writer.add_image('Output class '+str(i), output_image, n_iter)
                gt_image = vutils.make_grid(sample_batched['gt'][i][0,::,::], normalize=True, scale_each=True)
                writer.add_image('gt class '+str(i), gt_image, n_iter)
                
            # net parameters histograms
            for name, param in net.named_parameters():
                writer.add_histogram(name, param.clone().cpu().data.numpy(), n_iter)

                
writer.export_scalars_to_json("./all_scalars.json")
writer.close()

print "training complete"
    

100%|██████████| 10/10 [08:56<00:00, 54.04s/it]
