# Final BME Project: 3D Brain Segmentation

Natalie McKenzie - ncm2165
Eva Melendrez - emm2355

Welcome to our Final Biomedical Engineering Project! In this Kaggle notebook, we will walk you through the process of building a model that segments different parts of the brain in a MRI T1 (3D) scan of a brain. Firstly, let's import our libraries and our raw data!

In [None]:
#Libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torch.utils.data import TensorDataset, DataLoader

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, accuracy_score
import warnings

import nibabel as nib
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.callbacks import LambdaCallback
from tensorflow.keras import layers, models
import os
import tensorflow as tf
from scipy.ndimage import zoom
import skimage.transform as skTrans
#1: import packages
import sys, os
import glob
import math
import cv2
import scipy.interpolate as scInterp
from torch.optim import Adam, SGD

cuda = torch.cuda.is_available()
print ("GPU available:", cuda)

In [None]:
# Masks we want from our data - background, cerebrospinal fluid, grey matter, and white matter
masks = ["background", "csf", "gm", "wm"]

def plot_image(mask_array, image_array, idx, axis, save_root_folder=None):

    plt.figure(figsize=(22, 22))
    plt.subplot(1, 5, 1)
    plt.title(f"{idx} image on Axis {axis}")
    plt.imshow(image_array[0], cmap="gray")
    
    for i in range(len(masks)):
        show_image = mask_array[i]
        plt.subplot(1, 5, i + 2)
        plt.title(f"{idx} {masks[i]} on Axis {axis}")
        plt.imshow(show_image, cmap="gray")
        plt.axis("off")
    
    plt.tight_layout()
    
    if save_root_folder:
        plt.savefig(os.path.join(save_root_folder, f"Axis_{axis}_{idx}_Images_And_Masks.png"))
    
    plt.show()


In [None]:

def niftyToNumpy(root_folder, folder, output_folder):
    folder_path = os.path.join(root_folder, folder)
    img_folder_path = os.path.join(folder_path, 'image')
    mask_folder_path = os.path.join(folder_path, 'mask')
    output_folder_path = os.path.join(output_folder, folder)
    
    os.makedirs(output_folder_path, exist_ok=True)
    
    sbj_list = [i.split("_img.nii")[0] for i in os.listdir(img_folder_path) if i.endswith(".nii")]
    
    total_images = 0
    
    for sbj in sbj_list:
        print(f"Currently processing {sbj} in {folder}, {sbj_list.index(sbj) + 1} of {len(sbj_list)}", end="\r")
        
        img = nib.load(os.path.join(img_folder_path, sbj + "_img.nii")).get_fdata()
        # This line will normalize the pixel intensity to a [0, 1] range
        img = (img - img.min()) / (img.max() - img.min())
        mask_csf = nib.load(os.path.join(mask_folder_path, sbj + "_probmask_csf.nii")).get_fdata()
        mask_gm = nib.load(os.path.join(mask_folder_path, sbj + "_probmask_graymatter.nii")).get_fdata()
        mask_wm = nib.load(os.path.join(mask_folder_path, sbj + "_probmask_whitematter.nii")).get_fdata()
        
        # Create combined mask array
        mask = np.zeros((4,) + img.shape)
        mask[1] = mask_csf
        mask[2] = mask_gm
        mask[3] = mask_wm
        mask[0] = np.logical_and(mask[1] == 0, np.logical_and(mask[2] == 0, mask[3] == 0)).astype(np.float32)
        
        # Resize image and mask
        img = skTrans.resize(np.expand_dims(img, axis=0), (1, 112, 112, 112), order=1, preserve_range=True)
        mask = skTrans.resize(mask, (4, 112, 112, 112), order=1, preserve_range=True)
        
        # Iterate over slices in the middle of the image (to avoid empty slices)
        for idx in range(1, img.shape[1], 7):  # Adjust as necessary
            img_slc = img[:, idx, :, :]
            mask_slc = mask[:, idx, :, :]
            
            total_images += 1
            
            # If the current subject is "sald_031318" in the "test" folder, visualize the image and masks
            if sbj == "sald_031318" and folder == "test":
                plot_image(mask_slc, img_slc, idx, axis="Y", save_root_folder=None)
            
            # Save the image and mask slices
            np.save(os.path.join(output_folder_path, f"{sbj}_slc{idx:03}_img.npy"), img_slc)
            np.save(os.path.join(output_folder_path, f"{sbj}_slc{idx:03}_mask.npy"), mask_slc)
    
    print(f"\nTotal images processed: {total_images}")

# Define paths
og_folder = '/kaggle/input/3dbraintissuesegmentation'  # Path to the original dataset
np_folder = './processed_data'  # Path where you want to save the processed data (can be any directory)

# Create processed data folder if it doesn't exist
os.makedirs(np_folder, exist_ok=True)


folders = ['test', 'train', 'valid']
for curr_folder in folders:
    niftyToNumpy(og_folder, curr_folder, np_folder)

In [None]:
class BasicDataset(torch.utils.data.Dataset):
    def __init__(self, folder, split, percent_samples=100):
        self.folder = folder
        self.split = split
        self.ids = self.load_ids(folder, split)
        # Optionally limit the number of samples if needed
        self.num_samples = int(len(self.ids) * percent_samples / 100)
    
    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # Ensure you are not slicing along axes, just load the whole image data
        image_file = os.path.join(self.folder, self.split, "image", self.ids[idx] + "_img.npy")
        mask_file = os.path.join(self.folder, self.split, "mask", self.ids[idx] + "_mask.npy")
        
        # Load the image and mask data
        image = np.load(image_file)
        mask = np.load(mask_file)
        
        # Return the whole image and mask
        return {'image': image, 'mask': mask}


Above, you can see what the MRI scans look like after we conduct the 2D array transformation of the 3D images. We utilized images from the test set so you can see clearly what we expect from our model's performance. It should be able to determine wher ethe background, cerebrospinal fluid, grey matter, and white matter are in one of the MRI scans, as shown divided up above.

We included a counting method so we can ensure that the number of images we enter with in the transformation is the same number we get coming out. Next, we're going to implement the U-Net structure with transposed convolution to upsample our data and to segment it.

In [None]:
import torch.nn as nn
#Conv 3x3, ReLU
class ReLUConv(nn.Module):
    """(2d convolution, batch norm, relu) x 2"""
    def __init__(self, ch_in, ch_out):
        super().__init__()
        self.double_conv = nn.Sequential(
            #Shrink by 2 pixels a second time
            nn.Conv2d(ch_in, ch_out, kernel_size = 3, padding = 0),
            nn.BatchNorm2d(ch_out), 
            nn.ReLU(inplace = True),
            #Shrink by 2 pixels second time
            nn.Conv2d(ch_out, ch_out, kernel_size = 3, padding = 0),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace = True)
        )
    def forward(self, x):
        return self.double_conv(x)

#Copy and Crop: TO STUDY LATER

#Max pooling/Down-Conv 2x2
class DownConv(nn.Module):
    """maxpool, DoubleConv"""
    def __init__(self, ch_in, ch_out):
        super().__init__()
        self.down_conv = nn.Sequential(
            #Reduce dimensions by half of what they were before
            nn.MaxPool2d(2),
            ReLUConv(ch_in, ch_out)
        )
    def forward(self, x):
        return self.down_conv(x)

#Up-Conv 2x2
class UpConv(nn.Module):
    """upsample, 1x1 conv, concatenate with earlier layer, DoubleConv"""
    def __init__(self, ch_in, ch_out):
        super().__init__()
        self.up_conv = nn.Sequential(
            nn.Upsample(scale_factor = 2, mode = "bilinear", align_corners = True),
            nn.Conv2d(ch_in, ch_out, kernel_size = 1)
        )
        self.conv = DoubleConv(ch_out * 2, ch_out)
    def forward(self, x1, x2):
        x1 = self.up_conv(x1)
        x = torch.cat([x1, x2], dim = 1)
        x = self.conv(x)
        return x

#Conv 1x1
class OutConv(nn.Module):
    """1x1 conv, softmax (last operation)"""
    def __init__(self, ch_in, ch_out):
        super().__init__()
        self.conv_final = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size = 1),
            nn.Softmax(dim = 1)
        )
    def forward(self, x):
        return self.conv_final(x)
   

In [None]:
class ThisCNN(nn.Module):
    def __init__(self, name, n_channels, n_classes):
        super().__init__()
        self.name = name
        self.n_channels = n_channels
        self.n_classes = n_classes

        #Structure copied from class slides
        self.inputL = ReLUConv(n_channels, 64)
        self.down1 = DownConv(64, 128)
        self.down2 = DownConv(128, 256)
        self.down3 = DownConv(256, 512)
        self.down4 = DownConv(512, 1024)
        self.up1 = UpConv(1024, 512)
        self.up2 = UpConv(512, 256)
        self.up3 = UpConv(256, 128)
        self.up4 = UpConv(128, 64)
        self.outputL = OutConv(64, n_classes)
        
    def forward(self, x):
        x = self.inputL(x)
        
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        
        #wow look at that u-net concatenating
        u1 = self.up1(d4, d3)
        u2 = self.up2(u1, d2)
        u3 = self.up3(u2, d1)
        u4 = self.up4(u3, x)
        
        x = self.outputL(u4)
        
        return x

In [None]:
def train_net(net, epochs, train_dataloader, valid_dataloader, optimizer, loss_function):
    
    #create folder to model in (if not already existing)
    if not os.path.isdir(f'{net.name}'):
        os.mkdir(f'{net.name}')
    
    n_train = len(train_dataloader)
    n_valid = len(valid_dataloader)    
    
    train_loss = []
    train_pearson = []
    
    valid_loss = []
    valid_pearson = []
    
    #TRAINING
    for epoch in range(epochs):
        net.train()
        
        train_batch_loss = []
        train_batch_pearson = []
        
        for i, batch in enumerate(train_dataloader):
            
            #load image and gt mask for the batch and mount to cupa 
            imgs = batch['image'].cuda()
            masks = batch['mask'].cuda()

            #load model prediction w/ current weights
            pred = net(imgs)

            #compute batch loss and append to overall training loss
            loss, isolated_images, stacked_brain_map = loss_function(criterion, imgs, pred, masks, 'train')
            batch_loss = loss.item() #this just turns it into a number instead (i think it was a tensor)
            train_batch_loss.append(batch_loss)

            #compute batch's pearson coeff and append to overall
            batch_pearson = pearson_coeff(isolated_images, masks, stacked_brain_map)
            train_batch_pearson.append(batch_pearson)

            optimizer.zero_grad() #reset gradient values
            loss.backward() #compute backwards loss
            optimizer.step() #update weights
            
            #update progress (bc the end = '\r' rewrites over the last!)
            print(f'EPOCH {epoch + 1}/{epochs} - Training Batch {i+1}/{n_train} - Loss: {batch_loss}, Pearson Coefficient: {batch_pearson}', end='\r')
        
        #calculate averages and append to overall
        average_training_loss = np.array(train_batch_loss).mean()
        average_training_pearson = np.array(train_batch_pearson).mean()

        train_loss.append(average_training_loss)
        train_pearson.append(average_training_pearson)
        
        #VALIDATION
        net.eval()
        valid_batch_loss = list()
        valid_batch_pearson = list()
        
        #don't calculate gradient since you're not trying to train here, just evaluate
        with torch.no_grad():
            for i, batch in enumerate(valid_dataloader):

                #load image and mask for batch, mount to cupa
                imgs = batch['image'].cuda()
                masks = batch['mask'].cuda()


                #load model prediction w/ current weights
                pred = net(imgs)

                # compute batch loss and append to overall validation loss
                loss, isolated_images, stacked_brain_map = loss_function(criterion, imgs, pred, masks, 'val')
                batch_loss = loss.item()
                valid_batch_loss.append(batch_loss)

                #compute batch's pearson coeff and append to overall validation loss
                batch_pearson = pearson_coeff(isolated_images, masks, stacked_brain_map)
                valid_batch_pearson.append(batch_pearson)

                #update progress (bc the end = '\r' rewrites over the last!)
                print(f'EPOCH {epoch + 1}/{epochs} - Validation Batch {i+1}/{n_valid} - Loss: {batch_loss}, Pearson Coefficient: {batch_pearson}', end='\r')
                
        average_validation_loss = np.array(valid_batch_loss).mean()
        average_validation_pearson = np.array(valid_batch_pearson).mean()
        valid_loss.append(average_validation_loss)
        valid_pearson.append(average_validation_pearson)
        
        #print final evaluation of epoch 
        print(f'EPOCH {epoch + 1}/{epochs} - Training Loss: {average_training_loss}, Training Pearson score: {average_training_pearson}, Validation Loss: {average_validation_loss}, Validation Pearson Coefficient: {average_validation_pearson}')

        #save model 
        #note: since there's only 50 epochs you can change to {:02}
        torch.save(net.state_dict(), f'{net.name}/epoch_{epoch+1:03}.pth')
    
    return train_loss, train_pearson, valid_loss, valid_pearson

In [None]:
def loss_function(criterion, images, model_output, gt_masks, mode):
    """
    calculate loss for the three tissue type channels only over the brain area
    compute loss for background channel over the whole area
    """
    
    if mode == "val":
        torch.manual_seed(1102)
        np.random.seed(1102)
        
    #remove extra dimension
    images_squeezed = torch.squeeze(images, dim = 1)
    
    #turn this into binary brain map (values are 1 where brain, 0 where background)
    brain_map = (images_squeezed > 0).float()
    
    #stack the brain maps so their shape is batch_size x 4 x 112 x 112 (like model's output)
    #computing loss for background channel over the whole area, so that will just be filled with ones 
    ones = torch.ones((brain_map.shape)).cuda()
    stacked_brain_map = torch.stack(([ones, brain_map, brain_map, brain_map]), dim = 1)
    
    #only consider the values inside the brain (zero out others using 0s in bg of stacked brain map)
    isolated_images = torch.mul(stacked_brain_map, model_output)

    #see how the probabilities at each (prediction vs ground truth) compare using whatever criterion you prefer
    #computes avg loss over all the brain voxels
    loss = criterion(isolated_images, gt_masks) #make sure reduction = "sum" so you can do average over entire brain area
    num_brain_voxels = stacked_brain_map.sum()
    loss = loss / num_brain_voxels
    
    #return the isolated_images and the stacked_brain_map too because they're helpful for other functions
    return loss, isolated_images, stacked_brain_map
    
def pearson_coeff(isolated_image, target, stacked_brain_map):
    """
    calculate pearson correlation coefficient over the brain area
    """
    #flatten so you can do pearson correlation calculation
    gt_flattened = torch.flatten(target)
    iso_flattened = torch.flatten(isolated_image)
    mask_flattened = torch.flatten(stacked_brain_map)
    
    #only keep the values where there is brain (like not just zero them out, but discard them)
    gt_flattened = gt_flattened[mask_flattened.nonzero(as_tuple = True)]
    iso_flattened = iso_flattened[mask_flattened.nonzero(as_tuple = True)]
    
    #remove from cupa and turn into numpy array
    iso_flattened = iso_flattened.cpu().detach().numpy()
    gt_flattened = gt_flattened.cpu().detach().numpy()

    
    #calculate and return pearson correlation coefficient
    pearson = np.corrcoef(iso_flattened, gt_flattened)[0][1]
    return pearson

In [None]:
def dice_coeff_CM(isolated_image, target, stacked_brain_map):
    """
    calculate dice coefficient to evaluate model in testing stage and also return confusion matrix
    literally just convenient to do them together since they need the same inputs and you only need either for the testing stage
    not going to include bacground in calcations because there is so much background, it'll make the dice score higher than it should be
    """
    
    #create dictionary for each channel
    #ie. gt1 would csf, 2 gm, 3 wm
    masks_list = ["gt1", "gt2", "gt3", "iso1", "iso2", "iso3"]
    masks_dict = {i :{} for i in masks_list}
    
    #create maps with values of 0, 1, 2, 3, depending on which tissue type is the most likely to exist there (max value)
    #reminder: 0 is background here

    full_map_model = torch.argmax(isolated_image, 1) 
    full_map_gt = torch.argmax(target, 1)
    #shape of these is batch_size x 122 x 122

    binary_brain = stacked_brain_map[:,1,:,:]
    brain_flattened = torch.flatten(binary_brain)
     
    #okay quick pause for confusion matrix calculations
    
    #flatten because that's needed for CM calculations
    CM_full_map_model = torch.flatten(full_map_model)
    CM_full_map_gt = torch.flatten(full_map_gt)
    
    #detach from cupa and turn into numpy arrays for CM calculation
    CM_full_map_model = CM_full_map_model.cpu().detach().numpy()
    CM_full_map_gt = CM_full_map_gt.cpu().detach().numpy()
    CM = confusion_matrix(CM_full_map_gt, CM_full_map_model)
    
    #okay back to dice_scores
    dice_scores = []
    
    #calculate one dice score for each tissue type
    for i in range (1, 4):
        #key is a tensor where 0 if not correct tissue type (background or other tissue), 1 if correct tissue type
        masks_dict["gt" + str(i)] = (full_map_gt == i).float()
        masks_dict["iso" + str(i)] = (full_map_model == i).float()
        masks_dict["gt" + str(i)] = torch.flatten(masks_dict["gt" + str(i)])
        masks_dict["gt" + str(i)] = masks_dict["gt" + str(i)][brain_flattened.nonzero(as_tuple = True)]
        masks_dict["gt" + str(i)] = masks_dict["gt" + str(i)].cpu().detach().numpy()
        
        masks_dict["iso" + str(i)] = torch.flatten(masks_dict["iso" + str(i)])
        masks_dict["iso" + str(i)] = masks_dict["iso" + str(i)][brain_flattened.nonzero(as_tuple = True)]
        masks_dict["iso" + str(i)] = masks_dict["iso" + str(i)].cpu().detach().numpy()
    
        #okay wait actually none of the mask_dict stuff is really necessary since you're not accessing the dictionaries later...
        #like you're literally taking these values and assigning them new names...
        #note: come back and delete all the dictionary stuff... literally just give them the below names from the beginning
        model_output = masks_dict["iso" + str(i)]
        gt = masks_dict["gt" + str(i)]
        eps = 0.00001 #so there's no dividing by zero
        
        #calculate dice scores and append to overall dice scores list
        dice = (np.sum(model_output[gt == 1]) * 2.0 + eps) / (np.sum(model_output) + np.sum(gt) + eps)
        dice_scores.append(dice)
    
    #calculate and return avg dice scores and confusion matrix info
    avg_dice = sum(dice_scores)/len(dice_scores)
    return avg_dice, CM

#TESTING
def test_net(net, test_dataloader, loss_function):
    if not os.path.isdir('/kaggle/working/pred_mask'):
        os.mkdir('/kaggle/working/pred_mask')
    net.eval()
    
    n_test = len(test_dataloader)
    test_batch_loss = []
    test_batch_pearson = []
    test_batch_dice = []
    test_batch_CM = []
    all_imgs = []
    all_labels = list()
    all_preds = list()
    
    #no need to update gradient
    with torch.no_grad():
        for i, batch in enumerate(test_dataloader):

            #load image and mask for batch, mount to cupa
            imgs = batch['image'].cuda()
            masks = batch['mask'].cuda()
            
            #load model prediction w/ "best" weights
            pred = net(imgs)
            
            #compute batch loss and append to overall validation loss
            loss, isolated_images, stacked_brain_map = loss_function(criterion, imgs, pred, masks, 'test')
            batch_loss = loss.item()
            test_batch_loss.append(batch_loss)
            
            #compute batch's pearson coeff and append to overall validation loss            
            batch_pearson = pearson_coeff(isolated_images, masks, stacked_brain_map)
            test_batch_pearson.append(batch_pearson)
            
            #compute batch's dice score and confusion matrix and append to overall lists
            batch_dice, batch_CM = dice_coeff_CM(isolated_images, masks, stacked_brain_map)
            test_batch_dice.append(batch_dice)
            test_batch_CM.append(batch_CM)
            
            pred = pred.cpu().detach().numpy()
            imgs = imgs.cpu().detach().numpy()
            labels = batch['sbj_id']
            all_labels+=labels
            all_preds.append(pred)
            all_imgs.append(imgs)
            
            #update progress
            print(f'Test Batch {i+1}/{n_test} - Loss: {batch_loss}, Pearson Corr: {batch_pearson}, DICE score: {batch_dice}', end='\r')
        
        #find averages of all metrics and return them
        test_loss = np.array(test_batch_loss).mean()
        test_dice = np.array(test_batch_dice).mean()
        test_pearson = np.array(test_batch_pearson).mean()
        test_CM = np.array(test_batch_CM).mean(axis = 0)
        
    return all_imgs, all_labels, all_preds, test_loss, test_dice, test_pearson, test_CM

In [None]:
def CM(test_CM, save_root_folder, label):
    """
    given CM values, return a nice little visualization of them :), but with the background values removed because there's too many 
    """
    
    #turn into panda dataframe so you can add labels!
    df_cm = pd.DataFrame(test_CM, index = ["Background", "CSF", "GM", "WM"],
                     columns = ["Background", "CSF", "GM", "WM"])
    
    #remove the background >:(
    df_cm.pop("Background")
    df_cm = df_cm.drop("Background", axis = "index")
    
    #plot 
    plt.figure(figsize = (12,10))
    plt.title('Confusion Matrix')
    sns.heatmap(df_cm, annot = True, annot_kws = {"size": 15})
    plt.ylabel('True labels')
    plt.xlabel('predicted labels')
    
    #save plot
    plt.savefig(os.path.join(save_root_folder, f"{label}_Confusion_Matrix.png"))
    plt.show()

def learning_curve(best_epoch, train_loss, valid_loss, train_pearson, valid_pearson, save_root_folder):
    """
    return visualization of how the metrics improved during each epoch (learning curve)
    """
    fig, (ax1, ax2) = plt.subplots(figsize = (15, 8), ncols = 2)
    fig.suptitle("Learning Curve", fontsize = 18)
        
    #plot losses on first graph
    ax1.set_ylabel("Loss", fontsize = 15)
    ax1.set_xlabel("Epoch", fontsize = 15)
    ax1.set_xticks(np.arange(EPOCHS, step = 10) + 1)
    ax1.plot(np.arange(EPOCHS) + 1, train_loss, '-o', label = "Training")
    ax1.plot(np.arange(EPOCHS) + 1, valid_loss, '-o', label = "Validation")
    ax1.axvline(best_epoch, color = 'm', lw = 4, alpha = 0.5, label = "Best Epoch") #highlight the best epoch
    
    #plot pearson coefficients on second graph    
    ax2.set_ylabel("Pearson Coeff", fontsize = 15)
    ax2.set_xlabel("Epoch", fontsize = 15)
    ax2.set_xticks(np.arange(EPOCHS, step = 10) + 1)
    ax2.plot(np.arange(EPOCHS) + 1, train_pearson, '-o', label = "Training")
    ax2.plot(np.arange(EPOCHS) + 1, valid_pearson, '-o', label = "Validation ")
    ax2.axvline(best_epoch, color = 'm', lw = 4, alpha = 0.5, label = "Best Epoch") #highlight the best epoch
    
    plt.legend()
    plt.tight_layout()
    
    #save plot
    plt.savefig(os.path.join(save_root_folder, "Learning_Curve.png"))
    plt.show()

In [None]:
if not os.path.isdir("figures"):
    os.mkdir("figures")
save_root_folder = "/kaggle/working/figures"
final_predictions = list()
final_labels = list()
final_images = list()


torch.manual_seed(1102)
np.random.seed(1102) #Q: should i reset the seeds each time?
root_folder = f"/kaggle/working/figures/train/"

#2: datasets and dataloaders
percent = 1
train_dataset = BasicDataset(root_folder, 'train', percent)
valid_dataset = BasicDataset(root_folder, 'valid', percent)
test_dataset = BasicDataset(root_folder, 'test', percent)

batch_size = 85
train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
valid_loader = DataLoader(valid_dataset, batch_size = batch_size)
test_loader = DataLoader(test_dataset, batch_size = batch_size)
    
print (f"Batches in Train DataLoader: {len(train_loader)}")
print (f"Batches in Validation DataLoader: {len(valid_loader)}")
print (f"Batches in Test DataLoader: {len(test_loader)}")

    #display the first five image slices
#     for slc in range (5):
#         plot_image(train_dataset[slc]['mask'], train_dataset[slc]['image'], train_dataset[slc]['sbj_id'], axis, save_root_folder)
    
    #create instance of u-net architecture
model = ThisCNN(f"Model_Axis_Y", 1, 4)
model = model.cuda()
    
        
    #4: define loss function and optimization method
    #define optimizer 
optimizer = torch.optim.SGD(model.parameters(), lr=0.005)
    

    #define loss function
criterion = nn.MSELoss(reduction = "sum")
    
    
    #5: train the model for at least 50 epochs 
    #6: validate and save the model at the end of each epoch
EPOCHS = 50
train_loss, train_pearson, valid_loss, valid_pearson = train_net(model, EPOCHS, train_loader, valid_loader, optimizer, loss_function)

    
    #7: choose the best epoch and load the weights at that specific epoch
best_epoch = np.argmax(valid_pearson) + 1 #just add one bc epochs start at 1
print (f"Best Epoch: {best_epoch}")

    #load weights to model
state_dict = torch.load(f'./Model_Axis_Y/epoch_{best_epoch:03}.pth')
    torch.save(state_dict, f"/kaggle/working/Axis_Y_model_best_epoch.pth")
    !rm -r './Model_Axis_{axis}'
    model.load_state_dict(state_dict)
    model.cuda()
    

    #8: test the model on the test data
    test_imgs, test_labels, test_predictions, test_loss, test_dice, test_pearson, test_CM = test_net(model, test_loader, loss_function)
    final_predictions.append(test_predictions)
    final_labels.append(test_labels)
    final_images.append(test_imgs)
    print(f'Test Loss: {test_loss}, Test DICE score: {test_dice}, Test Pearson Correlation: {test_pearson}')
    
    for slc in range (3, 5):
        plot_image(test_predictions[0][slc], test_imgs[0][slc], test_labels[0][slc], axis)

    
    #9: generate and save all the figures you might need for the presentation (dataset distribution, learning curve, confusion matrix, etc.)
    learning_curve(best_epoch, train_loss, valid_loss, train_pearson, valid_pearson, save_root_folder)
    CM(test_CM, save_root_folder, f"Axis {axis}")
                    
    model = model.cpu() #remove old model from gpu

Next, we're going to normalize the MRI scans so that the voxels values are scaled from 0 to 1. This way, we can standardize every image's intensity range similarly, so that the model can learn more easily.
