### Import Libraries

In [None]:
from tqdm import tqdm
import os
import time
from random import randint
 
import gc 
import numpy as np
from scipy import stats
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVR
from sklearn.model_selection import KFold

import nibabel as nib
import pydicom as pdm
import nilearn as nl
import nilearn.plotting as nlplt
import h5py

import matplotlib.pyplot as plt
from matplotlib import cm


import seaborn as sns
import imageio
from skimage.transform import resize
from skimage.util import montage

from IPython.display import clear_output


import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn import MSELoss

import albumentations as A
from albumentations import Compose, HorizontalFlip


from monai.metrics.utils import do_metric_reduction
from monai.metrics.utils import get_mask_edges, get_surface_distance
from monai.metrics import CumulativeIterationMetric

from medcam import medcam

import warnings
warnings.simplefilter("ignore")

### Load an example of the 3D Brain MRI

In [None]:
import nibabel as nib
import numpy as np

sample_filename = r'C:\Users\ethan\OneDrive\Desktop\Thesis Brain Tumour\BraTS2020_TrainingData\MICCAI_BraTS2020_TrainingData\BraTS20_Training_275\BraTS20_Training_275_flair.nii'
sample_filename_mask = r'C:\Users\ethan\OneDrive\Desktop\Thesis Brain Tumour\BraTS2020_TrainingData\MICCAI_BraTS2020_TrainingData\BraTS20_Training_275\BraTS20_Training_275_seg.nii'

sample_img = nib.load(sample_filename)
sample_img = np.asanyarray(sample_img.dataobj)
sample_img = np.rot90(sample_img)
sample_mask = nib.load(sample_filename_mask)
sample_mask = np.asanyarray(sample_mask.dataobj)
sample_mask = np.rot90(sample_mask)
print("img shape ->", sample_img.shape)
print("mask shape ->", sample_mask.shape)

In [None]:
sample_filename2 = r'BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_275/BraTS20_Training_275_t1.nii'
sample_img2 = nib.load(sample_filename2)
sample_img2 = np.asanyarray(sample_img2.dataobj)
sample_img2  = np.rot90(sample_img2)

sample_filename3 = r'C:\Users\ethan\OneDrive\Desktop\Thesis Brain Tumour\BraTS2020_TrainingData\MICCAI_BraTS2020_TrainingData\BraTS20_Training_275\BraTS20_Training_275_t2.nii'
sample_img3 = nib.load(sample_filename3)
sample_img3 = np.asanyarray(sample_img3.dataobj)
sample_img3  = np.rot90(sample_img3)

sample_filename4 = r'BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_275/BraTS20_Training_275_t1ce.nii'
sample_img4 = nib.load(sample_filename4)
sample_img4 = np.asanyarray(sample_img4.dataobj)
sample_img4  = np.rot90(sample_img4)

# WHOLE TUOMUR / ED - LABEL 2 
mask_WT = sample_mask.copy()
mask_WT[mask_WT == 1] = 1
mask_WT[mask_WT == 2] = 1
mask_WT[mask_WT == 4] = 1

# NCR OR NET - LABEL 1 
mask_TC = sample_mask.copy()
mask_TC[mask_TC == 1] = 1
mask_TC[mask_TC == 2] = 0
mask_TC[mask_TC == 4] = 1

# ET - LABEL 4 
mask_ET = sample_mask.copy()
mask_ET[mask_ET == 1] = 0
mask_ET[mask_ET == 2] = 0
mask_ET[mask_ET == 4] = 1

### Visualize a slice of the Brain Tumour and the MRIs

In [None]:
import matplotlib.gridspec as gridspec
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt

# https://matplotlib.org/3.3.2/gallery/images_contours_and_fields/plot_streamplot.html#sphx-glr-gallery-images-contours-and-fields-plot-streamplot-py
# https://stackoverflow.com/questions/25482876/how-to-add-legend-to-imshow-in-matplotlib
fig = plt.figure(figsize=(20, 10))

gs = gridspec.GridSpec(nrows=2, ncols=4, height_ratios=[1, 1]) # create the space in which the images will be plotted (2 rows, 4 columns, second row 2 times the height of the first row)

#  Varying density along a streamline
ax0 = fig.add_subplot(gs[0, 0])
flair = ax0.imshow(sample_img[:,:,70], cmap='bone')
ax0.set_title("FLAIR", fontsize=18, weight='bold', y=-0.2)
fig.colorbar(flair)

#  Varying density along a streamline
ax1 = fig.add_subplot(gs[0, 1])
t1 = ax1.imshow(sample_img2[:,:,70], cmap='bone')
ax1.set_title("T1", fontsize=18, weight='bold', y=-0.2)
fig.colorbar(t1)

#  Varying density along a streamline
ax2 = fig.add_subplot(gs[0, 2])
t2 = ax2.imshow(sample_img3[:,:,70], cmap='bone')
ax2.set_title("T2", fontsize=18, weight='bold', y=-0.2)
fig.colorbar(t2)

#  Varying density along a streamline
ax3 = fig.add_subplot(gs[0, 3])
t1ce = ax3.imshow(sample_img4[:,:,70], cmap='bone')
ax3.set_title("T1 contrast", fontsize=18, weight='bold', y=-0.2)
fig.colorbar(t1ce)

#  Varying density along a streamline
ax4 = fig.add_subplot(gs[1, 0:4])

#ax4.imshow(np.ma.masked_where(mask_WT[:,:,65]== False,  mask_WT[:,:,65]), cmap='summer', alpha=0.6)
l1 = ax4.imshow(mask_WT[:,:,70], cmap='summer',)
l2 = ax4.imshow(np.ma.masked_where(mask_TC[:,:,70]== False,  mask_TC[:,:,70]), cmap='rainbow', alpha=0.6) # creating segmentation with tumour core
l3 = ax4.imshow(np.ma.masked_where(mask_ET[:,:,70] == False, mask_ET[:,:,70]), cmap='winter', alpha=0.6) # creating segmentation with enhancing tumour

ax4.set_title("Segmented Mask", fontsize=20, weight='bold', y=-0.1)

_ = [ax.set_axis_off() for ax in [ax0,ax1,ax2,ax3, ax4]] # removes the numbers (pixel values) from the y and x axis

colors = [im.cmap(im.norm(1)) for im in [l1,l2, l3]] # Generate colours based on normalizing the values between 0 to 1 for each image of the segmented mask


labels = ['Non-Enhancing tumor core', 'Peritumoral Edema ', 'GD-enhancing tumor']
patches = [ mpatches.Patch(color=colors[i], label=f"{labels[i]}") for i in range(len(labels))]
# put those patched as legend-handles into the legend
plt.legend(handles=patches, bbox_to_anchor=(1.1, 0.65), loc=2, borderaxespad=0.4,fontsize = 'xx-large',
           title='Mask Labels', title_fontsize=18, edgecolor="black",  facecolor='#c5c6c7')

plt.suptitle("Multimodal Scans -  Data | Manually-segmented mask - Target", fontsize=20, weight='bold')

# fig.savefig("data_sample.png", format="png",  pad_inches=0.2, transparent=False, bbox_inches='tight')
# fig.savefig("data_sample.svg", format="svg",  pad_inches=0.2, transparent=False, bbox_inches='tight')

### Generating a 3D GIF of the Brain Tumour

In [None]:
class Image3dToGIF3d:
    """
    Displaying 3D images in 3d axes.
    Parameters:
        img_dim: shape of cube for resizing.
        figsize: figure size for plotting in inches.
    """
    def __init__(self, 
                 img_dim: tuple = (55, 55, 55), #Image dimension size
                 figsize: tuple = (15, 10), #size of image output
                 binary: bool = False, 
                 normalizing: bool = True,
                ):
        """Initialization."""
        self.img_dim = img_dim
        print(img_dim)
        self.figsize = figsize
        self.binary = binary
        self.normalizing = normalizing

    def _explode(self, data: np.ndarray):
        """
        Takes: array and return an array twice as large in each dimension,
        with an extra space between each voxel.
        """
        shape_arr = np.array(data.shape)
        size = shape_arr[:3] * 2 - 1
        exploded = np.zeros(np.concatenate([size, shape_arr[3:]]),
                            dtype=data.dtype)
        exploded[::2, ::2, ::2] = data
        return exploded

    def _expand_coordinates(self, indices: np.ndarray):
        x, y, z = indices
        x[1::2, :, :] += 1
        y[:, 1::2, :] += 1
        z[:, :, 1::2] += 1
        return x, y, z
    
    def _normalize(self, arr: np.ndarray):
        """Normilize image value between 0 and 1."""
        arr_min = np.min(arr)
        return (arr - arr_min) / (np.max(arr) - arr_min)

    
    def _scale_by(self, arr: np.ndarray, factor: int):
        """
        Scale 3d Image to factor.
        Parameters:
            arr: 3d image for scalling.
            factor: factor for scalling.
        """
        mean = np.mean(arr)
        return (arr - mean) * factor + mean 
        # the mean is added back to the scaled array ((arr - mean) * factor + mean). 
        # This step ensures that the mean value of the resulting array remains unchanged after scaling.
    
    def get_transformed_data(self, data: np.ndarray):
        """Data transformation: normalization, scaling, resizing."""
        if self.binary:
            resized_data = resize(data, self.img_dim, preserve_range=True)
            return np.clip(resized_data.astype(np.uint8), 0, 1).astype(np.float32)
            
        norm_data = np.clip(self._normalize(data)-0.1, 0, 1) ** 0.4
        scaled_data = np.clip(self._scale_by(norm_data, 2) - 0.1, 0, 1)
        resized_data = resize(scaled_data, self.img_dim, preserve_range=True)
        
        return resized_data
    
    def plot_cube(self,
                  cube,
                  title: str = '', 
                  init_angle: int = 0,
                  make_gif: bool = False,
                  path_to_save: str = 'filename.gif'
                 ):
        """
        Plot 3d data.
        Parameters:
            cube: 3d data
            title: title for figure.
            init_angle: angle for image plot (from 0-360).
            make_gif: if True create gif from every 5th frames from 3d image plot.
            path_to_save: path to save GIF file.
            """
        if self.binary:
            facecolors = cm.winter(cube)
            print("binary")
        else:
            if self.normalizing:
                cube = self._normalize(cube)
            facecolors = cm.gist_stern(cube)
            print("not binary")
            
        facecolors[:,:,:,-1] = cube
        facecolors = self._explode(facecolors)

        filled = facecolors[:,:,:,-1] != 0
        x, y, z = self._expand_coordinates(np.indices(np.array(filled.shape) + 1))

        with plt.style.context("dark_background"):

            fig = plt.figure(figsize=self.figsize)
            ax = fig.add_subplot(projection = '3d')

            ax.view_init(30, init_angle)
            ax.set_xlim(right = self.img_dim[0] * 2)
            ax.set_ylim(top = self.img_dim[1] * 2)
            ax.set_zlim(top = self.img_dim[2] * 2)
            ax.set_title(title, fontsize=18, y=1.05)

            ax.voxels(x, y, z, filled, facecolors=facecolors, shade=False)

            if make_gif:
                images = []
                for angle in tqdm(range(0, 360, 5)):
                    ax.view_init(30, angle)
                    fname = str(angle) + '.png'

                    plt.savefig(fname, dpi=120, format='png', bbox_inches='tight')
                    images.append(imageio.imread(fname))
                    #os.remove(fname)
                imageio.mimsave(path_to_save, images)
                plt.close()

            else:
                plt.show()

                
class ShowResult:
  
    def mask_preprocessing(self, mask):
        """
        Test.
        """
        # removing all the ones in the tensor --> using cpu --> removing the tensor from its computational graph --> tensor to numpy conversion 
        
        print(mask.shape)
        mask_crop1 = mask[0,0,:,:,:]
        mask_crop2 = mask[0,1,:,:,:]
        mask_crop3 = mask[0,2,:,:,:]
        
        mask_WT = montage(mask_crop1)
        mask_TC = montage(mask_crop2)
        mask_ET = montage(mask_crop3)

        return mask_WT, mask_TC, mask_ET
        

    def image_preprocessing(self, image):
        """
        Returns image flair as mask for overlaping gt and predictions.
        """
        image = image.squeeze().cpu().detach().numpy()
        
        # image = np.moveaxis(image, (0, 1, 2, 3), (0, 3, 2, 1))
        
        img_crop = image[0, :,:,:]
        flair_img = montage(img_crop)
        
        return flair_img
    
    def plot(self, image, ground_truth, prediction):
        image = self.image_preprocessing(image)
        gt_mask_WT, gt_mask_TC, gt_mask_ET = self.mask_preprocessing(ground_truth)
        pr_mask_WT, pr_mask_TC, pr_mask_ET = self.mask_preprocessing(prediction)
        
        fig, axes = plt.subplots(1, 2, figsize = (35, 30))
    
        [ax.axis("off") for ax in axes]
        axes[0].set_title("Ground Truth", fontsize=35, weight='bold')
        axes[0].imshow(image, cmap ='bone')
        axes[0].imshow(np.ma.masked_where(gt_mask_WT == False, gt_mask_WT),
                  cmap='summer', alpha=0.6)
        axes[0].imshow(np.ma.masked_where(gt_mask_TC == False, gt_mask_TC),
                  cmap='rainbow', alpha=0.6)
        axes[0].imshow(np.ma.masked_where(gt_mask_ET == False, gt_mask_ET),
                  cmap='Wistia', alpha=0.6)

                  

        axes[1].set_title("Prediction", fontsize=35, weight='bold')
        axes[1].imshow(image, cmap ='bone')
        axes[1].imshow(np.ma.masked_where(pr_mask_WT == False, pr_mask_WT),
                   cmap='summer', alpha=0.6)
        axes[1].imshow(np.ma.masked_where(pr_mask_TC == False, pr_mask_TC),
                   cmap='rainbow', alpha=0.6)
        axes[1].imshow(np.ma.masked_where(pr_mask_ET == False, pr_mask_ET),
                  cmap='Wistia', alpha=0.6)

        plt.tight_layout()
        
        plt.show()
        


def merging_two_gif(path1: str, path2: str, name_to_save: str):
    """
    Merging GIFs side by side.
    Parameters:
        path1: path to gif with ground truth.
        path2: path to gif with prediction.
        name_to_save: name for saving new GIF.
    """
    #Create reader object for the gif
    gif1 = imageio.get_reader(path1)
    gif2 = imageio.get_reader(path2)

    #If they don't have the same number of frame take the shorter
    number_of_frames = min(gif1.get_length(), gif2.get_length()) 

    #Create writer object
    new_gif = imageio.get_writer(name_to_save)

    for frame_number in range(number_of_frames):
        img1 = gif1.get_next_data()
        img2 = gif2.get_next_data()
        #here is the magic
        new_image = np.hstack((img1, img2))
        new_gif.append_data(new_image)

    gif1.close()
    gif2.close()    
    new_gif.close()
    
#merging_two_gif('BraTS20_Training_001_flair_3d.gif',
#                'BraTS20_Training_001_flair_3d.gif', 
#                'result.gif')

def get_all_csv_file(root: str) -> list:
    """Extraction all unique ids from file names."""
    ids = []
    for dirname, _, filenames in os.walk(root):
        for filename in filenames:
            path = os.path.join(dirname, filename)
            if path.endswith(".csv"):
                ids.append(path) 
    ids = list(set(filter(None, ids)))
    print(f"Extracted {len(ids)} csv files.")
    return ids

### Path Directories

In [None]:
class GlobalConfig:
    root_dir = r'C:\Users\ethan\OneDrive\Desktop\Thesis Brain Tumour'
    train_root_dir = r'C:\Users\ethan\OneDrive\Desktop\Thesis Brain Tumour\BraTS2020_TrainingData\MICCAI_BraTS2020_TrainingData'
    test_root_dir = r'C:\Users\ethan\OneDrive\Desktop\Thesis Brain Tumour\test_df'
    path_to_csv = 'tumourCSV.csv'
    

    # Define the directory where the model checkpoints are saved
    
    UNet_checkpoint_dir = r"C:\Users\ethan\OneDrive\Desktop\Thesis Brain Tumour\UNet model"
    ResUNet_checkpoint_dir =  r"C:\Users\ethan\OneDrive\Desktop\Thesis Brain Tumour\ResUNet model"
    Att_checkpoint_dir = r'C:\Users\ethan\OneDrive\Desktop\Thesis Brain Tumour\AttUNet model'

    train_logs_path = r'C:\Users\ethan\OneDrive\Desktop\Thesis Brain Tumour\UNet model\train_log.csv'
    ResUNet_train_logs_path = r'C:\Users\ethan\OneDrive\Desktop\Thesis Brain Tumour\ResUNet model\train_log.csv'
    AttUNet_train_logs_path = r'C:\Users\ethan\OneDrive\Desktop\Thesis Brain Tumour\AttUNet model\train_log.csv'

    
    seed = 55
    
def seed_everything(seed: int):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    
config = GlobalConfig()
seed_everything(config.seed)

### Create a 3D scatterplot of the brain using Plotly (Interactive)

In [None]:
class ImageReader:
    def __init__(self, root:str, img_size:int=256, normalize:bool=False, single_class:bool=False):
        pad_size = 256 if img_size > 256 else 224
        self.resize = A.Compose(
            [
                A.PadIfNeeded(min_height=pad_size, min_width=pad_size, value=0),
                A.Resize(img_size, img_size)
            ]
        )
        self.normalize=normalize
        self.single_class=single_class
        self.root=root
        
    def read_file(self, path:str) -> dict:
        scan_type = path.split('_')[-1]
        raw_image = nib.load(path).get_fdata()
        raw_mask = nib.load(path.replace(scan_type, 'seg.nii')).get_fdata()
        processed_frames, processed_masks = [], []
        for frame_idx in range(raw_image.shape[2]):
            frame = raw_image[:, :, frame_idx]
            mask = raw_mask[:, :, frame_idx]
            if self.normalize:
                if frame.max() > 0:
                    frame = frame/frame.max()
                frame = frame.astype(np.float32)
            else:
                frame = frame.astype(np.uint8)
            resized = self.resize(image=frame, mask=mask)
            processed_frames.append(resized['image'])
            processed_masks.append(1*(resized['mask'] > 0) if self.single_class else resized['mask'])
        return {
            'scan': np.stack(processed_frames, 0),
            'segmentation': np.stack(processed_masks, 0),
            'orig_shape': raw_image.shape
        }
    
    def load_patient_scan(self, idx:int, scan_type:str='flair') -> dict:
        patient_id = str(1).zfill(3) 
        scan_filename = f'{self.root}/BraTS20_Training_{patient_id}/BraTS20_Training_{patient_id}_{scan_type}.nii'
        return self.read_file(scan_filename)
    
import plotly.graph_objects as go
import numpy as np
import plotly


def generate_3d_scatter(
    x:np.array, y:np.array, z:np.array, colors:np.array,
    size:int=3, opacity:float=0.2, scale:str='Teal',
    hover:str='skip', name:str='MRI'
) -> go.Scatter3d:
    return go.Scatter3d(
        x=x, y=y, z=z,
        mode='markers', hoverinfo=hover,
        marker = dict(
            size=size, opacity=opacity,
            color=colors, colorscale=scale
        ),
        name=name
    )


class ImageViewer3d():
    def __init__(
        self, reader:ImageReader, mri_downsample:int=10, mri_colorscale:str='Ice'
    ) -> None:
        self.reader = reader
        self.mri_downsample = mri_downsample
        self.mri_colorscale = mri_colorscale

    def load_clean_mri(self, image:np.array, orig_dim:int) -> dict:
        shape_offset = image.shape[1]/orig_dim
        z, x, y = (image > 0).nonzero()
        # only (1/mri_downsample) is sampled for the resulting image
        x, y, z = x[::self.mri_downsample], y[::self.mri_downsample], z[::self.mri_downsample]
        colors = image[z, x, y]
        return dict(x=x/shape_offset, y=y/shape_offset, z=z, colors=colors)
    def load_tumor_segmentation(self, image:np.array, orig_dim:int) -> dict:
        tumors = {}
        shape_offset = image.shape[1]/orig_dim
        # 1/1, 1/3 si 1/5 pixeli pentru clasele tumorii  1(nucleu necrotic), 2(edem) si 4(tumoare de amplificare)
        sampling = {
            1: 1, 2: 3, 4: 5
        }
        for class_idx in sampling:
            z, x, y = (image == class_idx).nonzero()
            x, y, z = x[::sampling[class_idx]], y[::sampling[class_idx]], z[::sampling[class_idx]]
            tumors[class_idx] = dict(
                x=x/shape_offset, y=y/shape_offset, z=z,
                colors=class_idx/4
            )
        return tumors
    def collect_patient_data(self, scan:dict) -> tuple:
        clean_mri = self.load_clean_mri(scan['scan'], scan['orig_shape'][0])
        tumors = self.load_tumor_segmentation(scan['segmentation'], scan['orig_shape'][0])
        markers_created = clean_mri['x'].shape[0] + sum(tumors[class_idx]['x'].shape[0] for class_idx in tumors)
        return [
            generate_3d_scatter(**clean_mri, scale=self.mri_colorscale, opacity=0.3, hover='skip', name='Brain MRI'),
            generate_3d_scatter(**tumors[1], opacity=0.90, hover='all', name='Necrotic tumor core'),
            generate_3d_scatter(**tumors[2], opacity=0.05, hover='all', name='Peritumoral invaded tissue'),
            generate_3d_scatter(**tumors[4], opacity=0.30, hover='all', name='GD-enhancing tumor'),
        ], markers_created
    
    def get_3d_scan(self, patient_idx:int, scan_type:str='flair') -> go.Figure:
        scan = self.reader.load_patient_scan(patient_idx, scan_type)
        data, num_markers = self.collect_patient_data(scan)
        fig = go.Figure(data=data)
        fig.update_layout(
            title=f"[Patient id:{patient_idx}] brain MRI scan ({num_markers} points)",
            legend_title="Pixel class (click to enable/disable)",
            font=dict(
                family="Courier New, monospace",
                size=14,
            ),
            margin=dict(
                l=0,r=0,b=0,t=30
            ),
            legend=dict(itemsizing='constant')
        )
        return fig

# tumour visualization time starts(tv0)
tv0 = time.time() 
reader = ImageReader(config.train_root_dir, img_size=128, normalize=True, single_class=False)
viewer = ImageViewer3d(reader, mri_downsample=25)

fig = viewer.get_3d_scan(100, 'flair')
plotly.offline.iplot(fig)

# Preprocessing

Check if the pretrained model exists

In [None]:
def check_exist(checkpoint_dir):
        # Get a list of all files in the checkpoint directory
        all_files = os.listdir(checkpoint_dir)

        # Filter the files to get only the model checkpoint files
        model_checkpoint_files = [file for file in all_files if file.startswith("last_epoch_model")]
        
        if model_checkpoint_files:

            sorted_file_names = sorted(model_checkpoint_files, key=lambda x: int(x.split('_')[-1].split('.')[0]))
            
            # Get the latest model checkpoint file
            latest_checkpoint_file = sorted_file_names[-1]
            

            # Construct the full path to the latest model checkpoint
            pretrained_model_path = os.path.join(checkpoint_dir, latest_checkpoint_file)
            
            return pretrained_model_path
        else:
            pretrained_model_path = None




Merge the two CSV files together

In [None]:
survival_info_df = pd.read_csv(r'C:\Users\ethan\OneDrive\Desktop\Thesis Brain Tumour\BraTS2020_TrainingData\MICCAI_BraTS2020_TrainingData\survival_info.csv')
name_mapping_df = pd.read_csv(r'C:\Users\ethan\OneDrive\Desktop\Thesis Brain Tumour\BraTS2020_TrainingData\MICCAI_BraTS2020_TrainingData\name_mapping.csv')

name_mapping_df.rename({'BraTS_2020_subject_ID': 'Brats20ID'}, axis=1, inplace=True) 


df = survival_info_df.merge(name_mapping_df, on="Brats20ID", how="right")

# renaming & merging into one dataframe

df.sample(10)

Generate the path to the Brain Tumour Segmentation Dataset

In [None]:
paths = []
for _, row  in df.iterrows():  # iterating through each row
    
    id_ = row['Brats20ID']      # column brats20ID
    phase = id_.split("_")[-2]
    if phase == 'Training':

        path = os.path.join(config.train_root_dir, id_)
    else:
        print(phase)
        path = os.path.join(config.test_root_dir, id_)
    paths.append(path)
    
df['path'] = paths

# Drop index 355 due to unavailable and missing data
df = df.loc[df['Brats20ID'] != 'BraTS20_Training_355'].reset_index(drop=True)
df.to_csv('tumourCSV.csv')
df

Stratified K fold splitting, which was experimented on but not used in the final work due to inferior results compared to train_test_split

In [None]:
# train_data = df.loc[df['Age'].notnull()].reset_index(drop=True)

# train_data["Age_rank"] =  train_data["Age"] // 10 * 10

# skf = StratifiedKFold(
#     n_splits=7, random_state=config.seed, shuffle=True
# )

# # enumeratng all entries for defining the fold number 
# # assigning the fold number in increment order 
# for i, (train_index, val_index) in enumerate(
#         skf.split(train_data, train_data["Age_rank"])
#         ):
#         train_data.loc[val_index, "fold"] = i
# train_df = train_data.loc[train_data['fold'] != 0].reset_index(drop=True)
# val_df = train_data.loc[train_data['fold'] == 0].reset_index(drop=True)

# # selecting the rows where the AGE col. is null --> test_df 
# test_df = df.loc[~df['Age'].notnull()].reset_index(drop=True)
#print("train_df ->", train_df_copy.shape, "val_df ->", val_df.shape, "test_df ->", test_df.shape)

Use train_test_split to obtain the training and validation and testing dataset

In [None]:
train_df_copy, test_df = train_test_split(df, test_size=0.30, random_state=10, shuffle=True)
train_df_copy, test_df_copy = train_df_copy.reset_index(drop=True), test_df.reset_index(drop=True)

print(train_df_copy.shape, test_df_copy.shape)

In [None]:
# # splitting of the data wasn't done for train , test &  validation data 

val_df = test_df_copy.iloc[:len(test_df_copy)*2//3]
test_df = test_df_copy.iloc[len(test_df_copy)*2//3:]

print("train_df ->", train_df_copy.shape, "val_df ->", val_df.shape, "test_df ->", test_df.shape)

Create the new csv files based on the training, testing and validation dataset

In [None]:
train_df_copy.to_csv("train_df.csv", index=False)
test_df.to_csv("test_df.csv", index=False)
val_df.to_csv("val_df.csv", index=False)

In [None]:
def get_augmentations(phase): 
    list_transforms = []
    
    # Does data augmentations & tranformation required for IMAGES & MASKS 
    # they include cropping, padding, flipping , rotating 
    list_trfms = Compose(list_transforms,  is_check_shapes=False)
    return list_trfms


def get_dataloader(
    dataset: torch.utils.data.Dataset,
    path_to_csv: str,
    phase: str,
    fold: int = 0,
    batch_size: int = 1,
    num_workers: int = 0 
    ):
    df = pd.read_csv(path_to_csv)
    train_df_copy, test_df = train_test_split(df, test_size=0.3, random_state=10, shuffle=True)
    train_df_copy, test_df_copy = train_df_copy.reset_index(drop=True), test_df.reset_index(drop=True)
    test_df = test_df_copy.iloc[len(test_df_copy)*2//3:].reset_index(drop=True)
    val_df = test_df_copy.iloc[:len(test_df_copy)*2//3].reset_index(drop=True)
    


    if phase != 'test':

        
    
    # selection a particluar fold while calling the get_dataloader function 
        
        '''Returns: dataloader for the model training'''
        
        if phase == "train" : 
            
            df = train_df_copy
        elif phase == "valid" :
            
            df = val_df
        
        dataset = dataset(df, phase)

    else:
        
        df = test_df
        dataset = dataset(df, phase)
    """
    DataLoader iteratively goes through every id in the df & gets all the individual tuples for individual ids & appends all of them 
    like this : 
    { id : ['BraTS20_Training_235'] ,
      image : [] , 
      tensor : [] , 
    } 
    { id : ['BraTS20_Training_236'] ,
      image : [] , 
      tensor : [] , 
    } 
    { id : ['BraTS20_Training_237'] ,
      image : [] , 
      tensor : [] , 
    } 
    """
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=True,
        shuffle=False,   
    )
    
    return dataloader


class BratsDataset(Dataset):
    def __init__(self, df: pd.DataFrame, phase: str="test", is_resize: bool=True):
        self.df = df
        self.phase = phase
        self.augmentations = get_augmentations(phase)
        self.data_types = ['_flair.nii', '_t1.nii', '_t1ce.nii', '_t2.nii']
        self.is_resize = is_resize
        
    def __len__(self):
        return self.df.shape[0] 
    
    def __getitem__(self, idx):
        # at a specified index ( idx ) select the value under 'Brats20ID' & asssign it to id_ 
        id_ = self.df.loc[idx, 'Brats20ID']
        
       
        root_path = self.df.loc[self.df['Brats20ID'] == id_]['path'].values[0]
        
        # load all modalities
        images = []
        
        for data_type in self.data_types:
            img_path = os.path.join(root_path, id_ + data_type) 
            img = self.load_img(img_path)#.transpose(2, 0, 1)
            
            if self.is_resize:
                img = self.resize(img)
    
            img = self.normalize(img)
            images.append(img)
            
        img = np.stack(images)
        img = np.moveaxis(img, (0, 1, 2, 3), (0, 3, 2, 1))
        
        # if self.phase != "test":
        mask_path =  os.path.join(root_path, id_ + "_seg.nii")
        mask = self.load_img(mask_path)
        
        if self.is_resize:
            
            mask = self.resize(mask)
            
        mask = self.preprocess_mask_labels(mask)
        # setting the mask labels 1 , 2 , 4 for the mask file ( _seg.ii ) 
        

        augmented = self.augmentations(image=img.astype(np.float32), 
                                        mask=mask.astype(np.float32))
        # Several augmentations / transformations like flipping, rotating, padding will be applied to both the images 
        img = augmented['image']
        mask = augmented['mask']

    
        return {
            "Id": id_,
            "image": img,
            "mask": mask,
            }
        
        
    
    def load_img(self, file_path):
        data = nib.load(file_path)
        data = np.asarray(data.dataobj)
        return data
    
    def normalize(self, data: np.ndarray):
        data_min = np.min(data)
        # normalization = (each element - min element) / ( max - min ) 
        return (data - data_min) / (np.max(data) - data_min)
    
    def resize(self, data: np.ndarray):
        
        data = data[ 40:210, 40:210, 20:120]
        # The selected indices do not remove the slices that contain the brain tumour
        #40:210
        #40:210
        #20:120
        return data
    
    def preprocess_mask_labels(self, mask: np.ndarray):

        # whole tumour
        mask_WT = mask.copy()
        mask_WT[mask_WT == 1] = 1
        mask_WT[mask_WT == 2] = 1
        mask_WT[mask_WT == 4] = 1
        # include all tumours 

        # NCR / NET - LABEL 1
        mask_TC = mask.copy()
        mask_TC[mask_TC == 1] = 1
        mask_TC[mask_TC == 2] = 0
        mask_TC[mask_TC == 4] = 1
        # exclude 2 / 4 labelled tumour 
        
        # ET - LABEL 4 
        mask_ET = mask.copy()
        mask_ET[mask_ET == 1] = 0
        mask_ET[mask_ET == 2] = 0
        mask_ET[mask_ET == 4] = 1
        # exclude 2 / 1 labelled tumour 
        
        # # ED - LABEL 2
        # # mask_ED = mask.copy()
        # # mask_ED[mask_ED == 1] = 0
        # # mask_ED[mask_ED == 2] = 1
        # # mask_ED[mask_ED == 4] = 0


        # mask = np.stack([mask_WT, mask_TC, mask_ET, mask_ED])
        mask = np.stack([mask_WT, mask_TC, mask_ET])
        
        mask = np.moveaxis(mask, (0, 1, 2, 3), (0, 3, 2, 1))

        return mask  

Load the training dataset

In [None]:
dataloader = get_dataloader(dataset=BratsDataset, path_to_csv='tumourCSV.csv', phase='train')
len(dataloader)


Load the testing dataset

In [None]:
test_dataloader = get_dataloader(dataset=BratsDataset, path_to_csv='tumourCSV.csv', phase='test')
len(test_dataloader)

Load the validation dataset

In [None]:
val_dataloader = get_dataloader(dataset=BratsDataset, path_to_csv='tumourCSV.csv', phase='valid')
len(val_dataloader)


### Load one of the brain MRIs to visualize all the ground truth tumour segmentations

In [None]:

data = next(iter(dataloader))
print('Shape: ', data['Id'], data['image'].shape, data['mask'].shape)

fig, ax = plt.subplots(1, 3, figsize = (20, 20))

idx = 20
ax[0].imshow(data['mask'][0,0,idx,:,:,], cmap ='bone')
ax[1].imshow(data['mask'][0,1,idx,:,:,], cmap ='bone')
ax[2].imshow(data['mask'][0,2,idx,:,:,], cmap ='bone')
# batch size , channels , spatial dimensions
# no.of images in a batch : channels : t1 , t2 , flair , weighted : dimensions


img_tensor = data['image'].squeeze()[0].cpu().detach().numpy() 

mask_tensor = data['mask'].squeeze()[0].squeeze().cpu().detach().numpy()



print("Num uniq Image values :", len(np.unique(img_tensor, return_counts=True)[0]))
print("Min/Max Image values:", img_tensor.min(), img_tensor.max())
print("Num uniq Mask values:", np.unique(mask_tensor, return_counts=True))

image = montage(img_tensor)
mask = montage(mask_tensor)

fig, ax = plt.subplots(1, 1, figsize = (20, 20))
ax.imshow(image, cmap ='bone')
ax.imshow(np.ma.masked_where(mask == False, mask),
           cmap='cool', alpha=0.6)

### Define the Evaluation Metrics and Loss Function

In [None]:
def dice_coef_metric(probabilities: torch.Tensor,
                     truth: torch.Tensor, 
                     treshold: float = 0.5,
                     eps: float = 1e-9) -> np.ndarray:
    """
    Calculate Dice score for data batch.
    Params:
        probobilities: model outputs after activation function.
        truth: truth values.
        threshold: threshold for probabilities.
        eps: additive to refine the estimate.
        Returns: dice score aka f1.
    """
    scores = []
    num = probabilities.shape[0] 
    predictions = (probabilities >= treshold).float()
    assert(predictions.shape == truth.shape) # shape of prediction and shape should be equal
    for i in range(num):
        prediction = predictions[i]
        truth_ = truth[i]
        intersection = 2.0 * (truth_ * prediction).sum()
        union = truth_.sum() + prediction.sum()
        if truth_.sum() == 0 and prediction.sum() == 0:
            scores.append(1.0)
        else:
            scores.append((intersection + eps) / union)
    return np.mean(scores)


def jaccard_coef_metric(probabilities: torch.Tensor,
               truth: torch.Tensor,
               treshold: float = 0.5,
               eps: float = 1e-9) -> np.ndarray:
    """
    Calculate Jaccard index for data batch.
    Params:
        probobilities: model outputs after activation function.
        truth: truth values.
        threshold: threshold for probabilities.
        eps: additive to refine the estimate.
        Returns: jaccard score aka iou."
    """
    scores = []
    num = probabilities.shape[0]
    predictions = (probabilities >= treshold).float()
    assert(predictions.shape == truth.shape)

    for i in range(num):
        prediction = predictions[i]
        truth_ = truth[i]
        intersection = (prediction * truth_).sum()
        union = (prediction.sum() + truth_.sum()) - intersection + eps
        if truth_.sum() == 0 and prediction.sum() == 0:
            scores.append(1.0)
        else:
            scores.append((intersection + eps) / union)
    return np.mean(scores)
            
  

class Meter:
    '''factory for storing and updating iou and dice scores.'''
    def __init__(self, treshold: float = 0.5):
        self.threshold: float = treshold
        self.dice_scores: list = []
        self.iou_scores: list = []
        self.haus_scores: list = []
    
    def update(self, logits: torch.Tensor, targets: torch.Tensor):
        """
        Takes: logits from output model and targets,
        calculates dice and iou scores, and stores them in lists.
        calculates using the above declare functions 
        """
       
        
        probs = torch.sigmoid(logits)
        
        
        dice = dice_coef_metric(probs, targets, self.threshold)
        
        iou = jaccard_coef_metric(probs, targets, self.threshold)
        
        #haus = hausdorff_distance_metric(probs, targets)
        #print(haus)
        
        # appending to the respective lists 
        self.dice_scores.append(dice)
        self.iou_scores.append(iou)
        #self.haus_scores.append(haus)
        #print(self.haus_scores)
        
        
    
    def get_metrics(self) -> np.ndarray:
        """
        Returns: the average of the accumulated dice and iou scores.
        """
        dice = np.mean(self.dice_scores)
        iou = np.mean(self.iou_scores)
        #haus = np.mean(self.haus_scores)
        return dice, iou


class DiceLoss(nn.Module):
    """Calculate dice loss."""
    def __init__(self, eps: float = 1e-9):
        super(DiceLoss, self).__init__()
        self.eps = eps
        
    def forward(self,
                logits: torch.Tensor,
                targets: torch.Tensor) -> torch.Tensor:
        
        num = targets.size(0)
        probability = torch.sigmoid(logits)
        probability = probability.view(num, -1)
        targets = targets.view(num, -1)
        assert(probability.shape == targets.shape)
        
        intersection = 2.0 * (probability * targets).sum()
        union = probability.sum() + targets.sum()
        dice_score = (intersection + self.eps) / union
        #print("intersection", intersection, union, dice_score)
        return 1.0 - dice_score
        
        
class BCEDiceLoss(nn.Module):
    """Compute objective loss: BCE loss + DICE loss."""
    def __init__(self):
        super(BCEDiceLoss, self).__init__()
        self.bce = nn.BCEWithLogitsLoss()
        self.dice = DiceLoss()
        
    def forward(self, 
                logits: torch.Tensor,
                targets: torch.Tensor) -> torch.Tensor:
        
        # logits are the images 
        # target are the masks 
        assert(logits.shape == targets.shape)
        dice_loss = self.dice(logits, targets)
        bce_loss = self.bce(logits, targets)
        
        # binary cross entropy loss & dice loss 
        return bce_loss + dice_loss
    
# helper functions for testing.  
def dice_coef_metric_per_classes(probabilities: np.ndarray,
                                    truth: np.ndarray,
                                    treshold: float = 0.33,
                                    eps: float = 1e-9,
                                    classes: list = ['WT', 'TC', 'ET']) -> np.ndarray:
    """
    Calculate Dice score for data batch and for each class i.e. 'WT', 'TC', 'ET'
    Params:
        probobilities: model outputs after activation function.
        truth: model targets.
        threshold: threshold for probabilities.
        eps: additive to refine the estimate.
        classes: list with name classes.
        Returns: dict with dice scores for each class.
    """
    scores = {key: list() for key in classes}
    num = probabilities.shape[0]
    num_classes = probabilities.shape[1]
    predictions = (probabilities >= treshold).astype(np.float32)
    assert(predictions.shape == truth.shape)

    for i in range(num):
        for class_ in range(num_classes):
            prediction = predictions[i][class_]
            truth_ = truth[i][class_]
            intersection = 2.0 * (truth_ * prediction).sum()
            union = truth_.sum() + prediction.sum()
            if truth_.sum() == 0 and prediction.sum() == 0:
                 scores[classes[class_]].append(1.0)
            else:
                scores[classes[class_]].append((intersection + eps) / union)
                
    return scores


def jaccard_coef_metric_per_classes(probabilities: np.ndarray, # output of the model in an array format 
               truth: np.ndarray,# masks  
               treshold: float = 0.33, # threshold to whether segment / not 
               eps: float = 1e-9, # smooth 
               classes: list = ['WT', 'TC', 'ET']) -> np.ndarray:
    """
    Calculate Jaccard index for data batch and for each class.
    Params:
        probobilities: model outputs after activation function.
        truth: model targets.
        threshold: threshold for probabilities.
        eps: additive to refine the estimate.
        classes: list with name classes.
        Returns: dict with jaccard scores for each class."
    """
    scores = {key: list() for key in classes}
    # storing all the jaccard coefficients in a list 
    
    num = probabilities.shape[0]
    
    num_classes = probabilities.shape[1]
    
    # segmenting if prob > threshold .i.e. setting to float32 
    predictions = (probabilities >= treshold).astype(np.float32)
    
    assert(predictions.shape == truth.shape)

    for i in range(num):
        for class_ in range(num_classes):
            prediction = predictions[i][class_]
            truth_ = truth[i][class_]
            intersection = (prediction * truth_).sum()
            union = (prediction.sum() + truth_.sum()) - intersection + eps
            if truth_.sum() == 0 and prediction.sum() == 0:
                 scores[classes[class_]].append(1.0)
            else:
                scores[classes[class_]].append((intersection + eps) / union)

    return scores



### Create a Trainer Class to train the UNet models

In [None]:
from tqdm import tqdm
import time

class Trainer():
    """
    Factory for training proccess.
    Args:
        display_plot: if True - plot train history after each epoch.
        net: neural network for mask prediction.
        criterion: factory for calculating objective loss. i.e. bce loss + dice loss / others 
        optimizer: optimizer for weights updating. i.e. Adam 
        phases: list with train and validation phases.
        dataloaders: dict with data loaders for train and val phases. i.e. DataLoader / dataloader 
        path_to_csv: path to csv file.
        meter: factory for storing and updating metrics. -> return the jaccard coeff / dice loss 
        batch_size: data batch size for one step weights updating.
        num_epochs: num weights updation for all data.
        accumulation_steps: the number of steps after which the optimization step can be taken
                    (https://www.kaggle.com/c/understanding_cloud_organization/discussion/105614).
        lr: learning rate for optimizer.
        scheduler: scheduler for control learning rate.
        losses: dict for storing lists with losses for each phase.
        jaccard_scores: dict for storing lists with jaccard scores for each phase.
        dice_scores: dict for storing lists with dice scores for each phase.
    """
    def __init__(self,
                 net: nn.Module,
                 dataset: torch.utils.data.Dataset,
                 criterion: nn.Module,
                 lr: float,
                 accumulation_steps: int,
                 batch_size: int,
                 fold: int,
                 num_epochs: int,
                 path_to_csv: str,
                 model_type: str,
                 display_plot: bool = True       
                 
                ):

        """Initialization."""
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print("device:", self.device)
        self.display_plot = display_plot
        self.net = net
        self.net = self.net.to(self.device)
        self.criterion = criterion
        self.optimizer = Adam(self.net.parameters(), lr=lr)
        self.scheduler = ReduceLROnPlateau(self.optimizer, mode="min",
                                           patience=2, verbose=True)
        self.accumulation_steps = accumulation_steps // batch_size
        self.phases = ["train", "valid"]
        self.num_epochs = num_epochs
        self.model_type = model_type
        self.epoch_value = self.check_epoch_number(self.model_type) 
        
        
        self.dataloaders = {
            phase: get_dataloader(
                dataset = dataset,
                path_to_csv = path_to_csv,
                phase = phase,
                fold = fold,
                batch_size = batch_size,
                num_workers = 0
            )
            for phase in self.phases
        }

        self.best_loss = float("inf")
        
        # calculating the list of losses for both train & validation phases 
        self.losses = {phase: [] for phase in self.phases}
        
        # calculating the dice scores for both train & validation phases 
        self.dice_scores = {phase: [] for phase in self.phases}
        
        # calculating the jaccard scores for both train & validation phases
        self.jaccard_scores = {phase: [] for phase in self.phases}

        # calculating the time for both train & validation phases
        self.time = {phase: [] for phase in self.phases}
         
    def _compute_loss_and_outputs(self,
                                  images: torch.Tensor,
                                  targets: torch.Tensor):
        images = images.to(self.device)
        targets = targets.to(self.device)
        
        # making images predictions symmetric using logits  
        logits = self.net(images)
        
        # calculating the loss bce loss / dice loss / jaccard loss / combined loss 
        # as defined calcluating the mean square error loss 
        loss = self.criterion(logits, targets)
        return loss, logits
        
    def _do_epoch(self, epoch: int, phase: str):
        start_time = time.time()
        meter = Meter()
        dataloader = self.dataloaders[phase]
        
        total_batches = len(dataloader)
        running_loss = 0.0

        # Initialize tqdm progress bar
        progress_bar = tqdm(dataloader, desc=f"{phase} epoch: {epoch}", unit="batch", dynamic_ncols=True)

        self.net.train() if phase == "train" else self.net.eval()

        for itr, data_batch in enumerate(progress_bar):
            images, targets = data_batch['image'], data_batch['mask']
            
            
            # BCEDiceLoss & raw prediction( logits ) are calculated 
            
            loss, logits = self._compute_loss_and_outputs(images, targets)
            loss = loss / self.accumulation_steps
            
            if phase == "train":
                # Backpropagating the losses generated to train the Unet 
                loss.backward()
                    
                # if a certain no. is reached then all the gradient accumulated will be given to the optimizer & it gets trained
                # after giving, gradient gets reset to 0 
                if (itr + 1) % self.accumulation_steps == 0:
                    self.optimizer.step()
                    self.optimizer.zero_grad()
                        
            running_loss += loss.item()
            progress_bar.set_postfix({"loss": running_loss / (itr + 1)})  # Update loss in progress bar
            meter.update(logits.detach().cpu(), targets.detach().cpu())

        epoch_loss = (running_loss * self.accumulation_steps) / total_batches
        epoch_dice, epoch_iou = meter.get_metrics()

        self.losses[phase].append(epoch_loss)
        self.dice_scores[phase].append(epoch_dice)
        
        self.jaccard_scores[phase].append(epoch_iou)
        
        
        # self.haus_scores[phase].append(epoch_haus)
        end_time = time.time()

        total_time = end_time - start_time

        total_time = round(total_time, 2)
        self.time[phase].append(total_time)
        return epoch_loss
        
    def run(self, check_path):
        epoch = self.epoch_value
        
        for epoch in range(int(self.epoch_value) + 1, self.num_epochs):
            self._do_epoch(epoch, "train")
            with torch.no_grad():
                val_loss = self._do_epoch(epoch, "valid")
                print(f"BCEDiceLoss for epoch {epoch} is : " , val_loss ) 
                self.scheduler.step(val_loss)
            if self.display_plot and epoch == self.num_epochs:
                self._plot_train_history()
                
            if val_loss < self.best_loss:
                print(f"\n{'#'*20}\nSaved new checkpoint\n{'#'*20}\n")
                self.best_loss = val_loss

                checkpoint_dir = check_path

                # Get a list of all files in the checkpoint directory
                all_files = os.listdir(checkpoint_dir)
                best_model_current = [file for file in all_files if file.startswith("best_model_")]
                for best_model in best_model_current:
                    os.remove(checkpoint_dir + "/" + best_model)
                torch.save(self.net, f"{self.model_type}/best_model_{epoch}.pth")
            
            if epoch % 1 == 0:
                self._save_train_history(epoch)
            print()
        self._save_train_history()
            
    def _plot_train_history(self):
        data = [self.losses, self.dice_scores, self.jaccard_scores]
        colors = ['deepskyblue', "crimson"]
        labels = [
            f"""
            train loss {self.losses['train'][-1]}
            val loss {self.losses['val'][-1]}
            """,
            
            f"""
            train dice score {self.dice_scores['train'][-1]}
            val dice score {self.dice_scores['val'][-1]} 
            """, 
                  
            f"""
            train jaccard score {self.jaccard_scores['train'][-1]}
            val jaccard score {self.jaccard_scores['val'][-1]}
            """
        ]
        
        clear_output(True)

        fig, axes = plt.subplots(3, 1, figsize=(8, 10))
        for i, ax in enumerate(axes):
            ax.plot(data[i]['val'], c=colors[0], label="val")
            ax.plot(data[i]['train'], c=colors[-1], label="train")
            ax.set_title(labels[i])
            ax.legend(loc="upper right")
                
        plt.tight_layout()
        plt.show()
            
    def load_pretrain_model(self,
                             state_path: str):
        
        pretrain = torch.load(state_path)
        self.net.load_state_dict(pretrain.state_dict())
        print("Pretrain model loaded")

    def check_epoch_number(self, checkpoint_dir):
        value_of_hash = 0
        # Get a list of all files in the checkpoint directory
        all_files = os.listdir(checkpoint_dir)

        # Filter the files to get only the model checkpoint files
        model_checkpoint_files = [file for file in all_files if file.startswith("last_epoch_model")]
        
        # Sort the model checkpoint files based on their names (assuming they contain the epoch number)
        
        if model_checkpoint_files:

            sorted_file_names = sorted(model_checkpoint_files, key=lambda x: int(x.split('_')[-1].split('.')[0]))
            
            # Get the latest model checkpoint file
            latest_checkpoint_file = sorted_file_names[-1]
            

            # Construct the full path to the latest model checkpoint
            pretrained_model_path = os.path.join(checkpoint_dir, latest_checkpoint_file)
            latest = pretrained_model_path.split("_")
            value_of_hash = latest[-1].split(".")[0]
            return value_of_hash
        else:
            return value_of_hash
        
    def _save_train_history(self, epoch):
        """writing model weights and training logs to files."""
        torch.save(self.net,
                   f"{self.model_type}\last_epoch_model_{epoch}.pth")

        logs_ = [self.losses, self.dice_scores, self.jaccard_scores, self.time]
        
        log_names_ = ["_loss", "_dice", "_jaccard", "_time"]
        logs = [logs_[i][key] for i in list(range(len(logs_)))
                         for key in logs_[i]]
        log_names = [key+log_names_[i] 
                     for i in list(range(len(logs_))) 
                     for key in logs_[i]
                    ]
        pd.DataFrame(
            dict(zip(log_names, logs))
        ).to_csv(f"{self.model_type}/train_log.csv", index=False)



# Creating and Defining the UNet models

### UNet

In [None]:
class DoubleConv(nn.Module):  # Each layer has two convolutions
    """(Conv3D -> BN -> ReLU) * 2"""
    def __init__(self, in_channels, out_channels, num_groups=8):
        super().__init__()
        self.double_conv = nn.Sequential(
            # Convlution set one 
            nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm3d(out_channels),
            nn.GroupNorm(num_groups=num_groups, num_channels=out_channels),
            nn.ReLU(inplace=True),

            # Convlution set two 
            nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm3d(out_channels),
            nn.GroupNorm(num_groups=num_groups, num_channels=out_channels),
            nn.ReLU(inplace=True)
            
          )

    def forward(self,x): # Move forward will always go through 2 convolutional layers
        return self.double_conv(x)

    
class Down(nn.Module): # Move downwards 

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.MaxPool3d(2, 2),
            DoubleConv(in_channels, out_channels)
        )
    def forward(self, x):
        # max pooling 3d + doubleConv 
        return self.encoder(x)

    
class Up(nn.Module):

    def __init__(self, in_channels, out_channels, trilinear=True):
        super().__init__()
        
        if trilinear:
            self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose3d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
            
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)

        diffZ = x2.size()[2] - x1.size()[2]
        diffY = x2.size()[3] - x1.size()[3]
        diffX = x2.size()[4] - x1.size()[4]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2, diffZ // 2, diffZ - diffZ // 2])

        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

    
class Out(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size = 1)

    def forward(self, x):
        return self.conv(x)


class UNet3d(nn.Module):
    def __init__(self, in_channels, n_classes, n_channels):
        super().__init__()
        self.in_channels = in_channels
        self.n_classes = n_classes
        self.n_channels = n_channels

        # extracting the features by incrementally multiplying the no.of channels 
        self.conv = DoubleConv(in_channels, n_channels) #64
        self.enc1 = Down(n_channels, 2 * n_channels) #64,128
        self.enc2 = Down(2 * n_channels, 4 * n_channels) #128, 256
        self.enc3 = Down(4 * n_channels, 8 * n_channels) #256, 512
        self.enc4 = Down(8 * n_channels, 8 * n_channels) #512, 512

        self.dec1 = Up(16 * n_channels, 4 * n_channels) # 512+512, 256
        self.dec2 = Up(8 * n_channels, 2 * n_channels)
        self.dec3 = Up(4 * n_channels, n_channels)
        self.dec4 = Up(2 * n_channels, n_channels)
        self.out = Out(n_channels, n_classes)

    def forward(self, x):
        
        x1 = self.conv(x)
        x2 = self.enc1(x1)
        x3 = self.enc2(x2)
        x4 = self.enc3(x3)
        x5 = self.enc4(x4)

        mask = self.dec1(x5, x4)
        mask = self.dec2(mask, x3)
        mask = self.dec3(mask, x2)
        mask = self.dec4(mask, x1)
        mask = self.out(mask)
        
        """
        After a series of either Upsampling / 3d Transpose
        a segmented image of the input image is generated 
        & returned 
        """
        #print(mask.shape)
        return mask
    


### Train the UNet model

In [None]:
model = UNet3d(in_channels=4, n_classes=3, n_channels=24).to('cuda')


trainer = Trainer(net=model,
                  dataset=BratsDataset,
                  criterion=BCEDiceLoss(),
                  lr=5e-4,
                  accumulation_steps=4,
                  batch_size=1,
                  fold=0,
                  num_epochs=200,
                  path_to_csv = config.path_to_csv,
                  model_type = config.UNet_checkpoint_dir
                  )



if check_exist(config.UNet_checkpoint_dir) is not None:
    trainer.load_pretrain_model(check_exist(config.UNet_checkpoint_dir))
    
    # if need - load the logs.      
    train_logs = pd.read_csv(config.train_logs_path)
    trainer.losses["train"] =  train_logs.loc[:, "train_loss"].to_list()
    trainer.losses["valid"] =  train_logs.loc[:, "valid_loss"].to_list()
    trainer.dice_scores["train"] = train_logs.loc[:, "train_dice"].to_list()
    trainer.dice_scores["valid"] = train_logs.loc[:, "valid_dice"].to_list()
    trainer.jaccard_scores["train"] = train_logs.loc[:, "train_jaccard"].to_list()
    trainer.jaccard_scores["valid"] = train_logs.loc[:, "valid_jaccard"].to_list()
    trainer.time["train"] = train_logs.loc[:, "train_time"].to_list()
    trainer.time["valid"] = train_logs.loc[:, "valid_time"].to_list()



trainer.run(config.UNet_checkpoint_dir)

### ResUNet 

In [None]:
class ResBlock(nn.Module):  # Each layer has two convolutions
    """(Conv3D -> BN -> ReLU) * 2"""
    def __init__(self, in_channels, out_channels, num_groups=8):
        super().__init__()
        self.residual_block = nn.Sequential(
            # Convlution set one 
            # nn.BatchNorm3d(in_channels),
            nn.GroupNorm(num_groups=num_groups, num_channels=in_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            #
            # Convlution set two 
            nn.GroupNorm(num_groups=num_groups, num_channels=out_channels),
            # nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
            #
          )
        
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)

    def forward(self,x): # Move forward will always go through 2 convolutional layers
        s = x
        s = self.conv(s)
        x = self.residual_block(x)
        y = x + s
        return y

    
class Down(nn.Module): # Move downwards 

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.MaxPool3d(2, 2),
            ResBlock(in_channels, out_channels)
            
        )
    def forward(self, x):
        # max pooling 3d + doubleConv 
        return self.encoder(x)

    
class Up(nn.Module):

    def __init__(self, in_channels, out_channels, trilinear=True):
        super().__init__()
        
        if trilinear:
            self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose3d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
            
        self.conv = ResBlock(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)

        diffZ = x2.size()[2] - x1.size()[2]
        diffY = x2.size()[3] - x1.size()[3]
        diffX = x2.size()[4] - x1.size()[4]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2, diffZ // 2, diffZ - diffZ // 2])

        x = torch.cat([x2, x1], dim=1)

        return self.conv(x)

    
class Out(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(nn.Conv3d(in_channels, out_channels, kernel_size = 1)
                                  
        )

    def forward(self, x):
        return self.conv(x)

class FirstLayer(nn.Module):

    def __init__(self, in_channels, out_channels, num_groups = 8):
        super().__init__()
        self.residual_block = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.GroupNorm(num_groups=num_groups, num_channels=out_channels),
            # # Convlution set two 
            nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        )

        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)

    def forward(self, x):
        s = self.conv(x)
        x = self.residual_block(x)   
        y = x + s
        return y


class ResUNet3d(nn.Module):
    def __init__(self, in_channels, n_classes, n_channels):
        super().__init__()
        self.in_channels = in_channels
        self.n_classes = n_classes
        self.n_channels = n_channels

        # extracting the features by incrementally multiplying the no.of channels 
        self.conv = FirstLayer(in_channels, n_channels)
        self.enc1 = Down(n_channels, 2 * n_channels)
        self.enc2 = Down(2 * n_channels, 4 * n_channels)
        self.enc3 = Down(4 * n_channels, 8 * n_channels)

        self.enc4 = Down(8 * n_channels, 8 * n_channels)

        self.dec1 = Up(16 * n_channels, 4 * n_channels)
        self.dec2 = Up(8 * n_channels, 2 * n_channels)
        self.dec3 = Up(4 * n_channels, n_channels)
        self.dec4 = Up(2 * n_channels, n_channels)
        self.out = Out(n_channels, n_classes)

    def forward(self, x):
        
        x1 = self.conv(x)
        x2 = self.enc1(x1)
        x3 = self.enc2(x2)
        x4 = self.enc3(x3)
        x5 = self.enc4(x4)

        mask = self.dec1(x5, x4)
        mask = self.dec2(mask, x3)
        mask = self.dec3(mask, x2)
        mask = self.dec4(mask, x1)
        mask = self.out(mask)
        
        return mask
    

### Train the ResUNet model

In [None]:
model2 = ResUNet3d(in_channels=4, n_classes=3, n_channels=24).to('cuda')
trainer = Trainer(net=model2,
                  dataset=BratsDataset,
                  criterion=BCEDiceLoss(),
                  lr=5e-4,
                  accumulation_steps=4,
                  batch_size=1,
                  fold=0,
                  num_epochs=200,
                  path_to_csv = config.path_to_csv,
                  model_type = config.ResUNet_checkpoint_dir
                  )

if config.ResUNet_checkpoint_dir is not None:
    trainer.load_pretrain_model(check_exist(config.ResUNet_checkpoint_dir))
    
    # if need - load the logs.      
    train_logs = pd.read_csv(config.ResUNet_train_logs_path)
    trainer.losses["train"] =  train_logs.loc[:, "train_loss"].to_list()
    trainer.losses["valid"] =  train_logs.loc[:, "valid_loss"].to_list()
    trainer.dice_scores["train"] = train_logs.loc[:, "train_dice"].to_list()
    trainer.dice_scores["valid"] = train_logs.loc[:, "valid_dice"].to_list()
    trainer.jaccard_scores["train"] = train_logs.loc[:, "train_jaccard"].to_list()
    trainer.jaccard_scores["valid"] = train_logs.loc[:, "valid_jaccard"].to_list()
    trainer.time["train"] = train_logs.loc[:, "train_time"].to_list()
    trainer.time["valid"] = train_logs.loc[:, "valid_time"].to_list()

trainer.run(config.ResUNet_checkpoint_dir)

### AttUNet

#### Attention Modules

In [None]:

class ChannelAttention(nn.Module):
    def __init__(self, ch, ratio=8):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.max_pool = nn.AdaptiveMaxPool3d(1)
        self.ratio = ratio
        self.sigmoid = nn.Sigmoid()
        self.channel = ch
       
        self.mlp = nn.Sequential(
            nn.Conv3d(self.channel, self.channel // self.ratio, kernel_size=1, bias=False),
            nn.ReLU(),
            nn.Conv3d(self.channel // self.ratio, self.channel, kernel_size=1, bias=False)
        )
    def forward(self, x):
        x1 = self.avg_pool(x)
        
        
        x1 = self.mlp(x1) 
        x2 = self.max_pool(x)
        x2 = self.mlp(x2)

        feats = x1 + x2
        feats = self.sigmoid(feats)
        refined_feats = x * feats
        return refined_feats


class SpatialAttention(nn.Module):
    def __init__(self, ch, kernel_size=7):
        super(SpatialAttention, self).__init__()

        self.conv1 = nn.Conv3d(2, ch, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        
        feats = self.sigmoid(x)
        refined_feats = x* feats
        return refined_feats
    
    
class cbam(nn.Module):
    def __init__(self, channel):
        super().__init__()


        self.ca = ChannelAttention(channel)
        
        self.sa = SpatialAttention(channel)
        
    def forward(self, x):
        x = self.ca(x)
        
        x = self.sa(x)
        
        return x
    
# https://idiotdeveloper.com/attention-unet-in-pytorch/
class attention_gate(nn.Module):
    def __init__(self, in_c, out_c): 
        super().__init__()
        
        self.Wg = nn.Sequential(
            nn.Conv3d(in_c, out_c, kernel_size=1, padding=0),
            nn.BatchNorm3d(out_c)
        )
        self.Ws = nn.Sequential(
            
            nn.Conv3d(in_c, out_c, kernel_size=1, padding=0),
            nn.BatchNorm3d(out_c)
            
        )
        self.relu = nn.ReLU(inplace=True)
        self.output = nn.Sequential(
            nn.Conv3d(out_c, out_c, kernel_size=1, padding=0),
            nn.Sigmoid()
        )
 
    def forward(self, g, s):
        
        Wg = self.Wg(g) #from attention gate
        
       
        Ws = self.Ws(s) #from skip connection
        
        out = self.relu(Wg + Ws)
        
        out = self.output(out)
        
        return out * Ws


#### AttUNet architecture

In [None]:


class AttUp(nn.Module):

    def __init__(self, in_channels, out_channels, trilinear=True):
        super().__init__()
        
        if trilinear:
            self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose3d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
            
        self.conv = DoubleConv(in_channels, out_channels)
        
        self.num_channels = in_channels //2
        self.cbam_module =  cbam(self.num_channels)
        self.attention_gate = attention_gate(self.num_channels, self.num_channels)
    def forward(self, x1, x2):
        x1 = self.up(x1)
        
        diffZ = x2.size()[2] - x1.size()[2]
        diffY = x2.size()[3] - x1.size()[3]
        diffX = x2.size()[4] - x1.size()[4]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2, diffZ // 2, diffZ - diffZ // 2])
        
        x2 = self.cbam_module(x2)
        
        x2 = self.attention_gate(x1,x2)
        
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)
    
class Down(nn.Module): # Move downwards 

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.MaxPool3d(2, 2),
            DoubleConv(in_channels, out_channels)
        )
    def forward(self, x):
        # max pooling 3d + doubleConv 
        return self.encoder(x)
    
class AttUNet3d(nn.Module):
    
    def __init__(self, in_channels, n_classes, n_channels):
        super().__init__()
        self.in_channels = in_channels
        self.n_classes = n_classes
        self.n_channels = n_channels

        # extracting the features by incrementally multiplying the no.of channels 
        self.conv = DoubleConv(in_channels, n_channels)
        self.enc1 = Down(n_channels, 2 * n_channels)
        self.enc2 = Down(2 * n_channels, 4 * n_channels)
        self.enc3 = Down(4 * n_channels, 8 * n_channels)
        self.enc4 = Down(8 * n_channels, 8 * n_channels)

        self.dec1 = AttUp(16 * n_channels, 4 * n_channels)
        self.dec2 = AttUp(8* n_channels, 2 * n_channels)
        self.dec3 = AttUp(4 * n_channels, n_channels)
        self.dec4 = AttUp(2 * n_channels, n_channels)
        self.out = Out(n_channels, n_classes)

    def forward(self, x):
        x1 = self.conv(x) 
        x2 = self.enc1(x1)
        x3 = self.enc2(x2)
        x4 = self.enc3(x3)
        
        #Bridge
        
        x5 = self.enc4(x4)
        
        #Decoder
        mask = self.dec1(x5, x4)
        mask = self.dec2(mask, x3)
        mask = self.dec3(mask, x2)
        mask = self.dec4(mask, x1)
        mask = self.out(mask)
        
        return mask
    


#### Train the AttUNet model

In [None]:
model3 = AttUNet3d(in_channels=4, n_classes=3, n_channels=24).to('cuda')

trainer = Trainer(net=model3,
                  dataset=BratsDataset,
                  criterion=BCEDiceLoss(),
                  lr=5e-4,
                  accumulation_steps=4,
                  batch_size=1,
                  fold=0,
                  num_epochs=100,
                  path_to_csv = config.path_to_csv,
                  model_type = config.Att_checkpoint_dir
                  )

if config.Att_checkpoint_dir is not None:
    trainer.load_pretrain_model(check_exist(config.Att_checkpoint_dir))
    
    # if need - load the logs.      
    train_logs = pd.read_csv(config.AttUNet_train_logs_path)
    trainer.losses["train"] =  train_logs.loc[:, "train_loss"].to_list()
    trainer.losses["valid"] =  train_logs.loc[:, "valid_loss"].to_list()
    trainer.dice_scores["train"] = train_logs.loc[:, "train_dice"].to_list()
    trainer.dice_scores["valid"] = train_logs.loc[:, "valid_dice"].to_list()
    trainer.jaccard_scores["train"] = train_logs.loc[:, "train_jaccard"].to_list()
    trainer.jaccard_scores["valid"] = train_logs.loc[:, "valid_jaccard"].to_list()
    trainer.time["train"] = train_logs.loc[:, "train_time"].to_list()
    trainer.time["valid"] = train_logs.loc[:, "valid_time"].to_list()

trainer.run(config.Att_checkpoint_dir)
    

In [None]:
# if get error two different devices, need to initiate all classes in init function, can't be in forward function

# Training, Testing and Validating Evaluation

Load the models and put them into evaluation mode

In [None]:
UNet = torch.load(r'C:\Users\ethan\OneDrive\Desktop\Thesis Brain Tumour\UNet model\best_model_79.pth')
ResUNet = torch.load(r'C:\Users\ethan\OneDrive\Desktop\Thesis Brain Tumour\ResUNet model\best_model_93.pth')
AttUNet = torch.load(r'C:\Users\ethan\OneDrive\Desktop\Thesis Brain Tumour\AttUNet model\best_model_74.pth')

In [None]:

UNet.eval()

In [None]:
ResUNet.eval()

In [None]:
AttUNet.eval()

Load in the Validation dataset

In [None]:
val_dataloader = get_dataloader(dataset=BratsDataset, path_to_csv='tumourCSV.csv', phase="valid")
len(val_dataloader)

Obtain the TP, FP, TN and FN scores for the pixel classification evaluation metrics

In [None]:
import torch
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report

gc.collect() 
def compute_metrics(model, dataloader, threshold=0.33):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


    true_positives = 0
    false_positives = 0
    true_negatives = 0
    false_negatives = 0

    counter = 0  # Counter to keep track of the number of entries processed

    with torch.no_grad():  # Disable gradient calculations to save memory
        for data in dataloader:
            
            images, targets = data['image'], data['mask']
            images = images.to(device)
            targets = targets.to(device)

            logits = model(images)
            probabilities = torch.sigmoid(logits)
            predictions = (probabilities >= threshold).float()

            # Compute binary segmentation metrics
            true_positives += torch.sum((predictions == 1) & (targets == 1)).item()
            false_positives += torch.sum((predictions == 1) & (targets == 0)).item()
            true_negatives += torch.sum((predictions == 0) & (targets == 0)).item()
            false_negatives += torch.sum((predictions == 0) & (targets == 1)).item()

            counter += 1

            # Free memory by clearing intermediate variables
            del images, targets, logits, probabilities, predictions
            torch.cuda.empty_cache()

    return true_positives , false_positives , true_negatives , false_negatives



In [None]:

def plot_confusion_matrix(ax, tp, fp, tn, fn, title):
    # Create confusion matrix array
    confusion_matrix = np.array([[tp, fp], [fn, tn]])

    # Set up labels for matrix
    labels = ['True ', 'False ']

    # Create color map
    cmap = plt.cm.Blues

    # Plot confusion matrix
    cax = ax.matshow(confusion_matrix, interpolation='nearest', cmap=cmap)
    ax.set_title(title)
    
    # Add colorbar to the figure
    fig.colorbar(cax, ax=ax)

    # Add labels to matrix cells
    thresh = confusion_matrix.max() / 2.
    for i, j in np.ndindex(confusion_matrix.shape):
        ax.text(j, i, format(confusion_matrix[i, j], 'd'), horizontalalignment='center', color='white' if confusion_matrix[i, j] > thresh else 'black')

    # Set tick labels
    tick_marks = np.arange(len(labels))
    ax.set_xticks(tick_marks)
    ax.set_yticks(tick_marks)
    ax.set_xticklabels(labels, rotation=45)
    ax.set_yticklabels(labels)

    # Set axis labels
    ax.set_xlabel('Predicted label')
    ax.set_ylabel('True label')



In [None]:
def metric (tp,tn,fp,fn):
    accuracy = (tp + tn) / (tp + tn + fp + fn)
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    f1_score = 2 * (precision * recall) / (precision + recall)
    return accuracy, precision, recall, f1_score

### Validation Dataset

In [None]:
tp , fp , tn , fn  = compute_metrics(UNet, val_dataloader, threshold=0.33)
resTP, resFP, resTN, resFN = compute_metrics(ResUNet, val_dataloader, threshold=0.33)
attTP, attFP, attTN, attFN = compute_metrics(AttUNet, val_dataloader, threshold=0.33)

In [None]:
# Create a figure with subplots in a single row
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Plot the confusion matrices
plot_confusion_matrix(axes[0], tp, fp, tn, fn, "Confusion Matrix UNet (Validation)")
plot_confusion_matrix(axes[1], resTP, resFP, resTN, resFN, "Confusion Matrix ResUNet (Validation)")
plot_confusion_matrix(axes[2], attTP, attFP, attTN, attFN, "Confusion Matrix AttUNet (Validation)")

# Adjust layout to prevent overlap
plt.tight_layout()
plt.show()

In [None]:
UNetMetric = metric(tp, tn, fp, fn)
ResUNetMetric = metric(resTP, resTN, resFP, resFN)
AttUNetMetric = metric(attTP, attTN, attFP, attFN)

dictionary = {'UNet':UNetMetric,
        'ResUNet':ResUNetMetric,
        'AttUNet':AttUNetMetric
        }
df = pd.DataFrame(dictionary,['Validation Accuracy', 'Validation Precision', 'Validation Recall', 'Validation F1 Score'])
df


In [None]:
df.plot(kind='bar', figsize=(10, 6))
plt.title('Model Performance Metrics (Validation)')
plt.xlabel('Metrics')
plt.ylabel('Scores')
plt.xticks(rotation=0)
plt.legend(title='Model')
plt.tight_layout()
plt.show()

### Testing Dataset

In [None]:
tp , fp , tn , fn  = compute_metrics(UNet, test_dataloader, threshold=0.33)
resTP, resFP, resTN, resFN = compute_metrics(ResUNet, test_dataloader, threshold=0.33)
attTP, attFP, attTN, attFN = compute_metrics(AttUNet, test_dataloader, threshold=0.33)

In [None]:
# Create a figure with subplots in a single row
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Plot the confusion matrices
plot_confusion_matrix(axes[0], tp, fp, tn, fn, "Confusion Matrix UNet (Testing)")
plot_confusion_matrix(axes[1], resTP, resFP, resTN, resFN, "Confusion Matrix ResUNet (Testing)")
plot_confusion_matrix(axes[2], attTP, attFP, attTN, attFN, "Confusion Matrix AttUNet (Testing)")

# Adjust layout to prevent overlap
plt.tight_layout()
plt.show()

In [None]:
UNetMetric = metric(tp, tn, fp, fn)
ResUNetMetric = metric(resTP, resTN, resFP, resFN)
AttUNetMetric = metric(attTP, attTN, attFP, attFN)

dictionary = {'UNet':UNetMetric,
        'ResUNet':ResUNetMetric,
        'AttUNet':AttUNetMetric
        }
df1 = pd.DataFrame(dictionary,['Testing Accuracy', 'Testing Precision', 'Testing Recall', 'Testing F1 Score'])
df1

In [None]:
df1.plot(kind='bar', figsize=(10, 6))
plt.title('Model Performance Metrics (Testing)')
plt.xlabel('Metrics')
plt.ylabel('Scores')
plt.xticks(rotation=0)
plt.legend(title='Model')
plt.tight_layout()
plt.show()

### Calculate Scores per Class

In [None]:
import torch
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report

gc.collect()

def compute_metrics(model, dataloader, num_entries, threshold=0.33):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    predictions = []

    # Counter to keep track of the number of entries processed
    counter = 0  

    with torch.no_grad():
        for data in dataloader:
            if counter >= num_entries:
                break  # Stop processing entries if the desired number is reached

            images, targets = data['image'], data['mask']
            images = images.to(device)
            targets = targets.to(device)

            logits = model(images)
            probabilities = torch.sigmoid(logits)
            prediction = (probabilities >= threshold).float()

            prediction =  prediction.cpu()
            targets = targets.cpu()
            
            predictions.append(prediction)

            model.zero_grad()
            del images, targets, logits, probabilities, prediction
            torch.cuda.empty_cache()

            counter += 1

    # Compute confusion matrix
    y_true = np.concatenate([targets.cpu() for data in dataloader for targets in data['mask']])
    y_pred = np.concatenate([prediction.cpu() for data in dataloader for prediction in predictions])
    

    
    cm = confusion_matrix(y_true, y_pred)
    
    # Compute classification report
    class_names = ['Background', 'Tumor']
    report = classification_report(y_true, y_pred, target_names=class_names)

    # Return evaluation metrics, confusion matrix, and classification report
    evaluation_results = {
        'Confusion Matrix': cm,
        'Classification Report': report
    }

    return evaluation_results

In [None]:
val_dataloader = get_dataloader(BratsDataset, 'tumourCSV.csv', phase='valid')
len(val_dataloader)

In [None]:
test_dataloader = get_dataloader(BratsDataset, 'tumourCSV.csv', phase='test')
len(test_dataloader)

# Training Phase Evaluation

#### Obtaining Time taken to train and validate per epoch during the training phase

In [None]:
base_train_data = pd.read_csv(r'C:\Users\ethan\OneDrive\Desktop\Thesis Brain Tumour\UNet model\train_log.csv')
res_train_data = pd.read_csv(r'C:\Users\ethan\OneDrive\Desktop\Thesis Brain Tumour\ResUNet model\train_log.csv')
att_train_data = pd.read_csv(r'C:\Users\ethan\OneDrive\Desktop\Thesis Brain Tumour\AttUNet model\train_log.csv')

def gettime(train_data, model):
    base_mean_tratime = train_data['train_time'].mean()
    base_mean_valtime = train_data['valid_time'].mean() 
    
    # Create a dictionary with the results
    data = {
        'Model': [model],
        f'Training Time ({len(dataloader)} instances)': [f"{base_mean_tratime: .2f} s"],
        f'Validation Time ({len(val_dataloader)} instances) ': [f"{base_mean_valtime: .2f} s"],
    }
    
    # Convert the dictionary to a Pandas DataFrame
    df = pd.DataFrame(data)

    return df


df_time = pd.concat([gettime(base_train_data, "UNet"),gettime(res_train_data, "ResUNet"),gettime(att_train_data, "AttUNet")])

df_time.head()

#### Graphical Representations of Training vs Validation

In [None]:

def plotScoresindi(metric, model=['UNet', 'ResUNet', 'AttUNet']):

    for i in model:
        print(i)
        train_data = pd.read_csv(rf'C:\Users\ethan\OneDrive\Desktop\Thesis Brain Tumour\{i} model\train_log.csv')
        plt.figure(figsize=(12, 6))

        # Plotting training Dice scores
        plt.plot(train_data[f'train_{metric}'], label=f'Training {metric}')

        # Plotting validation Dice scores
        plt.plot(train_data[f'valid_{metric}'], label=f'Validation {metric}')

        # Adding titles and labels
        plt.title(f'Training and Validation {metric} Scores for {i}')
        plt.xlabel('Epoch')
        plt.ylabel('Dice Score')
        plt.legend()
        plt.grid(True)
        plt.show()



def plotScores(metric, models=['UNet', 'ResUNet', 'AttUNet']):
    plt.figure(figsize=(12, 6))
    
    for model in models:
        
        train_data = pd.read_csv(rf'C:\Users\ethan\OneDrive\Desktop\Thesis Brain Tumour\{model} model\train_log.csv')
        print(model, ' num epochs:', len(train_data))

        # Plotting training Dice scores
        plt.plot(train_data[f'train_{metric}'], label=f'Training {metric} ({model})')

        # Plotting validation Dice scores
        plt.plot(train_data[f'valid_{metric}'], label=f'Validation {metric} ({model})')

    # Adding titles and labels
    plt.title(f'Training and Validation {metric} Scores for {models}')
    plt.xlabel('Epoch')
    plt.ylabel('Dice Score')
    plt.legend()
    plt.grid(True)
    
    # Show the plot
    plt.show()



### Plot of BCEDice Loss 

In [None]:
plotScores('loss')
plotScoresindi('loss')

In [None]:
plotScores('dice')
plotScoresindi('dice')

In [None]:
plotScores('jaccard')
plotScoresindi('jaccard')

# Calculating Jaccard and Dice Scores per Class

### Validation

In [None]:
def compute_scores_per_classes(model,          # model 
                               dataloader,     # tuple consisting of ( id , image tensor , mask tensor )
                               classes):       # classes : WT , TC , ET 
    """
    Compute Dice and Jaccard coefficients for each class.
    Params:
        model: neural net for make predictions.
        dataloader: dataset object to load data from.
        classes: list with classes.
        Returns: dictionaries with dice and jaccard coefficients for each class for each slice.
    """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    dice_scores_per_classes = {key: list() for key in classes}
    iou_scores_per_classes = {key: list() for key in classes}
    haus_scores_per_classes = {key: list() for key in classes}

    with torch.no_grad():
        for i, data in enumerate(dataloader):
            imgs, targets = data['image'], data['mask']
            imgs, targets = imgs.to(device), targets.to(device)
            logits = model(imgs)
            logits = logits.detach().cpu().numpy()
            targets = targets.detach().cpu().numpy()
            
            # Now finding the overlap between the raw prediction i.e. logit & the mask i.e. target & finding the dice & iou scores 
            dice_scores = dice_coef_metric_per_classes(logits, targets)
            iou_scores = jaccard_coef_metric_per_classes(logits, targets)
            #haus_scores = hausdorff_distance_metric_per_class(logits, targets)
             
            # storing both dice & iou scores in the list declared 
            for key in dice_scores.keys():
                dice_scores_per_classes[key].extend(dice_scores[key])

            for key in iou_scores.keys():
                iou_scores_per_classes[key].extend(iou_scores[key])

            # for key in iou_scores.keys():
            #     haus_scores_per_classes[key].extend(haus_scores[key])

    return dice_scores_per_classes, iou_scores_per_classes

def compute_scores_per_classes_mean(model,          
                               dataloader,    
                               classes):
    
    dice_scores_per_classes, iou_scores_per_classes = compute_scores_per_classes(model,          
                               dataloader,    
                               classes)
    

    dice_means = {key: np.mean(values) for key, values in dice_scores_per_classes.items()}
    iou_means = {key: np.mean(values) for key, values in iou_scores_per_classes.items()}
    
    return dice_means, iou_means

In [None]:
def bar(model_metrics, model_name, type, ax=None):
    colors = ['#35FCFF', '#FF355A', '#96C503', '#C5035B', '#28B463', '#35FFAF']
    palette = sns.color_palette(colors, 6)
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 6))
    
    sns.barplot(x=model_metrics.mean().index, y=model_metrics.mean(), palette=palette, ax=ax)
    ax.set_xticklabels(model_metrics.columns, fontsize=14, rotation=15)
    ax.set_title(f"{model_name} Dice and Jaccard Coefficients from {type}", fontsize=20)
    
    ax.set_xlabel(None, fontsize=16)
    ax.set_ylabel('Dice Score', fontsize=16)

    for idx, p in enumerate(ax.patches):
        percentage = '{:.3f}'.format(model_metrics.mean().values[idx])
        x = p.get_x() + p.get_width() / 2 - 0.15
        y = p.get_y() + p.get_height()
        ax.annotate(percentage, (x, y), fontsize=15, fontweight="bold")

In [None]:
dice_scores_per_classes, iou_scores_per_classes = compute_scores_per_classes(
    UNet, val_dataloader, ['WT', 'TC', 'ET']
    )
dice_df = pd.DataFrame(dice_scores_per_classes)
dice_df.columns = ['WT dice', 'TC dice', 'ET dice']

iou_df = pd.DataFrame(iou_scores_per_classes)
iou_df.columns = ['WT jaccard', 'TC jaccard', 'ET jaccard']
val_metrics_df = pd.concat([dice_df, iou_df], axis=1, sort=True)

                                      

base_sample = val_metrics_df.round(6).sample(5)
base_sample   

In [None]:
resdice_scores_per_classes, resiou_scores_per_classes = compute_scores_per_classes(
    ResUNet, val_dataloader, ['WT', 'TC', 'ET']
    )

resdice_df = pd.DataFrame(resdice_scores_per_classes)
resdice_df.columns = ['WT dice', 'TC dice', 'ET dice']

resiou_df = pd.DataFrame(resiou_scores_per_classes)
resiou_df.columns = ['WT jaccard', 'TC jaccard', 'ET jaccard']
resval_metrics_df = pd.concat([resdice_df, resiou_df], axis=1, sort=True)


res_sample = resval_metrics_df.round(6).sample(5)
res_sample    

In [None]:
attdice_scores_per_classes, attiou_scores_per_classes = compute_scores_per_classes(
    AttUNet, val_dataloader, ['WT', 'TC', 'ET']
    )

attdice_df = pd.DataFrame(attdice_scores_per_classes)
attdice_df.columns = ['WT dice', 'TC dice', 'ET dice']

attiou_df = pd.DataFrame(attiou_scores_per_classes)
attiou_df.columns = ['WT jaccard', 'TC jaccard', 'ET jaccard']
attval_metrics_df = pd.concat([attdice_df, attiou_df], axis=1, sort=True)

                                      

att_sample = attval_metrics_df.round(6).sample(5)
att_sample

In [None]:
dice_avg =pd.DataFrame([np.mean(dice_df.values, axis =0).round(6), np.mean(resdice_df.values, axis =0).round(6), np.mean(attdice_df.values, axis =0).round(6)], index = ['UNet', 'ResUNet', 'AttUNet'], columns = dice_df.columns)
jac_avg = pd.DataFrame([np.mean(iou_df.values, axis =0).round(6), np.mean(resiou_df.values, axis =0).round(6), np.mean(attiou_df.values, axis =0).round(6)], index = ['UNet', 'ResUNet', 'AttUNet'], columns = iou_df.columns)

In [None]:
# Create subplots
fig, axs = plt.subplots(2, 1, figsize=(5, 3))

# Plotting UNet metrics
axs[0].axis('tight')
axs[0].axis('off')
axs[0].table(cellText=dice_avg.values, colLabels=dice_df.columns, cellLoc='center', loc='center', rowLabels = ['UNet', 'ResUNet', 'AttUNet'])
axs[0].set_title('WT, TC, ET Scores for Validation Dataset')

# Plotting ResUNet metrics
axs[1].axis('tight')
axs[1].axis('off')
axs[1].table(cellText=jac_avg.values, colLabels=iou_df.columns, cellLoc='center', loc='center', rowLabels = ['UNet', 'ResUNet', 'AttUNet'])


plt.tight_layout()
plt.show()

In [None]:
fig, axs = plt.subplots(3, 1, figsize=(10, 20))

# Plotting UNet metrics
bar(val_metrics_df, "UNet", type='Validation', ax=axs[0])

# Plotting ResUNet metrics
bar(resval_metrics_df, "ResUNet", type='Validation',ax=axs[1])

# Plotting AttUNet metrics
bar(attval_metrics_df, "AttUNet", type='Validation',ax=axs[2])

plt.tight_layout()
plt.show()

### Testing

In [None]:
dice_scores_per_classes, iou_scores_per_classes = compute_scores_per_classes(
    UNet, test_dataloader, ['WT', 'TC', 'ET']
    )
dice_df = pd.DataFrame(dice_scores_per_classes)
dice_df.columns = ['WT dice', 'TC dice', 'ET dice']

iou_df = pd.DataFrame(iou_scores_per_classes)
iou_df.columns = ['WT jaccard', 'TC jaccard', 'ET jaccard']
# CONCAT BOTH THE COLUMNS ALONG AXIS 1 & SORT THE TWO 
test_metrics_df = pd.concat([dice_df, iou_df], axis=1, sort=True)                         

base_sample_test = test_metrics_df.round(6).head(5)
base_sample_test     


In [None]:
resdice_scores_per_classes, resiou_scores_per_classes = compute_scores_per_classes(
    ResUNet, test_dataloader, ['WT', 'TC', 'ET']
    )

resdice_df = pd.DataFrame(resdice_scores_per_classes)
resdice_df.columns = ['WT dice', 'TC dice', 'ET dice']

resiou_df = pd.DataFrame(resiou_scores_per_classes)
resiou_df.columns = ['WT jaccard', 'TC jaccard', 'ET jaccard']
# CONCAT BOTH THE COLUMNS ALONG AXIS 1 & SORT THE TWO 
restest_metrics_df = pd.concat([resdice_df, resiou_df], axis=1, sort=True)

res_sample_test = restest_metrics_df.round(6).sample(5)
res_sample_test     

In [None]:
attdice_scores_per_classes, attiou_scores_per_classes = compute_scores_per_classes(
    AttUNet, test_dataloader, ['WT', 'TC', 'ET']
    )

attdice_df = pd.DataFrame(attdice_scores_per_classes)
attdice_df.columns = ['WT dice', 'TC dice', 'ET dice']

attiou_df = pd.DataFrame(attiou_scores_per_classes)
attiou_df.columns = ['WT jaccard', 'TC jaccard', 'ET jaccard']
# CONCAT BOTH THE COLUMNS ALONG AXIS 1 & SORT THE TWO 
atttest_metrics_df = pd.concat([attdice_df, attiou_df], axis=1, sort=True)

                                      

att_sample_test = atttest_metrics_df.round(6).sample(5)
att_sample_test

In [None]:

dice_avg =pd.DataFrame([np.mean(dice_df.values, axis =0).round(6), np.mean(resdice_df.values, axis =0).round(6), np.mean(attdice_df.values, axis =0).round(6)], index = ['UNet', 'ResUNet', 'AttUNet'], columns = dice_df.columns)
jac_avg = pd.DataFrame([np.mean(iou_df.values, axis =0).round(6), np.mean(resiou_df.values, axis =0).round(6), np.mean(attiou_df.values, axis =0).round(6)], index = ['UNet', 'ResUNet', 'AttUNet'], columns = iou_df.columns)

In [None]:
# Create subplots
fig, axs = plt.subplots(2, 1, figsize=(5, 3))

# Plotting UNet metrics
axs[0].axis('tight')
axs[0].axis('off')
axs[0].table(cellText=dice_avg.values, colLabels=dice_df.columns, cellLoc='center', loc='center', rowLabels = ['UNet', 'ResUNet', 'AttUNet'])
axs[0].set_title('WT, TC, ET Scores for Testing Dataset')

# Plotting ResUNet metrics
axs[1].axis('tight')
axs[1].axis('off')
axs[1].table(cellText=jac_avg.values, colLabels=iou_df.columns, cellLoc='center', loc='center', rowLabels = ['UNet', 'ResUNet', 'AttUNet'])


plt.tight_layout()
plt.show()

In [None]:
fig, axs = plt.subplots(3, 1, figsize=(10, 20))

# Plotting UNet metrics
bar(test_metrics_df, "UNet", type = 'Testing', ax=axs[0])

# Plotting ResUNet metrics
bar(restest_metrics_df, "ResUNet", type = 'Testing', ax=axs[1])

# Plotting AttUNet metrics
bar(atttest_metrics_df, "AttUNet", type = 'Testing', ax=axs[2])

plt.tight_layout()
plt.show()

#### Visualization for Tumours

In [None]:
def compute_results(model,
                    dataloader,
                    treshold=0.33):

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    results = {"Id": [],"image": [], "GT": [],"Prediction": []}

    with torch.no_grad():
        for i, data in enumerate(dataloader):
            id_, imgs, targets = data['Id'], data['image'], data['mask']
            imgs, targets = imgs.to(device), targets.to(device)
            logits = model(imgs)
            
            probs = torch.sigmoid(logits)
            
            predictions = (probs >= treshold).float()
            predictions =  predictions.cpu()
            targets = targets.cpu()
            
            results["Id"].append(id_)
            results["image"].append(imgs.cpu())
            results["GT"].append(targets)
            results["Prediction"].append(predictions)
            
            # only 5 pars
            if (i > 5):
                return results
        return results

In [None]:


import ipywidgets as widgets
from ipywidgets import interact, fixed


def tumour_graphics(n_slice, img, gt, prediction):
    print("Image Shape:", img.shape)
    print("GT Shape:", gt.shape)
    print("Prediction Shape:", prediction.shape)

    # Convert to NumPy for visualization
    img_np = img.cpu().numpy() if torch.is_tensor(img) else img
    gt_np = gt.cpu().numpy() if torch.is_tensor(gt) else gt
    prediction_np = prediction.cpu().numpy() if torch.is_tensor(prediction) else prediction

    # Create a maximum intensity projection for the side view
    
    side_view = np.max(img_np[0, 2, :, :, :], axis=2)
    plt.figure(figsize=(10, 7))
    
    
    ax = plt.subplot(241)
    plt.title('Side View (SAGITTAL VIEW)')
    plt.imshow(np.rot90(side_view,2), cmap='bone')
    plt.axhline(y=side_view.shape[0] - n_slice, color='r', linestyle='--')  # Indicate current slice
    plt.axis('off')
    
    ax.set_ylabel('Left Side of Brain')
    ax.set_xlabel("Bottom of Brain")
    
    # Adding the side view

    ax1 = plt.subplot(242)
    plt.title('WT Ground Truth')
    plt.imshow(img_np[0,0,n_slice,:,:], cmap='bone')
    plt.imshow(np.ma.masked_where(gt_np[0,0, n_slice]== False, gt_np[0,0, n_slice]), cmap='summer', alpha=0.6)
    ax1.set_ylabel('Left Side of Brain')
    ax1.set_xlabel("Bottom of Brain")

    ax2 = plt.subplot(243)
    plt.title('TC Ground Truth')
    plt.imshow(img_np[0,0,n_slice,:,:], cmap='bone')
    plt.imshow(np.ma.masked_where(gt_np[0,1, n_slice]== False, gt_np[0,1, n_slice]), cmap='rainbow', alpha=0.6)
    ax2.set_ylabel('Left Side of Brain')
    ax2.set_xlabel("Bottom of Brain")

    ax3 = plt.subplot(244)
    plt.title('ET Ground Truth')
    plt.imshow(img_np[0,0,n_slice,:,:], cmap='bone')
    plt.imshow(np.ma.masked_where(gt_np[0,2, n_slice]== False, gt_np[0,2, n_slice]), cmap='Wistia', alpha=0.6)
    ax3.set_ylabel('Left Side of Brain')
    ax3.set_xlabel("Bottom of Brain")


    ax4 = plt.subplot(246)
    plt.title('WT Prediction')
    plt.imshow(img_np[0,0,n_slice,:,:], cmap='bone')
    plt.imshow(np.ma.masked_where(prediction_np[0,0, n_slice]== False, prediction_np[0,0, n_slice]), cmap='summer', alpha=0.5)
    ax4.set_ylabel('Left Side of Brain')
    ax4.set_xlabel("Bottom of Brain")

    ax5 = plt.subplot(247)
    plt.title('TC Prediction')
    plt.imshow(img_np[0,0,n_slice,:,:], cmap='bone')
    plt.imshow(np.ma.masked_where(prediction_np[0,1, n_slice]== False, prediction_np[0,1, n_slice]), cmap='rainbow', alpha=0.5)
    ax5.set_ylabel('Left Side of Brain')
    ax5.set_xlabel("Bottom of Brain")

    ax6 = plt.subplot(248)
    plt.title('ET Prediction')
    plt.imshow(img_np[0,0,n_slice,:,:], cmap='bone')
    plt.imshow(np.ma.masked_where(prediction_np[0,2, n_slice]== False, prediction_np[0,2, n_slice]), cmap='Wistia', alpha=0.5)
    ax6.set_ylabel('Left Side of Brain')
    ax6.set_xlabel("Bottom of Brain")


    

    plt.tight_layout()
    plt.show()
    
    plt.figure(figsize=(11, 5))  # Adjusted the figsize to make it wider for one row of images

    plt.subplot(241)
    plt.title('Ground Truth')
    plt.imshow(img_np[0,0,n_slice,:,:], cmap='bone')
    plt.imshow(np.ma.masked_where(gt_np[0,0,n_slice,:,:] == False, gt_np[0,0,n_slice,:,:]),
               cmap='summer', alpha=0.8)
    plt.imshow(np.ma.masked_where(gt_np[0,1,n_slice,:,:] == False, gt_np[0,1,n_slice,:,:]), 
               cmap='rainbow', alpha=0.8)
    plt.imshow(np.ma.masked_where(gt_np[0,2,n_slice,:,:] == False, gt_np[0,2,n_slice,:,:]), 
               cmap='Wistia', alpha=0.8)

    ax2 = plt.subplot(242)
    ax2.set_facecolor('black')
    plt.title("Whole Tumour")
    plt.imshow(np.ma.masked_where(gt_np[0,0,n_slice,:,:] == False, gt_np[0,0,n_slice,:,:]),
               cmap='summer', alpha=0.8)
    
    ax2.set_ylabel('Left Side of Brain')
    ax2.set_xlabel("Bottom of Brain")

    ax3 = plt.subplot(243)
    ax3.set_facecolor('black')
    plt.title("Tumour Core")
    plt.imshow(np.ma.masked_where(gt_np[0,1,n_slice,:,:] == False, gt_np[0,1,n_slice,:,:]), 
               cmap='rainbow', alpha=0.8)

    ax3.set_ylabel('Left Side of Brain')
    ax3.set_xlabel("Bottom of Brain")

    ax4 = plt.subplot(244)
    ax4.set_facecolor('black')
    plt.title("Enhancing Tumour")
    plt.imshow(np.ma.masked_where(gt_np[0,2,n_slice,:,:] == False, gt_np[0,2,n_slice,:,:]), 
               cmap='Wistia', alpha=0.8)

    ax4.set_ylabel('Left Side of Brain')
    ax4.set_xlabel("Bottom of Brain")

    ax5 = plt.subplot(245)
    ax5.set_facecolor('black')
    plt.title("Prediction")
    plt.imshow(img_np[0,0,n_slice,:,:], cmap='bone')
    plt.imshow(np.ma.masked_where(prediction_np[0,0,n_slice,:,:] == False, prediction_np[0,0,n_slice,:,:]), 
               cmap='summer', alpha=0.8)
    plt.imshow(np.ma.masked_where(prediction_np[0,1,n_slice,:,:] == False, prediction_np[0,1,n_slice,:,:]), 
               cmap='rainbow', alpha=0.8)
    plt.imshow(np.ma.masked_where(prediction_np[0,2,n_slice,:,:] == False, prediction_np[0,2,n_slice,:,:]), 
               cmap='Wistia', alpha=0.8)
    
    ax5.set_ylabel('Left Side of Brain')
    ax5.set_xlabel("Bottom of Brain")

    ax6 = plt.subplot(246)
    ax6.set_facecolor('black')
    plt.title("Whole Tumour")
    plt.imshow(np.ma.masked_where(prediction_np[0,0,n_slice,:,:] == False, prediction_np[0,0,n_slice,:,:]),
               cmap='summer', alpha=0.8)

    ax6.set_ylabel('Left Side of Brain')
    ax6.set_xlabel("Bottom of Brain")

    ax7 = plt.subplot(247)
    ax7.set_facecolor('black')
    plt.title("Tumour Core")
    plt.imshow(np.ma.masked_where(prediction_np[0,1,n_slice,:,:] == False, prediction_np[0,1,n_slice,:,:]),
               cmap='rainbow', alpha=0.8)
    
    ax7.set_ylabel('Left Side of Brain')
    ax7.set_xlabel("Bottom of Brain")

    ax8 = plt.subplot(248)
    ax8.set_facecolor('black')
    plt.title("Enhancing Tumour")
    plt.imshow(np.ma.masked_where(prediction_np[0,2,n_slice,:,:] == False, prediction_np[0,2,n_slice,:,:]),
               cmap='Wistia', alpha=0.8)
    
    ax8.set_ylabel('Left Side of Brain')
    ax8.set_xlabel("Bottom of Brain")

    plt.tight_layout()
    plt.show()


In [None]:
import plotly.graph_objs as go
import plotly

def generate_3d_plotly(img, prediction, text):
    data = img[0, 0, :, :, :]
    data1 = prediction[0, 0, :, :, :]
    data2 = prediction[0,1,:,:,:]
    data3 = prediction[0,2,:,:,:]

    # Threshold value
    threshold = 0.2

    # Extract coordinates and values for the first tensor
    coords = (data > threshold).nonzero(as_tuple=False)
    z = coords[:, 0].numpy()
    y = coords[:, 1].numpy()
    x = coords[:, 2].numpy()
    values = data[data > threshold].numpy()

    # Create a 3D scatter plot for the first tensor
    scatter1 = go.Scatter3d(
        x=y,
        y=x,
        z=z,
        mode='markers',
        marker=dict(
            size=2,
            color=values,  # Color by the actual value
            colorscale='Viridis',
            opacity=0.2
        ),
        name='Tensor 1'
    )

    # Extract coordinates and values for the second tensor
    coords1 = (data1 > threshold).nonzero(as_tuple=False)
    z1 = coords1[:, 0].numpy()
    y1 = coords1[:, 1].numpy()
    x1 = coords1[:, 2].numpy()
    values1 = data1[data1 > threshold].numpy()

    # Create a 3D scatter plot for the second tensor
    scatter2 = go.Scatter3d(
        x=y1,
        y=x1,
        z=z1,
        mode='markers',
        marker=dict(
            size=2,
            color=values1,  # Color by the actual value
            colorscale='YlGn',
            opacity=0.8
        ),
        name='Whole Tumour'
    )

    coords2 = (data2 > threshold).nonzero(as_tuple=False)
    z2 = coords2[:, 0].numpy()
    y2 = coords2[:, 1].numpy()
    x2 = coords2[:, 2].numpy()
    values2 = data2[data2 > threshold].numpy()

    # Create a 3D scatter plot for the second tensor
    scatter3 = go.Scatter3d(
        x=y2,
        y=x2,
        z=z2,
        mode='markers',
        marker=dict(
            size=2,
            color=values2,  # Color by the actual value
            colorscale='Purples',
            opacity=0.5
        ),
        name='Tumour Core'
    )

    coords3 = (data3 > threshold).nonzero(as_tuple=False)
    z3 = coords3[:, 0].numpy()
    y3 = coords3[:, 1].numpy()
    x3 = coords3[:, 2].numpy()
    values3 = data3[data3 > threshold].numpy()


    yellow_colorscale = [
        [0, 'rgb(255,255,204)'],
        [0.25, 'rgb(255,255,153)'],
        [0.5, 'rgb(255,255,102)'],
        [0.75, 'rgb(255,255,51)'],
        [1, 'rgb(255,255,0)']
    ]

    # Create a 3D scatter plot for the second tensor
    scatter4 = go.Scatter3d(
        x=y3,
        y=x3,
        z=z3,
        mode='markers',
        marker=dict(
            size=2,
            color=values3,  # Color by the actual value
            colorscale=yellow_colorscale,
            opacity=0.4
        ),
        name='Enhancing Tumour'
    )

    # Create the layout
    layout = go.Layout(
        title=f'3D Scatter Plot of Brain Tumour ({text})',
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z'
        )
    )

    # Combine the scatter plots
    # fig = go.Figure(data=[scatter1, scatter2, scatter3, scatter4], layout=layout)
    fig = go.Figure(data=[ scatter2, scatter3, scatter4], layout=layout)
    # Show the plot
    plotly.offline.iplot(fig)

In [None]:
show_result = ShowResult()
BRAIN_INDEX = 1
n_slices = 100


### Prediction for UNet


In [None]:
test_dataloader = get_dataloader(dataset=BratsDataset, path_to_csv='tumourCSV.csv', phase='test')


In [None]:
# UNet = medcam.inject(UNet, output_dir="UNet_attention_maps", save_maps=True, backend="gcam")
# UNet.medcam_dict['channels'] = 3
# Uncomment the top two lines of code for GradCAM

results = compute_results(
    UNet, test_dataloader, 0.33)

print(results['Id'])
id_list = []
img_list = []
gt_list = []
prediction_list = []


for id_, img, gt, prediction in zip(results['Id'],
                    results['image'],
                    results['GT'],
                    results['Prediction']
                    ):
    
    id_list.append(id_)
    img_list.append(img)
    gt_list.append(gt)
    prediction_list.append(prediction)
    


In [None]:
n_slices = 100
interact(tumour_graphics, n_slice=widgets.IntSlider(min=0, max=n_slices-1, step=1, value=0), img=fixed(img_list[BRAIN_INDEX]), gt=fixed(gt_list[BRAIN_INDEX]), prediction=fixed(prediction_list[BRAIN_INDEX]))

In [None]:
generate_3d_plotly(img_list[BRAIN_INDEX], prediction_list[BRAIN_INDEX], "Prediction")

In [None]:
generate_3d_plotly(img_list[BRAIN_INDEX], gt_list[BRAIN_INDEX], 'GT')

In [None]:
show_result = ShowResult()
show_result.plot(img_list[BRAIN_INDEX], gt_list[BRAIN_INDEX], prediction_list[BRAIN_INDEX])

### Prediction for ResUNet

In [None]:
# ResUNet = medcam.inject(ResUNet, output_dir="ResUNet_attention_maps", save_maps=True, backend="gcam")
# ResUNet.medcam_dict['channels'] = 3
# Uncomment the top two lines of code for GradCAM

results = compute_results(
    ResUNet, test_dataloader, 0.33)

print(results['Id'])

res_id_list = []
res_img_list = []
res_gt_list = []
res_prediction_list = []
for resid_, res_img, res_gt, res_prediction in zip(results['Id'],
                    results['image'],
                    results['GT'],
                    results['Prediction']
                    ):
    
    res_id_list.append(resid_)
    res_img_list.append(res_img)
    res_gt_list.append(res_gt)
    res_prediction_list.append(res_prediction)



In [None]:


n_slices = 100
print()
interact(tumour_graphics, n_slice=widgets.IntSlider(min=0, max=n_slices-1, step=1, value=0), img=fixed(res_img_list[BRAIN_INDEX]), gt=fixed(res_gt_list[BRAIN_INDEX]), prediction=fixed(res_prediction_list[BRAIN_INDEX]))

In [None]:
generate_3d_plotly(res_img_list[BRAIN_INDEX], res_prediction_list[BRAIN_INDEX], 'Prediction')

In [None]:
generate_3d_plotly(res_img_list[BRAIN_INDEX], res_gt_list[BRAIN_INDEX], "GT")

In [None]:

show_result.plot(res_img_list[BRAIN_INDEX], res_gt_list[BRAIN_INDEX], res_prediction_list[BRAIN_INDEX])

### Prediction for AttUNet

In [None]:
def compute_results_attention(model,
                    dataloader,
                    treshold=0.33):

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    results = {"Id": [],"image": [], "GT": [],"Prediction": [],"Attention": []}

    with torch.no_grad():
        for i, data in enumerate(dataloader):
            id_, imgs, targets = data['Id'], data['image'], data['mask']
            imgs, targets = imgs.to(device), targets.to(device)
            logits = model(imgs)

            attention = AttUNet.dec4.attention_gate.output[1]
            attention_map = attention(logits)
            probs = torch.sigmoid(logits)
            
            predictions = (probs >= treshold).float()
            predictions =  predictions.cpu()
            targets = targets.cpu()
            
            results["Id"].append(id_)
            results["image"].append(imgs.cpu())
            results["GT"].append(targets)
            results["Prediction"].append(predictions)
            results["Attention"].append(attention_map)

            torch.cuda.empty_cache()            
            
            # only 5 pars
            if (i > 5):
                return results
        return results
    



In [None]:
# AttUNet = medcam.inject(AttUNet, output_dir="AttUNet_attention_maps", save_maps=True, backend="gcam")
# AttUNet.medcam_dict['channels'] = 3
# Uncomment the top two lines of code for GradCAM

results= compute_results_attention(
    AttUNet, test_dataloader, 0.33)
    

att_id_list = []
att_img_list = []
att_gt_list = []
att_prediction_list = []
explain_att = []

for attid_, att_img, att_gt, att_prediction, attention in zip(results['Id'],
                    results['image'],
                    results['GT'],
                    results['Prediction'],
                    results['Attention']

                    ):
    
    att_id_list.append(attid_)
    att_img_list.append(att_img)
    att_gt_list.append(att_gt)
    att_prediction_list.append(att_prediction)
    explain_att.append(attention)

print(att_id_list)


In [None]:
n_slices = 100
BRAIN_INDEX = 6
interact(tumour_graphics, n_slice=widgets.IntSlider(min=0, max=n_slices-1, step=1, value=0), img=fixed(att_img_list[BRAIN_INDEX]), gt=fixed(att_gt_list[BRAIN_INDEX]), prediction=fixed(att_prediction_list[BRAIN_INDEX]))


In [None]:
generate_3d_plotly(att_img_list[BRAIN_INDEX], att_prediction_list[BRAIN_INDEX], "Prediction")

In [None]:
generate_3d_plotly(att_img_list[BRAIN_INDEX], att_gt_list[BRAIN_INDEX], "GT")

In [None]:
BRAIN_INDEX = 5
show_result.plot(att_img_list[BRAIN_INDEX], att_gt_list[BRAIN_INDEX], att_prediction_list[BRAIN_INDEX])

In [None]:
%%time
path = r'C:\Users\ethan\OneDrive\Desktop\Thesis Brain Tumour\GIF/'
title = "Ground Truth_" + id_[0]
filename1 = path + title + "_3d.gif"
print(res_gt.shape)
data_to_3dgif = Image3dToGIF3d(img_dim = (170, 170, 100), binary=False, normalizing=False)
transformed_data = data_to_3dgif.get_transformed_data(res_gt[0,0,:,:,:].numpy())
data_to_3dgif.plot_cube(
    transformed_data,
    title=title,
    make_gif=True,
    path_to_save=filename1
)
#show_gif(filename1, format='png')

title = "Ground Truth_" + id_[0]
filename1 = title + "_3d.gif"



# Explainable Artificial Intelligence

#### UNet GradCAM

In [None]:
print(results['Id'])

In [None]:
def GradCAMDisplay(img, gt, prediction, WT_file, TC_file, ET_file ,idx):
    
    WT = nib.load(WT_file)
    WT = np.asanyarray(WT.dataobj)

    TC = nib.load(TC_file)
    TC = np.asanyarray(TC.dataobj)

    ET = nib.load(ET_file)
    ET = np.asanyarray(ET.dataobj)


    fig, ax1 = plt.subplots(3, 4, figsize = (12, 10))


    ax1[0][0].imshow(img[0,0,idx,:,:], cmap ='bone')    
    ax1[0][0].imshow(np.ma.masked_where(WT[:,:,idx] == False, WT[:,:,idx]), cmap ='magma')
    ax1[0][0].set_title('GradCAM (WT)')

    ax1[0][1].imshow(img[0,0,idx,:,:], cmap ='bone')
    ax1[0][1].imshow(np.ma.masked_where(gt[0,0,idx,:,:] == False, gt[0,0,idx,:,:]), cmap ='summer')
    ax1[0][1].set_title('GT (WT)')

    ax1[0][2].imshow(img[0,0,idx,:,:], cmap ='bone')
    ax1[0][2].imshow(np.ma.masked_where(prediction[0,0,idx,:,:] == False, prediction[0,0,idx,:,:]),
            cmap='summer')
    ax1[0][2].set_title('Prediction (WT)')

    ax1[0][3].imshow(img[0,0,idx,:,:], cmap ='bone')
    ax1[0][3].set_title('Original Image')



    ax1[1][0].imshow(img[0,0,idx,:,:], cmap ='bone')    
    ax1[1][0].imshow(np.ma.masked_where(TC[:,:,idx] == False, TC[:,:,idx]), cmap ='magma')
    ax1[1][0].set_title('GradCAM (TC)')

    ax1[1][1].imshow(img[0,0,idx,:,:], cmap ='bone')
    ax1[1][1].imshow(np.ma.masked_where(gt[0,1,idx,:,:] == False, gt[0,1,idx,:,:]), cmap ='rainbow')
    ax1[1][1].set_title('GT (TC)')

    ax1[1][2].imshow(img[0,0,idx,:,:], cmap ='bone')
    ax1[1][2].imshow(np.ma.masked_where(prediction[0,1,idx,:,:] == False, prediction[0,1,idx,:,:]),
            cmap='rainbow')
    ax1[1][2].set_title('Prediction (TC)')

    ax1[1][3].imshow(img[0,0,idx,:,:], cmap ='bone')
    ax1[1][3].set_title('Original Image')

    ax1[2][0].imshow(img[0,0,idx,:,:], cmap ='bone')    
    ax1[2][0].imshow(np.ma.masked_where(ET[:,:,idx] == False, ET[:,:,idx]), cmap ='magma')
    ax1[2][0].set_title('GradCAM (ET)')


    ax1[2][1].imshow(img[0,0,idx,:,:], cmap ='bone')
    ax1[2][1].imshow(np.ma.masked_where(gt[0,2,idx,:,:] == False, gt[0,2,idx,:,:]), cmap ='Wistia')
    ax1[2][1].set_title('GT (ET)')

    ax1[2][2].imshow(img[0,0,idx,:,:], cmap ='bone')
    ax1[2][2].imshow(np.ma.masked_where(prediction[0,2,idx,:,:] == False, prediction[0,2,idx,:,:]),
            cmap='Wistia')
    ax1[2][2].set_title('Prediction (ET)')
    
    ax1[2][3].imshow(img[0,0,idx,:,:], cmap ='bone')
    ax1[2][3].set_title('Original Image')


In [None]:
n_slices = 100
BRAIN_INDEX = 5

interact(GradCAMDisplay, idx=widgets.IntSlider(min=0, max=n_slices-1, step=1, value=0), img=fixed(img_list[BRAIN_INDEX]), gt=fixed(gt_list[BRAIN_INDEX]), prediction=fixed(prediction_list[BRAIN_INDEX]), WT_file = fixed(r'C:\Users\ethan\OneDrive\Desktop\Thesis Brain Tumour\UNet_attention_maps\out\attention_map_0_0_0.nii'), TC_file = fixed(r'C:\Users\ethan\OneDrive\Desktop\Thesis Brain Tumour\UNet_attention_maps\out\attention_map_0_0_1.nii'), ET_file = fixed(r'C:\Users\ethan\OneDrive\Desktop\Thesis Brain Tumour\UNet_attention_maps\out\attention_map_0_0_2.nii'))

### ResUNet GradCAM

In [None]:
n_slices = 100
BRAIN_INDEX = 2

interact(GradCAMDisplay, idx=widgets.IntSlider(min=0, max=n_slices-1, step=1, value=0), img=fixed(res_img_list[BRAIN_INDEX]), gt=fixed(res_gt_list[BRAIN_INDEX]), prediction=fixed(res_prediction_list[BRAIN_INDEX]), WT_file = fixed(r'C:\Users\ethan\OneDrive\Desktop\Thesis Brain Tumour\ResUNet_attention_maps\out\attention_map_4_0_0.nii'), TC_file = fixed(r'C:\Users\ethan\OneDrive\Desktop\Thesis Brain Tumour\ResUNet_attention_maps\out\attention_map_4_0_1.nii'), ET_file = fixed(r'C:\Users\ethan\OneDrive\Desktop\Thesis Brain Tumour\ResUNet_attention_maps\out\attention_map_4_0_2.nii'))

## AttUnet

### AttUNet GradCAM

In [None]:
n_slices = 100
BRAIN_INDEX = 2

interact(GradCAMDisplay, idx=widgets.IntSlider(min=0, max=n_slices-1, step=1, value=0), img=fixed(att_img_list[BRAIN_INDEX]), gt=fixed(att_gt_list[BRAIN_INDEX]), prediction=fixed(att_prediction_list[BRAIN_INDEX]), WT_file = fixed(r'C:\Users\ethan\OneDrive\Desktop\Thesis Brain Tumour\AttUNet_attention_maps\out\attention_map_2_0_0.nii'), TC_file = fixed(r'C:\Users\ethan\OneDrive\Desktop\Thesis Brain Tumour\AttUNet_attention_maps\out\attention_map_2_0_1.nii'), ET_file = fixed(r'C:\Users\ethan\OneDrive\Desktop\Thesis Brain Tumour\AttUNet_attention_maps\out\attention_map_2_0_2.nii'))

### Explainable Attention

In [None]:

def tumour_graphics_attention_map(n_slice, img, gt, attention):
    print("Image Shape:", img.shape)
    print("GT Shape:", gt.shape)
    print("Attention Shape:", attention.shape)
    plt.close('all')
    n_slice = n_slice
    # Overlay the second set of images on top of the first set
    plt.figure(figsize=(10, 8))
    plt.subplot(221)
    plt.title('Original Image with Ground Truth')
    plt.imshow(img[0,0,n_slice,:,:], cmap='bone')
    plt.imshow(np.ma.masked_where(gt[0,0,n_slice,:,:] == False, gt[0,0,n_slice,:,:]),
            cmap='summer', alpha=0.8)
    plt.imshow(np.ma.masked_where(gt[0,1,n_slice,:,:] == False, gt[0,1,n_slice,:,:]), 
            cmap='rainbow', alpha=0.8)
    plt.imshow(np.ma.masked_where(gt[0,2,n_slice,:,:] == False, gt[0,2,n_slice,:,:]), 
            cmap='Wistia', alpha=0.8)

    plt.subplot(222)
    plt.title('Whole Tumour with Explainable Attention')
    plt.imshow(img[0,0,n_slice,:,:], cmap='bone')
    plt.imshow(np.ma.masked_where(attention[0,0, n_slice,:,:].to('cpu') < 0.1, attention[0,0, n_slice,:,:].to('cpu')), cmap='magma', alpha=0.7)

    plt.subplot(223)
    plt.title('Enhanced Tumour with Explainable Attention')
    plt.imshow(img[0,0,n_slice,:,:], cmap='bone')
    plt.imshow(np.ma.masked_where(attention[0,1, n_slice,:,:].to('cpu') < 0.1, attention[0,1, n_slice,:,:].to('cpu')), cmap='magma', alpha=0.7)
    


    plt.subplot(224)
    plt.title('Tumour Core with Explainable Attention')
    plt.imshow(img[0,0,n_slice,:,:], cmap='bone')
    plt.imshow(np.ma.masked_where(attention[0,2, n_slice,:,:].to('cpu') < 0.01, attention[0,2, n_slice,:,:].to('cpu')), cmap='magma', alpha=0.7)

    plt.tight_layout()  # Adjusts spacing between plots to fit them neatly
    plt.show()

n_slices = 100

interact(tumour_graphics_attention_map, n_slice=widgets.IntSlider(min=0, max=n_slices-1, step=1, value=0), img=fixed(att_img_list[BRAIN_INDEX]), gt=fixed(att_gt_list[BRAIN_INDEX]), attention=fixed(explain_att[BRAIN_INDEX]))
# # batch size, channels, depth, width, height

# explanable attention