# CS4MS: U-Net Brain MR Multiple Sklerosis Segmentation (OpenMS Dataset)

*This notebook is based on this GitHub repository:
https://github.com/usuyama/pytorch-unet*

*and the Open MS dataset from the Laboratory of Imaging Technologies at the University of Ljubljana:
Lesjak, Žiga, et al. “A novel public MR image dataset of multiple sclerosis patients with lesion segmentations based on multi-rater consensus.” Neuroinformatics (2017): 1-13.*

https://github.com/muschellij2/open_ms_data<br>
http://lit.fe.uni-lj.si/tools.php?lang=eng

Run the code below to download the Open MS dataset from GitHub

In [0]:
# install urlpath, see below
!pip install urlpath
# import required python modules
# pathlib helps with local paths
from pathlib import Path
# same for urls
from urlpath import URL
# urllib is a library for downloading things
import urllib.request

# create data folder in colab environment
data_root = Path('data')
data_root.mkdir(exist_ok=True)

# github repo with the multiple sklerosis data
open_ms_url = URL('https://github.com/muschellij2/open_ms_data/raw/master/cross_sectional/MNI/')

# names of the nifty files in each directory
nifty_file_names = {'input_volume':'T1_N4_noneck_reduced_winsor_regtoFLAIR_brain_N4_regtoMNI.nii.gz','mask':'GOLD_STANDARD_N4_noneck_reduced_winsor_regtoFLAIR_regtoMNI.nii.gz'}
overwrite = True

# download t1 and segmentation nifty files fro each patient 01-30
for pat_id in range(1,31):
  patient_folder = f'patient{pat_id:02d}'
  dest_local_path = data_root / patient_folder
  dest_local_path.mkdir(exist_ok=True)
  for ftype, fname in nifty_file_names.items():
    src_url = open_ms_url / patient_folder / fname
    dest_local_file = dest_local_path / fname
    if not dest_local_file.is_file() or overwrite:
      dest_local_file = (dest_local_path / ftype).with_suffix('.nii.gz')
      print(f'Downloading {src_url.parent.name}/{src_url.name} to {dest_local_file}')
      urllib.request.urlretrieve(str(src_url), dest_local_file)

print('Successfully downloaded all files!')

Create a dataloader from the downloaded nifty files

In [0]:
# get module for loading nifty files (a file format similar to DICOM)
!pip install SimpleITK
import SimpleITK as sitk
import numpy as np

# read nifty file from path and convert it to a numpy array using the library we just loaded
def load_np_from_nifty(path):
    """load nifty file from path and convert to numpy array"""
    return sitk.GetArrayFromImage(sitk.ReadImage(str(path)))

In [0]:
# lets load a pair of volumes and find more about them
test_input = load_np_from_nifty(data_root / 'patient01/input_volume.nii.gz')
test_mask = load_np_from_nifty(data_root / 'patient01/mask.nii.gz')
print('shape of input volume: ', test_input.shape)
print('shape of mask volume: ', test_mask.shape)
print('min and max values of input volume: ',test_input.min(), test_input.max())
print('unique values of mask:', np.unique(test_mask))

In [0]:
# let's print the central slices of both
import matplotlib.pyplot as plt 

def show_img(img):
  fig = plt.figure()
  plt.imshow(img, cmap='gray')
  plt.show()

for volume in [test_input, test_mask]:
  center_slice_num = (np.array(volume.shape) / 2).astype(int)
  show_img(volume[center_slice_num[0],:,:])
  show_img(volume[:,center_slice_num[1],:])
  show_img(volume[:,:,center_slice_num[2]])

In [0]:
# we select the second dimension for slices to have square input images
show_img(test_input[:,center_slice_num[1],:])

In [0]:
# test a split of the volume into slices
test_slices = [test_input[:,slice_num,:] for slice_num in range(test_input.shape[1])]
print(np.array(test_slices).shape)
show_img(test_slices[center_slice_num[1]])

In [0]:
# next we build a dataset with the volumes
from torch.utils.data import Dataset

INPUT_SIZE = 182
NUM_CLASSES = 2
NUM_CHANNELS = 1

class OpenMSDataset(Dataset):
    def __init__(self, patient_selection=np.arange(1,31), data_root=Path('data'), transform=None, target_transform=None):
        # load volumes from selected patients from disk
        input_volumes = [load_np_from_nifty(data_root / f'patient{pat_id:02d}/input_volume.nii.gz') for pat_id in patient_selection]
        # preprocess each volume (normalize and split into slices)
        input_volumes = [self._prepcocess_volume(volume) for volume in input_volumes]
        # make one big list of slices instead of seperate volumes
        self.input_slices = np.concatenate(input_volumes)
        # load masks
        self.target_masks = [load_np_from_nifty(data_root / f'patient{pat_id:02d}/mask.nii.gz') for pat_id in patient_selection] 
        # split masks into slices and bring them in one big list just like the input slices
        self.target_masks = np.concatenate([self._split_volume(volume) for volume in self.target_masks])
        self.target_masks = [mask_slice for mask_slice in self.target_masks]
        self.transform = transform
        self.target_transform = target_transform
    
    def _prepcocess_volume(self, volume):
      volume = self._normalize_volume(volume)
      volume = self._split_volume(volume)
      return volume

    def _normalize_volume(self, volume):
      return (volume - np.min(volume)) / (np.max(volume) - np.min(volume))

    def _split_volume(self, volume):
      return np.array([volume[:,slice_num,:].reshape(NUM_CHANNELS,INPUT_SIZE,INPUT_SIZE) for slice_num in range(volume.shape[1])])

    def __len__(self):
        return len(self.input_slices)
    
    def __getitem__(self, idx):        
        image = torch.from_numpy(self.input_slices[idx])
        mask = torch.from_numpy(self.target_masks[idx]).float()
        # one hot encoding the mask to have separate tensor for background and mask
        background = (mask == 0).float()
        lesion = (mask == 1).float()
        mask = torch.cat((background,lesion),dim=0)

        if self.transform:
          image = self.transform(image)
        if self.target_transform:
          mask = self.target_transform(mask)
        return [image, mask]

In [0]:
# create dataset with all images for testing
all_dataset = OpenMSDataset()

In [0]:
# get a random slice wit mask
test_input_slice, test_mask_slice = all_dataset[int(np.random.rand()*len(all_dataset))]

print(test_input_slice.shape, test_mask_slice.shape)

# print input and corresponding mask
show_img(test_input_slice.reshape(INPUT_SIZE,INPUT_SIZE))
show_img(test_mask_slice[1].reshape(INPUT_SIZE,INPUT_SIZE))

# run this cell multiple times seet different image pairs

In [0]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

train_set = OpenMSDataset(patient_selection=np.arange(1,21))
val_set = OpenMSDataset(patient_selection=np.arange(21,26))

image_datasets = {
    'train': train_set, 'val': val_set
}

batch_size = 25

dataloaders = {
    'train': DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0),
    'val': DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=0),
}

dataset_sizes = {
    x: len(image_datasets[x]) for x in image_datasets.keys()
}

dataset_sizes

In [0]:
import torchvision.utils

# Get a batch of training data
inputs, masks = next(iter(dataloaders['train']))

print(inputs.shape, masks.shape)
for x in [inputs.numpy(), masks.numpy()]:
    print(x.min(), x.max(), x.mean(), x.std())

# print input and corresponding mask
show_img(inputs[3].reshape(INPUT_SIZE,INPUT_SIZE))
show_img(masks[3][1].reshape(INPUT_SIZE,INPUT_SIZE))

**Define U-Net network architecture**

Original idea and paper from  Olaf Ronneberger et al. from the University of Freiburg: https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/

![alt text](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/u-net-architecture.png)

Nice medium post on this:
https://towardsdatascience.com/understanding-semantic-segmentation-with-unet-6be4f42d4b47

In [0]:
# define parts of unet
# taken from https://github.com/milesial/Pytorch-UNet

import torch
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

In [0]:
# define U-Net

import torch.nn.functional as F


class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.up1 = Up(1024, 256, bilinear)
        self.up2 = Up(512, 128, bilinear)
        self.up3 = Up(256, 64, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

In [0]:
from torchsummary import summary
import torch
import torch.nn as nn

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = UNet(n_channels=NUM_CHANNELS, n_classes=NUM_CLASSES, bilinear=True)
model = model.to(device)

summary(model, input_size=(NUM_CHANNELS, INPUT_SIZE,INPUT_SIZE))

In [0]:
# define DICE loss
import torch
from torch.autograd import Function

class DiceCoeff(Function):
    """Dice coeff for individual examples"""

    def forward(self, input, target):
        self.save_for_backward(input, target)
        eps = 0.0001
        self.inter = torch.dot(input.view(-1), target.view(-1))
        self.union = torch.sum(input) + torch.sum(target) + eps

        t = (2 * self.inter.float() + eps) / self.union.float()
        return t

    # This function has only a single output, so it gets only one gradient
    def backward(self, grad_output):

        input, target = self.saved_variables
        grad_input = grad_target = None

        if self.needs_input_grad[0]:
            grad_input = grad_output * 2 * (target * self.union - self.inter) \
                         / (self.union * self.union)
        if self.needs_input_grad[1]:
            grad_target = None

        return grad_input, grad_target


def dice_coeff(input, target):
    """Dice coeff for batches"""
    if input.is_cuda:
        s = torch.FloatTensor(1).cuda().zero_()
    else:
        s = torch.FloatTensor(1).zero_()

    for i, c in enumerate(zip(input, target)):
        s = s + DiceCoeff().forward(c[0], c[1])

    return s / (i + 1)


def dice_loss(pred, target, smooth = 1.):
    pred = pred.contiguous()
    target = target.contiguous()    

    intersection = (pred * target).sum(dim=2).sum(dim=2)
    
    loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))
    
    return loss.mean()

In [0]:
from collections import defaultdict
import torch.nn.functional as F

def calc_loss(pred, target, metrics, bce_weight=0.5):
    bce = F.binary_cross_entropy_with_logits(pred, target)
        
    pred = F.sigmoid(pred)
    dice = dice_loss(pred, target)
    
    loss = bce * bce_weight + dice * (1 - bce_weight)
    
    metrics['bce'] += bce.data.cpu().numpy() * target.size(0)
    metrics['dice'] += dice.data.cpu().numpy() * target.size(0)
    metrics['loss'] += loss.data.cpu().numpy() * target.size(0)
    
    return loss

def print_metrics(metrics, epoch_samples, phase):    
    outputs = []
    for k in metrics.keys():
        outputs.append("{}: {:4f}".format(k, metrics[k] / epoch_samples))
        
    print("{}: {}".format(phase, ", ".join(outputs)))    

def train_model(model, optimizer, scheduler, num_epochs=25):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 1e10

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        since = time.time()

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                for param_group in optimizer.param_groups:
                    print("LR", param_group['lr'])
                    
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            metrics = defaultdict(float)
            epoch_samples = 0
            
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)             

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = calc_loss(outputs, labels, metrics)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                epoch_samples += inputs.size(0)

            print_metrics(metrics, epoch_samples, phase)
            epoch_loss = metrics['loss'] / epoch_samples

            # deep copy the model
            if phase == 'val' and epoch_loss < best_loss:
                print("saving best model")
                best_loss = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())

        time_elapsed = time.time() - since
        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val loss: {:4f}'.format(best_loss))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

DICE score is the most common metric for segmentation tasks

![alt text](https://miro.medium.com/max/486/1*yUd5ckecHjWZf6hGrdlwzA.png)

more info here:


*   https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient
*   https://towardsdatascience.com/metrics-to-evaluate-your-semantic-segmentation-model-6bcb99639aa2

In [0]:
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
import time
import copy

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

model = UNet(n_channels=NUM_CHANNELS, n_classes=NUM_CLASSES, bilinear=True).to(device)

# Observe that all parameters are being optimized
optimizer_ft = optim.Adam(model.parameters(), lr=1e-4)

exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=25, gamma=0.1)

model = train_model(model, optimizer_ft, exp_lr_scheduler, num_epochs=2)

In [0]:
# prediction

import math

model.eval()   # Set model to evaluate mode

test_dataset = OpenMSDataset(patient_selection=np.arange(26,31))
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=0)

In [0]:
inputs, labels = next(iter(test_loader))
inputs = inputs.to(device)
labels = labels.to(device)

pred = model(inputs)
inputs = inputs.data.cpu().numpy()
labels = labels.data.cpu().numpy()
pred = pred.data.cpu().numpy()
print(pred.shape)

show_img(inputs[0][0])
show_img(labels[0][1])
show_img(pred[0][1]>0.5)