# U-Net

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import nibabel as nib 
import cv2
import glob
import random
import time

data_dir = '/data/projects/ccm/MRIs'
id_1     = 1
id_2     = 15

image = nib.load(f'{data_dir}/train/img_subj{id_1:>03}_sl_{id_2:>03}.nii').get_data()
msk = nib.load(f'{data_dir}/train/roi_subj{id_1:>03}_sl_{id_2:>03}.nii').get_data()
msk[msk==2] = 0

fig, ax = plt.subplots(1, 2, figsize=(15, 10))
ax[0].imshow(image, cmap='gray')
ax[1].imshow(msk, cmap='gray')

In [None]:
print(image.shape)

In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn

from torch.utils.data import Dataset, DataLoader

In [None]:
class conv_block(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(conv_block,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )


    def forward(self,x):
        x = self.conv(x)
        return x

class up_conv(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(up_conv,self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self,x):
        x = self.up(x)
        return x

In [None]:
class single_conv(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(single_conv,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self,x):
        x = self.conv(x)
        return x

In [None]:
class Unet(nn.Module):
    def __init__(self,img_ch=1,output_ch=1):
        super(Unet,self).__init__()
        
        self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)

        self.Conv1 = conv_block(ch_in=img_ch,ch_out=64)
        self.Conv2 = conv_block(ch_in=64,ch_out=128)
        self.Conv3 = conv_block(ch_in=128,ch_out=256)
        self.Conv4 = conv_block(ch_in=256,ch_out=512)
        self.Conv5 = conv_block(ch_in=512,ch_out=1024)

        self.Up5 = up_conv(ch_in=1024,ch_out=512)
        self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)

        self.Up4 = up_conv(ch_in=512,ch_out=256)
        self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
        
        self.Up3 = up_conv(ch_in=256,ch_out=128)
        self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
        
        self.Up2 = up_conv(ch_in=128,ch_out=64)
        self.Up_conv2 = conv_block(ch_in=128, ch_out=64)

        self.Conv_1x1 = single_conv(ch_in = 64, ch_out = output_ch)


    def forward(self,x):
        # encoding path
        x1 = self.Conv1(x)

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)
        
        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        d5 = torch.cat((x4,d5),dim=1)        
        d5 = self.Up_conv5(d5)
        
        d4 = self.Up4(d5)
        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_conv2(d2)

        d1 = self.Conv_1x1(d2)

        return d1

In [None]:
def iou(outputs, labels):
    outputs, labels = torch.sigmoid(outputs) > 0.5, labels > 0.5
    SMOOTH = 1e-6
    B, N, H, W = outputs.shape
    ious = []
    _out, _labs = outputs, labels
    intersection = (_out & _labs).float().sum((1,2))
    union = (_out | _labs).float().sum((1,2))
    iou = (intersection + SMOOTH)/(union + SMOOTH)
    ious.append(iou.mean().item())
    return np.mean(ious)

def DC(outputs, labels):
    outputs, labels = torch.sigmoid(outputs)>0.5, labels > 0.5
    SMOOTH = 1e-6
    B, N, H, W = outputs.shape
    DCs = []
    _out, _labs = outputs, labels
    intersection = (_out & _labs).float().sum((1,2))
    DC = (2 * intersection)/(float(torch.sum(_out)+torch.sum(_labs))+SMOOTH)
    DCs.append(DC.mean().item())
    return np.mean(DCs)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
def apply_vertical_flip(img, mask):
    img = np.flipud(img)
    mask = np.flipud(mask)
    return img, mask

In [None]:
def apply_horizontal_flip(img, mask):
    img = np.fliplr(img)
    mask = np.fliplr(mask)
    return img, mask

In [None]:
def apply_rotation(img, mask, bound_angle = 180):
    angle = random.choice([*range(-bound_angle, bound_angle + 1)])
    shape = img.shape[:2] # image may have 3 dimensions (height, width and channels). We only want the first two
    center = np.array(shape) / 2
    mat = cv2.getRotationMatrix2D(center, angle, 1.0)
    img = cv2.warpAffine(img, mat, shape, flags=cv2.INTER_CUBIC)
    mask = cv2.warpAffine(mask, mat, shape, flags=cv2.INTER_CUBIC)
    return img, mask

In [None]:
class RandomAugmentation:
    augmentations = [apply_horizontal_flip, apply_vertical_flip, apply_rotation]
    
    def __init__(self, max_augment_count = 3):
        if max_augment_count <= len(self.augmentations):
            self.max_augment_count = max_augment_count
        else:
            self.max_augment_count = len(self.augmentations)
            
    def __call__(self, img, mask):
        n_augm = random.randint(0, self.max_augment_count)
        augms = random.sample(self.augmentations, k=n_augm)
        for augm in augms:
            img, mask = augm(img, mask)
            
        return img, mask

In [None]:
class ScanDataset(Dataset):
    def __init__(self, data_dir, augmentations = RandomAugmentation(3)):
        self.imgs_path = sorted(glob.glob(f'{data_dir}/img*'))
        self.masks_path = sorted(glob.glob(f'{data_dir}/roi*'))
        self.augmentations = augmentations
        
    def __len__(self):
        return len(self.imgs_path)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        img_fn = self.imgs_path[idx]
        mask_fn = self.masks_path[idx]
        
        img = nib.load(img_fn).get_data()
        mask = nib.load(mask_fn).get_data()
        mask[mask==2] = 0
        
        if self.augmentations:
            img, mask = self.augmentations(img, mask)
        
        return self.to_tensor(img), self.to_tensor(mask)
    
    def to_tensor(self, mat):
        mat = mat / 255
        if mat.ndim == 2:
            mat = np.expand_dims(mat, 0)
        elif mat.ndim == 3:
            mat = mat.transpose(2,0,1)
        return torch.from_numpy(mat.astype('float32'))

In [None]:
class ScanDataModule():
    def __init__(self, data_dir, batch_size = 4, shuffle = True):
        self.train_dataset = ScanDataset(os.path.join(data_dir, 'train'))
        self.valid_dataset = ScanDataset(os.path.join(data_dir, 'valid'))
        self.test_dataset = ScanDataset(os.path.join(data_dir, 'test'))
        self.batch_size = batch_size
        self.shuffle = shuffle
        
    def train_loader(self):
        return DataLoader(self.train_dataset, batch_size = self.batch_size, shuffle = self.shuffle, drop_last = True)
    
    def valid_loader(self):
        return DataLoader(self.valid_dataset, batch_size = self.batch_size, shuffle = self.shuffle)
    
    def test_loader(self):
        return DataLoader(self.test_dataset, batch_size = self.batch_size, shuffle = self.shuffle)

In [None]:
dataloader = ScanDataModule(data_dir, batch_size = 4, shuffle = True)

In [None]:
imgs, masks = next(iter(dataloader.train_loader()))
fig, ax = plt.subplots(1, 2, figsize=(15, 10))
ax[0].imshow(imgs[0].squeeze(), cmap='gray', alpha = 1) 
ax[1].imshow(masks[0].squeeze(), cmap='gray', alpha=1)

In [None]:
from tqdm import tqdm #allows us to output a smart progress bar

def fit(model, dataloader, epochs=150, lr=3e-4):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = torch.nn.BCEWithLogitsLoss() #loss
    model.to(device) #move model to device, which is the GPU
    hist = {'loss': [], 'iou': [], 'DC': [], 'evaluation_loss': [], 'evaluation_iou': [], 'evaluation_DC': []}
    for epoch in range(1, epochs+1):
        bar = tqdm(dataloader.train_loader()) #creates a smart progress bar for train data
        train_loss, train_iou, train_DC = [], [], [] #create empty lists that are to be filled
        model.train()
        for imgs, masks in bar: #training the model
            imgs, masks = imgs.float().to(device), masks.float().to(device)
            optimizer.zero_grad()
            y_hat = model(imgs)
            y_hat = (y_hat - torch.min(y_hat))/(torch.max(y_hat) - torch.min(y_hat))
            loss = criterion(y_hat, masks)
            loss.backward()    #GRADIENT DECENT, adam optimizer
            optimizer.step()   #updating model with the new gradients
            ious = iou(y_hat, masks)
            DCs = DC(y_hat, masks)
            train_loss.append(loss.item())
            train_iou.append(ious)
            train_DC.append(DCs)
            bar.set_description(f"loss {np.mean(train_loss):.5f} iou {np.mean(train_iou):.5f} DC {np.mean(train_DC):.5f}")
        hist['loss'].append(np.mean(train_loss))
        hist['iou'].append(np.mean(train_iou))
        hist['DC'].append(np.mean(train_DC))
        bar = tqdm(dataloader.valid_loader()) #creates a smart progress bar for evaluation data
        evaluation_loss, evaluation_iou, evaluation_DC = [], [], [] #create empty lists for evaluation loss and iou
        model.eval()
        with torch.no_grad(): #evaluate the model
            for imgs, masks in bar:
                imgs, masks = imgs.float().to(device), masks.float().to(device)
                y_hat = model(imgs)
                y_hat = (y_hat - torch.min(y_hat))/(torch.max(y_hat) - torch.min(y_hat))
                loss = criterion(y_hat, masks)
                ious = iou(y_hat, masks)
                DCs = DC(y_hat, masks)
                evaluation_loss.append(loss.item())
                evaluation_iou.append(ious)
                evaluation_DC.append(DCs)
                bar.set_description(f"evaluation_loss {np.mean(evaluation_loss):.5f} evaluation_iou {np.mean(evaluation_iou):.5f} evaluation_DC {np.mean(evaluation_DC):.5f}")
        hist['evaluation_loss'].append(np.mean(evaluation_loss))
        hist['evaluation_iou'].append(np.mean(evaluation_iou))
        hist['evaluation_DC'].append(np.mean(evaluation_DC))
        print(f"\nEpoch {epoch}/{epochs} loss {np.mean(train_loss):.5f} iou {np.mean(train_iou):.5f} DC {np.mean(train_DC):.5f} evaluation_loss {np.mean(evaluation_loss):.5f} evaluation_iou {np.mean(evaluation_iou):.5f} evaluation_DC {np.mean(evaluation_DC):.5f}")
    return hist

In [None]:
start = time.time()
model = Unet()
hist = fit(model, dataloader, epochs=150)
finish = time.time()
timer = finish - start
print(f'Unet execution time is: {timer} seconds')

In [None]:
import pandas as pd
df = pd.DataFrame(hist)
df.plot(grid=True)
plt.show()

In [None]:
subj, slic = 16, 16

img = nib.load(f'{data_dir}/test/img_subj{subj:>03}_sl_{slic:>03}.nii').get_data()
mask = nib.load(f'{data_dir}/test/roi_subj{subj:>03}_sl_{slic:>03}.nii').get_data()
mask[mask==2] = 0

img_tensor = torch.tensor(img).unsqueeze(0).unsqueeze(0)
mask_tensor = torch.tensor(np.float32(mask)).unsqueeze(0).unsqueeze(0)

model.eval()
with torch.no_grad():
    output = model(img_tensor.float().to(device))
    pred_mask = torch.argmax(output, axis = 0)

np.unique(pred_mask.squeeze().cpu().numpy(), return_counts=True)
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(30,10))
ax1.imshow(img, cmap='gray')
ax2.imshow(mask, cmap='gray')
ax3.imshow(pred_mask.squeeze().cpu().numpy(), cmap='gray')
plt.show()