In [2]:
%matplotlib inline

In [1]:
import os
import matplotlib.pyplot as plt
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.transforms import Compose, ToTensor, Resize

import random
import datetime

import nibabel as nib
import torch
import torch.nn as nn


In [31]:
class DoubleConvBlock(nn.Module):
    """ Convolutional block for U-Net architecture.
        
        Block consists of two convolutional 2d layers, followed by batch normalization, ReLu activation 
        and a Drop Out layer.
    """

    def __init__(self, in_channels, out_channels, kernel_size = 3): 
        """ Initializes double_conv_block with specified input and output channels,
                kernel size, and padding.
        """
        super(DoubleConvBlock, self).__init__()
        self.conv = nn.Sequential(
                            nn.Conv2d(in_channels, out_channels, kernel_size, padding='same'),
                            nn.BatchNorm2d(out_channels),
                            nn.ReLU(inplace=True),
                            nn.Dropout(0.4),
                            nn.Conv2d(out_channels, out_channels, kernel_size, padding='same'),
                            nn.BatchNorm2d(out_channels),
                            nn.ReLU(inplace=True),
                            nn.Dropout(0.4)
                            )
        
    def forward(self, x): 
        return self.conv(x)

class Encoder(nn.Module):
    """Encoder architecture.
        
        It consists of several convolutional blocks with max pooling layers.

    Args:
        channels (List[int]): A list of channels for the convolutional block.
        
    Example:
        channels = [1, 64, 128, 256, 512]

    """

    def __init__(self, channels):
        super(Encoder, self).__init__()
        self.encoderBlocks = nn.ModuleList()

        # Adds a convolutional block followed by a max pooling layer (except the last one)
        for i in range(len(channels)-1):
            self.encoderBlocks.append(
                DoubleConvBlock(channels[i], channels[i+1])),

            if i < len(channels)-2:
                self.encoderBlocks.append(nn.MaxPool2d(kernel_size=2))

    def forward(self, x):
        features_encoder = []
        for encoder_block in self.encoderBlocks:
            x = encoder_block(x)

            # Save output of each convolutional block
            if isinstance(encoder_block, DoubleConvBlock):
                encoder_features.append(x)

        return features_encoder


class Decoder(nn.Module):
    """Decoder architecture

        It consists of several convolutional blocks with decreasing number of channels

    Args:
        channels (List[int]): A list of channels for convolutionals block
    
    Example:
        channels = [512, 256, 128, 64]
                                            
    """

    def __init__(self, channels): 
        super(Decoder, self).__init__()
        self.decoderBlocks = nn.ModuleList()

        # Add a upconvolutional followed by a double convolutional block for each level
        for i in range(len(channels)-1):
            self.decoderBlocks.append(nn.ConvTranspose2d(
                channels[i], channels[i+1], 2, 2))
            self.decoderBlocks.append(
                DoubleConvBlock(channels[i], channels[i+1]))

    def _center_crop(self, feature, target_size): 
        """Crops the input tensor to the target size.
        
        Args:
            feature (torch.Tensor)
            target_size (torch.Tensor)
            
        Returns: 
            cropped feature (torch.Tensor)
        """
        _, _, H, W = target_size.shape
        _, _, h, w = feature.shape

        # Calculate the starting indices for the crop
        h_start = (h - H) // 2
        w_start = (w - W) // 2

        # Crop and returns the tensor
        return feature[:, :, h_start:h_start+H, w_start:w_start+W]

    def forward(self, x, features_encoder): #-> torch.Tensor:

        for i, decoder_block in enumerate(self.decoderBlocks):

            # Concatenate the output of the encoder with the output of the decoder
            if isinstance(decoder_block, DoubleConvBlock):
                features_encoder = self._center_crop(features_encoder[i//2], x)
                x = torch.cat([x, features_encoder], dim=1)

            # Apply the upconv or double convolutional block
            x = decoder_block(x)
        return x

class UNet(nn.Module):
    """The UNet architecture.   

    Args:
        out_channels (int): The number of output channels.
        channels (List[int]): A list of channels for convolutionals block.

    Example:
        model = UNet(channels=[1, 64, 128, 256, 512], out_channels=1)
    """

    def __init__(self, channels, out_channels): 
        super(UNet, self).__init__()
        self.encoder = Encoder(channels)
        self.decoder = Decoder(channels[::-1][:-1])
        self.output = nn.Conv2d(channels[1], out_channels, kernel_size=1)

    def forward(self, x): 
        features_encoder = self.encoder(x)[::-1]
        x = self.decoder(features_encoder[0], features_encoder[1:])
        x = self.output(x)
        return x


In [32]:
class DiceLoss(nn.Module):
    """
    
    """
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        inputs = F.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        
        return 1 - dice

In [38]:
class DatasetMRI(Dataset):
    
    def __init__(self, csv_file, root_dir, augment = False, patch_size = 256):
        super(DatasetMRI, self).__init__()
        
        # Train and test csv file
        self.csv_file = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.augment = augment
        self.patch_size = patch_size

    
    def __len__(self):
        return len(self.csv_file)
    
    def __getitem__(self,index):
        # Get folder name at specified index
        folder_name = self.csv_file.iloc[index]['id']
        # Get the filename at the specified index
        filename_img = self.csv_file.iloc[index]["p_id"]
        filename_seg = self.csv_file.iloc[index]["p_id_seg"]
        
        # Full path to the image file
        image_path = os.path.join(self.root_dir, folder_name, filename_img)
        segmentation_path = os.path.join(self.root_dir, folder_name, filename_seg)
        
        # MRI scan
        image = self.read_nifti(image_path)
        # Segmentation
        seg = self.read_nifti(segmentation_path)

        # Get biggest tumor slice 
        biggest_img, biggest_seg = self.get_slice(image, seg)
        norm_img = self.normalize(biggest_img)
    
        # Add channel dimension and convert to tensor
        image_tensor = torch.from_numpy(norm_img).unsqueeze(0).float()  
        seg_tensor = torch.from_numpy(biggest_seg).unsqueeze(0).float()

        image_tensor = self.partition(image_tensor, self.patch_size)
        seg_tensor = self.partition(seg_tensor, self.patch_size)
        
        if self.augment:
            image_tensor = [self.apply_augmentation(chunk) for chunk in image_tensor]

        image_tensor = torch.stack(image_tensor)
        seg_tensor = torch.stack(seg_tensor)
        
        return image_tensor, seg_tensor
    
    def read_nifti(self, filepath):
        """Read a NIfTI file.
    
            Args:
                filepath (str): Path to the NIfTI file.
            Returns:
                image (numpy array)
        """
        # Load NIfTI file
        img = nib.load(filepath)
        # Get the image data as a numpy array
        img_data = img.get_fdata()

        return img_data
    
    def partition(self, image, patch_size):
        """Partition the images into patch_size squares
        Args:
            image (Tensor): MRI images
            patch_size (int): size of the division of images
        
        Returns:
            patches (List): list of the different divisions of the image
        
        """
        patches = []
        _, height, width = image.shape
        for i in range(0, height, patch_size):
            for j in range(0, width, patch_size):
                patch = image[:, i:i+patch_size, j:j+patch_size]
                # Padding if the patch is smaller than patch_size
                if patch.shape[1] < patch_size or patch.shape[2] < patch_size:
                    pad = torch.nn.functional.pad(patch, (0, patch_size-patch.shape[2], 0, patch_size-patch.shape[1]), 'constant', 0)
                    patches.append(pad)
                else:
                    patches.append(patch)
        return patches
    
    def get_slice(self, image, seg):
        """Choose the slice where the biggest tumour map is

        Args:
            image (tensor): MRI image
            seg (tensor): segmentation map
        
        Returns:
            biggest_img (array): corresponds to biggest slice from image
            biggest_seg (array): corresponds to biggest slice from segmentation map
        """
        
        max_area = 0
        idx_slice = 0

        for i in range(seg.shape[2]):
            slice = seg[:, :, i]
            # Calculate the area of the tumor in the current slice
            tumor_area = np.sum(slice)
            # Update the maximum area and slice index if the current slice has a larger area
            if tumor_area > max_area:
                max_area = tumor_area
                idx_slice = i
        
        # Get biggest slice
        biggest_img = image[:,:,idx_slice]
        biggest_slice = seg[:,:,idx_slice]
        
        return biggest_img, biggest_slice
    
    def normalize(self, image):
        """Normalize the image using mean and standard deviation.

        Args:
            image (numpy.ndarray): The image to be normalized.

        Returns:
            numpy.ndarray: The normalized image.
        """
        min_val = np.min(image)
        max_val = np.max(image)

        normalized_image = image - min_val
        normalized_image = image / (max_val - min_val)
        
        return normalized_image

        
    def apply_augmentation(self, image_tensor):
        """Applies blur imitating motion blur in a image
        
        """
        
        # Apply random Gaussian blur
        if random.random() > 0.5:
            blur_transform = transforms.GaussianBlur(kernel_size=(5, 9), sigma=(2, 5))
            image_tensor = blur_transform(image_tensor)
        
        return image_tensor


In [39]:
# Create train, validation and test dataloaders from an already split dataset

batch_size = 5
rootdir = os.path.abspath('data_images')

mri_dataset = DatasetMRI(csv_file=os.path.abspath('training.csv'),root_dir=rootdir, augment = True, patch_size = 256)
train_loader = DataLoader(mri_dataset, batch_size=batch_size, shuffle=True)

val_dataset = DatasetMRI(csv_file=os.path.abspath('validation.csv'),root_dir=rootdir, augment = False, patch_size = 256)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

test_dataset = DatasetMRI(csv_file=os.path.abspath('test.csv'), root_dir = rootdir, augment = False, patch_size = 256)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
# Visualize 5 samples from train_loader
num_samples = 5
count = 0

for i,(imgs,lbs) in enumerate(train_loader):
    for j in range(len(imgs)):
        if count > num_samples:
            break
        plt.figure()
        plt.imshow(imgs[j][1,0,:,:],cmap='gray')
        plt.show()

        plt.figure()
        plt.imshow(lbs[j][1,0,:,:],cmap='gray')
        plt.show()
        count+=1
    if count >= num_samples:
        break



In [None]:
# Initialize the model, loss function, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(channels=[1, 64, 128, 256, 512, 1024], out_channels=1).to(device)
criterion = DiceLoss()  # Adjust the loss function as per your task
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)



num_epochs = 20
train_losses = []
val_losses = []
best_val_loss = np.inf
for epoch in range(num_epochs):
    # Training loop
    model.train()
    train_loss = 0.0
    for image, true_mask in train_loader:
        partition_loss = 0.0
        for partition in range(image.shape[1]):
            images, masks = image[partition,:,:,:].to(device), true_mask[partition,:,:,:].to(device)
            
            optimizer.zero_grad()

            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            partition_loss += loss.item()
            
        train_loss += partition_loss
        train_loss /= len(image)

        
    train_loss /= len(train_loader)
    train_losses.append(train_loss)
    
    # Validation loop
    model.eval()
    correct = 0
    total = 0
    val_loss = 0.0
    with torch.no_grad():
        for images, masks in test_loader:
            partition_val_loss = 0.0
            for partition in range(images.shape[1]):

                images, masks = image[partition,:,:,:].to(device), true_mask[partition,:,:,:].to(device)

                outputs = model(images)
                loss = criterion(outputs, masks)
                partition_val_loss += loss.item()
                
            val_loss += partition_val_loss
            val_loss /= len(images)
                

    val_loss /= len(test_loader)
    val_losses.append(val_loss)
    
    scheduler.step(val_loss)


    print(f"Epoch [{epoch + 1}/{num_epochs}], Train loss: {train_loss:.4f}, Val loss: {val_loss:.4f}")
    
    if val_loss < best_val_loss:
        timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        best_val_loss = val_loss
        torch.save(model.state_dict(), f"models_unet/unet_model_{epoch}_{timestamp}.pth")



In [19]:
model_path = 'unet_model.pth'

In [None]:
# Visualize test predictions
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(channels=[1, 64, 128, 256, 512, 1024], out_channels=1).to(device)
model.load_state_dict(torch.load(model_path))

model.eval()
with torch.no_grad():
    for i, (images, masks) in enumerate(test_loader):
        for partition in range(images.shape[1]):
            image, mask = images[partition,:, :, :].to(device), masks[partition, :, :, :].to(device)
            outputs = model(image)

            for batch in range(image.shape[0]):
                img = image[batch].cpu().numpy().squeeze()
                msk = mask[batch].cpu().numpy().squeeze()
                output = torch.sigmoid(outputs[batch]).cpu().numpy().squeeze()
                # Binarize the output for visualization
                output = (output > 0.5).astype(np.uint8)

                # Plot the original image, true mask, and predicted mask
                fig, axes = plt.subplots(1, 3, figsize=(15, 5))
                axes[0].imshow(img, cmap='gray')
                axes[0].set_title('Original Image')
                axes[0].axis('off')

                axes[1].imshow(msk, cmap='gray')
                axes[1].set_title('Ground truth')
                axes[1].axis('off')

                axes[2].imshow(output, cmap='gray')
                axes[2].set_title('Predicted segmentation')
                axes[2].axis('off')

                plt.show()