**To install packages**

In [None]:
!pip install nibabel -q
!pip install scikit-learn -q
!pip install tqdm -q
!pip install split-folders -q
!pip install torchinfo -q
!pip install segmentation-models-pytorch-3d -q
!pip install livelossplot -q
!pip install torchmetrics -q
!pip install tensorboard -q
!apt-get install tree

In [173]:
import os
import random
import splitfolders
from tqdm import tqdm
import nibabel as nib
import glob
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
import shutil
import time

from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
import torchvision.transforms as transforms
from torch.cuda import amp

from torchmetrics import MeanMetric
from torchmetrics.classification import MulticlassAccuracy

from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchinfo import summary
import gc

import segmentation_models_pytorch_3d as smp

from livelossplot import PlotLosses
from livelossplot.outputs import MatplotlibPlot, ExtremaPrinter
from torch.optim import AdamW
import matplotlib.patches as mpatches
from matplotlib.colorbar import Colorbar

**Mount Drive to acess dataset**

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [175]:
def seed_everything(SEED):
   np.random.seed(SEED)
   torch.manual_seed(SEED)
   torch.cuda.manual_seed_all(SEED)
   torch.backends.cudnn.deterministic = True
   torch.backends.cudnn.benchmark = False


def get_default_device():
   gpu_available = torch.cuda.is_available()
   return torch.device('cuda' if gpu_available else 'cpu'), gpu_available

In [176]:
dataclass(frozen=True)
class TrainingConfig:
   BATCH_SIZE:      int = 5
   EPOCHS:          int = 50
   LEARNING_RATE: float = 5e-5
   CHECKPOINT_DIR:  str = os.path.join('/content/drive/MyDrive/MSc_Project', '3D_Brain_Images')
   NUM_WORKERS:     int = 4

#**Sample Preprocessing**

In [None]:
scaler = MinMaxScaler()
Dataset_path = '/content/drive/MyDrive/MSc_Project/BraTS2021_Training_Data'
print(f'Total Files: ', len(os.listdir(Dataset_path)))


**To check file structure**

In [None]:
!tree -L 2 "/content/drive/MyDrive/MSc_Project/BraTS2021_Training_Data/BraTS2021_00000"

**Load NifTi Images using nib.load()**

returns a numpy array

In [None]:
sample_image_flair_3d = nib.load(os.path.join(Dataset_path, 'BraTS2021_00000', 'BraTS2021_00000_flair.nii.gz')).get_fdata()
print(f'Original max value:', sample_image_flair_3d.max())

# reshape the 3d image to 2d iamge for scaling
sample_image_flair_2d = sample_image_flair_3d.reshape(-1,1)

**One sample taken for understanding how preprocessing would work on the whole dataset**

In [None]:
# Apply scaling
sample_image_flair_scaled = scaler.fit_transform(sample_image_flair_2d)

# Reshape it back to original form
sample_image_flair_scaled = sample_image_flair_scaled.reshape(sample_image_flair_3d.shape)

print(f'Scaled max value:', sample_image_flair_scaled.max())
print(f'Shape of scaled image:', sample_image_flair_scaled.shape)

**Use of seg MRI modality to see the classes**

In [None]:
sample_image_seg_3d = nib.load(os.path.join(Dataset_path, 'BraTS2021_00000', 'BraTS2021_00000_seg.nii.gz')).get_fdata()
sample_image_seg_3d = sample_image_seg_3d.astype(np.uint8)

print(f'Unique class in the mask', np.unique(sample_image_seg_3d))

**To choose a random slice, manual paitent id and visualize it per modality for better insights**

In [182]:
sample_image_t1_3d = nib.load(os.path.join(Dataset_path, 'BraTS2021_00000', 'BraTS2021_00000_t1.nii.gz')).get_fdata()
sample_image_t1ce_3d = nib.load(os.path.join(Dataset_path, 'BraTS2021_00000', 'BraTS2021_00000_t1ce.nii.gz')).get_fdata()
sample_image_t2_3d = nib.load(os.path.join(Dataset_path, 'BraTS2021_00000', 'BraTS2021_00000_t2.nii.gz')).get_fdata()

In [None]:
# random slice between 0-154
n_slice = random.randint(0,sample_image_seg_3d.shape[2]-1)
print(f'random slice number:', n_slice)
plt.figure(figsize=(12,8))

plt.subplot(231)
plt.imshow(sample_image_flair_3d[:,:,n_slice], cmap='gray')
plt.title('Flair Image')

plt.subplot(232)
plt.imshow(sample_image_t1_3d[:,:,n_slice], cmap='gray')
plt.title('T1 Image')

plt.subplot(233)
plt.imshow(sample_image_t1ce_3d[:,:,n_slice], cmap='gray')
plt.title('T1ce Image')

plt.subplot(234)
plt.imshow(sample_image_t2_3d[:,:,n_slice], cmap='gray')
plt.title('T2 Image')

plt.subplot(235)
plt.imshow(sample_image_seg_3d[:,:,n_slice])
plt.title('Seg Image')


plt.subplot(236)
plt.imshow(sample_image_seg_3d[:,:,n_slice], cmap = 'gray')
plt.title('Mask Gray')
plt.show()

**To see all modalities in loop, from all 3 axes**

In [184]:
volume = np.zeros((240, 240, 155))  # A blank MRI
n_slice = random.randint(0, 154)
slice_2d = volume[:, :, n_slice]    # 2D image at the randomly chosen depth


In [None]:
# Choose slices near the center of each dimension
y, x, z = sample_image_flair_3d.shape[0]//2, volume.shape[1]//2, volume.shape[2]//2

plt.figure(figsize=(18, 5))

# Axial (top-down)
plt.subplot(1, 3, 1)
plt.imshow(sample_image_flair_3d[:, :, z], cmap='gray')
plt.title(f'Axial (Z={z}), X-Y')
plt.axis('off')

# Coronal (front-back)
plt.subplot(1, 3, 2)
plt.imshow(sample_image_flair_3d[:, y, :], cmap='gray')
plt.title(f'Coronal (Y={y}), X-Z')
plt.axis('off')

# Sagittal (side)
plt.subplot(1, 3, 3)
plt.imshow(sample_image_flair_3d[x, :, :], cmap='gray')
plt.title(f'Sagittal (X={x}), Y-Z')
plt.axis('off')

plt.show()


In [None]:
combined_x = np.stack([sample_image_flair_3d,sample_image_t1ce_3d,sample_image_t2_3d],axis=3)
print("Shape of Combined x ", combined_x.shape)

In [None]:
combined_x = combined_x[56:184, 56:184, 13:141]
print("Shape after cropping: ", combined_x.shape)

sample_mask_c = sample_image_seg_3d[56:184,56:184, 13:141]
print("Mask shape after cropping: ", sample_mask_c.shape)

In [None]:
# Remap class value 4 to 3
sample_mask_c[sample_mask_c == 4] = 3

# Now apply one_hot encoding with num_classes=4
sample_mask_cat  = F.one_hot(torch.tensor(sample_mask_c, dtype = torch.long), num_classes = 4)
print("Shape after one-hot encoding: ", sample_mask_cat.shape)

In [None]:
t1ce_list = sorted(glob.glob(f"{Dataset_path}/*/*t1ce.nii.gz"))
t2_list = sorted(glob.glob(f"{Dataset_path}/*/*t2.nii.gz"))
flair_list = sorted(glob.glob(f"{Dataset_path}/*/*flair.nii.gz"))
mask_list = sorted(glob.glob(f"{Dataset_path}/*/*seg.nii.gz"))

print("t1ce list: ", len(t1ce_list))
print("t2 list: ", len(t2_list))
print("flair list: ", len(flair_list))
print("Mask list: ", len(mask_list))

#**Data Preprocessing**

In [None]:
# to show progress bar
for idx in tqdm(range(len(t2_list)), desc='Preparing to stack, crop and save', unit='file'):
  temp_image_t1ce = nib.load(t1ce_list[idx]).get_fdata()
  temp_image_t1ce = scaler.fit_transform(temp_image_t1ce.reshape(-1, temp_image_t1ce.shape[-1])).reshape(temp_image_t1ce.shape)

  temp_image_t2 = nib.load(t2_list[idx]).get_fdata()
  temp_image_t2 = scaler.fit_transform(temp_image_t2.reshape(-1, temp_image_t2.shape[-1])).reshape(temp_image_t2.shape)

  temp_image_flair = nib.load(flair_list[idx]).get_fdata()
  temp_image_flair = scaler.fit_transform(temp_image_flair.reshape(-1, temp_image_flair.shape[-1])).reshape(temp_image_flair.shape)

  temp_seg_mask = nib.load(mask_list[idx]).get_fdata()

  temp_comb_images = np.stack([temp_image_flair, temp_image_t1ce, temp_image_t2], axis=3)
  temp_comb_images = temp_comb_images[56:184, 56:184, 13:141]
  temp_seg_mask = temp_seg_mask[56:184, 56:184, 13:141]

  temp_seg_mask[temp_seg_mask == 4] = 3

  val, counts = np.unique(temp_seg_mask, return_counts=True)

# if a volume has less than 1% of mask, simplw ignore to reduce the computation.
  if(1 - (counts[0] / counts.sum())) > 0.01:
    temp_seg_mask = F.one_hot(torch.tensor(temp_seg_mask, dtype=torch.long), num_classes=4)
    out_dir_images = "/content/drive/MyDrive/MSc_Project/BraTS2021_Preprocessed/input_data_3channels/images"
    out_dir_masks  = "/content/drive/MyDrive/MSc_Project/BraTS2021_Preprocessed/input_data_3channels/masks"
    os.makedirs(out_dir_images, exist_ok=True)
    os.makedirs(out_dir_masks, exist_ok=True)
    np.save(
        os.path.join(out_dir_images, f"image_{idx}.npy"),
        temp_comb_images,
    )
    np.save(
        os.path.join(out_dir_masks, f"mask_{idx}.npy"),
        temp_seg_mask,
    )

  else:
    pass


In [None]:
images_folder = "/content/drive/MyDrive/MSc_Project/BraTS2021_Preprocessed/input_data_3channels/images"
print(len(os.listdir(images_folder)))

masks_folder = "/content/drive/MyDrive/MSc_Project/BraTS2021_Preprocessed/input_data_3channels/masks"
print(len(os.listdir(masks_folder)))

In [None]:
#split the data into folders
input_folder = "/content/drive/MyDrive/MSc_Project/BraTS2021_Preprocessed/input_data_3channels/"

output_folder = "/content/drive/MyDrive/MSc_Project/BraTS2021_Preprocessed/input_data_split/"

splitfolders.ratio(
    input_folder, output_folder, seed=42, ratio=(0.75, 0.15, 0.10), group_prefix=None
)
print('Sucess!')

# DataLoader

In [None]:
class Dataset_Brats21(Dataset):
    def __init__(self, img_dir, mask_dir, normalization=True):
        super().__init__()

        self.img_dir = img_dir
        self.mask_dir = mask_dir

        self.img_list = sorted(os.listdir(img_dir))
        self.mask_list = sorted(os.listdir(mask_dir))


        if len(self.img_list) != len(self.mask_list):
            raise ValueError(f"Number of images ({len(self.img_list)}) and masks ({len(self.mask_list)}) do not match in {img_dir} and {mask_dir}")

        self.normalization = normalization


        if self.normalization:
          # mean and std for t1ce, flair, t2
            self.mean = torch.tensor([0.5]*3)
            self.std = torch.tensor([0.5]*3)


    def load_file(self, filepath, retries=5, delay=1):
        # loads a file with retries to handle transient errors
        for i in range(retries):
            try:

                data = np.load(filepath, allow_pickle=False)
                if data.size == 0:
                    raise ValueError(f"Loaded empty data from {filepath}")
                return data
            except (OSError, ValueError) as e:
                print(f"Attempt {i+1}/{retries} failed to load {filepath}: {e}")
                if i < retries - 1:
                    time.sleep(delay)
                else:
                    raise e

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
        image_path = os.path.join(self.img_dir, self.img_list[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_list[idx])

        try:

            image = self.load_file(image_path)
            mask = self.load_file(mask_path)

            # convert to torch tensors and permute axes to C, D, H, W format
            image = torch.from_numpy(image).permute(3, 0, 1, 2).float()
            mask = torch.from_numpy(mask).permute(3, 0, 1, 2).float()

            # normalize the image per channel
            if self.normalization:
                # reshape mean and std
                mean = self.mean.view(-1, 1, 1, 1)
                std = self.std.view(-1, 1, 1, 1)
                image = (image - mean) / std


            return image, mask

        except Exception as e:
            print(f"Skipping sample {self.img_list[idx]} due to error: {e}")
            raise e

train_img_dir = "/content/drive/MyDrive/MSc_Project/BraTS2021_Preprocessed/input_data_split/train/images"
train_mask_dir = "/content/drive/MyDrive/MSc_Project/BraTS2021_Preprocessed/input_data_split/train/masks"

val_img_dir = "/content/drive/MyDrive/MSc_Project/BraTS2021_Preprocessed/input_data_split/val/images"
val_mask_dir = "/content/drive/MyDrive/MSc_Project/BraTS2021_Preprocessed/input_data_split/val/masks"

test_img_dir = "/content/drive/MyDrive/MSc_Project/BraTS2021_Preprocessed/input_data_split/test/images"
test_mask_dir = "/content/drive/MyDrive/MSc_Project/BraTS2021_Preprocessed/input_data_split/test/masks"

train_dataset = Dataset_Brats21(train_img_dir, train_mask_dir, normalization=True)
val_dataset = Dataset_Brats21(val_img_dir, val_mask_dir, normalization=True)
test_dataset = Dataset_Brats21(test_img_dir, test_mask_dir, normalization=True)

# dataset statistics
print("Total Training Samples: ", len(train_dataset))
print("Total Val Samples: ", len(val_dataset))
print("Total Test Samples: ", len(test_dataset))

train_loader = DataLoader(train_dataset, batch_size=5, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=5, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=5, shuffle=False, num_workers=4)

images, masks = next(iter(train_loader))
print(f"Train Image batch shape: {images.shape}")
print(f"Train Mask batch shape: {masks.shape}")

images, masks = next(iter(val_loader))
print(f"Val Image batch shape: {images.shape}")
print(f"Val Mask batch shape: {masks.shape}")

images, masks = next(iter(test_loader))
print(f"Test Image batch shape: {images.shape}")
print(f"Test Mask batch shape: {masks.shape}")


# **Building a 3D U-Net Model**

In [191]:
def DoubleConv(in_channels, out_channels):
   return nn.Sequential(
       nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
       nn.BatchNorm3d(out_channels),
       nn.ReLU(inplace=False),
       nn.Dropout(0.1 if out_channels <= 32 else 0.2 if out_channels <= 128 else 0.3),
       nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
       nn.BatchNorm3d(out_channels),
       nn.ReLU(inplace=False)
   )

In [192]:
class Unet(nn.Module):
  def __init__(self, in_channels, out_channels):
       super().__init__()

       # Encoder
       self.conv1 = DoubleConv(in_channels=in_channels, out_channels=16)
       self.pool1 = nn.MaxPool3d(kernel_size=2)

       self.conv2 = DoubleConv(in_channels=16, out_channels=32)
       self.pool2 = nn.MaxPool3d(kernel_size=2)

       self.conv3 = DoubleConv(in_channels=32, out_channels=64)
       self.pool3 = nn.MaxPool3d(kernel_size=2)

       self.conv4 = DoubleConv(in_channels=64, out_channels=128)
       self.pool4 = nn.MaxPool3d(kernel_size=2)

       # Bottleneck
       self.conv5 = DoubleConv(in_channels=128, out_channels=256)

       # Decoder
       self.upconv6 = nn.ConvTranspose3d(in_channels=256, out_channels=128, kernel_size=2, stride=2)
       self.conv6 = DoubleConv(in_channels=256, out_channels=128)

       self.upconv7 = nn.ConvTranspose3d(in_channels=128, out_channels=64, kernel_size=2, stride=2)
       self.conv7 = DoubleConv(in_channels=128, out_channels=64)

       self.upconv8 = nn.ConvTranspose3d(in_channels=64, out_channels=32, kernel_size=2, stride=2)
       self.conv8 = DoubleConv(in_channels=64, out_channels=32)

       self.upconv9 = nn.ConvTranspose3d(in_channels=32, out_channels=16, kernel_size=2, stride=2)
       self.conv9 = DoubleConv(in_channels=32, out_channels=16)

       self.out_conv = nn.Conv3d(in_channels=16, out_channels=out_channels, kernel_size=1)

  def forward(self, x):
       # Contracting path
       x1 = self.conv1(x)
       x2 = self.pool1(x1)

       x3 = self.conv2(x2)
       x4 = self.pool2(x3)

       x5 = self.conv3(x4)
       x6 = self.pool3(x5)

       x7 = self.conv4(x6)
       x8 = self.pool4(x7)

       x9 = self.conv5(x8)

       # Expansive path
       x10 = self.upconv6(x9)
       x10 = torch.cat([x10,x7], dim=1) # skip connections
       x10 = self.conv6(x10)

       x11 = self.upconv7(x10)
       x11 = torch.cat([x11,x5], dim=1)
       x11 = self.conv7(x11)

       x12 = self.upconv8(x11)
       x12 = torch.cat([x12,x3], dim=1)
       x12 = self.conv8(x12)

       x13 = self.upconv9(x12)
       x13 = torch.cat([x13,x1], dim=1)
       x13 = self.conv9(x13)

       x14 = self.out_conv(x13)

       return x14

**Defining Losses and Optimizers from Segmentation**

In [193]:
# Dice Loss
dice_loss = smp.losses.DiceLoss(
   mode="multiclass",
   classes=None,
   log_loss=False,
   from_logits=True,
   smooth=1e-5,
   ignore_index=None,
   eps=1e-7
)

# Focal Loss
focal_loss = smp.losses.FocalLoss(
   mode="multiclass",
   alpha=0.25,
   gamma=2.0
)

def combined_loss(output, target):
   loss1 = dice_loss(output, target)
   loss2 = focal_loss(output, target)
   return loss1 + loss2

**Checkpoint to based on valid loss**

In [None]:
def create_checkpoint_dir(checkpoint_dir):
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    try:
        num_versions = [
            int(i.split("_")[-1]) for i in os.listdir(checkpoint_dir) if "version" in i
        ]
        version_num = max(num_versions) + 1

    except:
        version_num = 0

    version_dir = os.path.join(checkpoint_dir, "version_" + str(version_num))
    os.makedirs(version_dir)

    print(f"Checkpoint directory: {version_dir}")
    return version_dir

seed_everything(SEED = 42)

DEVICE, GPU_AVAILABLE  = get_default_device()
print(DEVICE)

CKPT_DIR = create_checkpoint_dir(TrainingConfig.CHECKPOINT_DIR)

from torch.optim import AdamW

optimizer = AdamW(
   model.parameters(),
   lr=TrainingConfig.LEARNING_RATE,
   weight_decay=1e-2,                # Regularization to avoid overfitting
   amsgrad=True                      # Optional AMSGrad variant
)

# **Train the model**

In [195]:
from torch.optim.lr_scheduler import CosineAnnealingLR

def logits_to_probs(logits):
    # convert raw network outputs to soft-max probabilities
    return torch.softmax(logits, dim=1)

# Train function
def TrainModel(
   model,
   loader,
   optimizer,
   num_classes,
   device="gpu",
   epoch_index=800,
   total_epochs=50):


   model.train()


   loss_record = MeanMetric()

   loader_len = len(loader)
   iou_per_class = {i: MeanMetric() for i in range(num_classes)}
   dice_per_class = {i: MeanMetric() for i in range(num_classes)}


   with tqdm(total=loader_len, ncols=122) as tq:
       tq.set_description(f"Train ::  Epoch: {epoch_index}/{total_epochs}")

       for data, target in loader:
           tq.update(1)

           data, target = data.to(device).float(), target.to(device).float()

           optimizer.zero_grad()

           output_dict = model(data)

           target_indexed = target.argmax(dim=1)

           clsfy_out = output_dict
           loss = combined_loss(clsfy_out, target_indexed)

           loss.backward()
           optimizer.step()


           with torch.no_grad():
               pred_idx = clsfy_out.argmax(dim=1)


               tp, fp, fn, tn = smp.metrics.get_stats(pred_idx, target_indexed, mode='multiclass', num_classes=num_classes)


               iou_scores = smp.metrics.iou_score(tp, fp, fn, tn, reduction='none')
               dice_scores = smp.metrics.f1_score(tp, fp, fn, tn, reduction='none')


               for i in range(iou_scores.size(0)):

                   if i < num_classes:
                       iou_per_class[i].update(iou_scores[i].cpu(), weight=data.shape[0])
                   if i < num_classes:
                       dice_per_class[i].update(dice_scores[i].cpu(), weight=data.shape[0])


               loss_record.update(loss.detach().cpu(), weight=data.shape[0])


           tq.set_postfix_str(s=f"Loss: {loss_record.compute():.4f}")


   epoch_loss = loss_record.compute()
   epoch_iou_per_class = {i: iou_per_class[i].compute() for i in range(num_classes)}
   epoch_dice_per_class = {i: dice_per_class[i].compute() for i in range(num_classes)}


   return epoch_loss, epoch_iou_per_class, epoch_dice_per_class


def McDropout(model, x, mc_runs=20):
    # return mean and variance probability using Monte-Carlo Dropout.

    model.eval()

    def enable_dropout(m):
        if isinstance(m, (nn.Dropout, nn.Dropout3d)):
            m.train()
    model.apply(enable_dropout)

    probs = []
    with torch.no_grad():
        for _ in range(mc_runs):
            logits = model(x)
            probs.append(torch.softmax(logits, dim=1))

    stack = torch.stack(probs, dim=0)
    mean  = stack.mean(dim=0)
    var   = stack.var (dim=0)
    return mean, var

# Validation function
def ValidateModel(
   model,
   loader,
   device,
   num_classes,
   epoch_index,
   total_epochs,
   writer=None
):
   model.eval()


   loss_record = MeanMetric()
   iou_per_class = {i: MeanMetric() for i in range(num_classes)}
   dice_per_class = {i: MeanMetric() for i in range(num_classes)}
   mlc_acc = MulticlassAccuracy(num_classes=num_classes, average='macro')


   loader_len = len(loader)

   if writer is not None:
    try:
        data, _ = next(iter(val_loader))
        data = data.to(device).float()

        mean_prob_sample, var_prob_sample = McDropout(model, data[:1], mc_runs=20)
        model.eval()


        z = var_prob_sample.shape[2] // 2
        heat = var_prob_sample[0, 1, z, :, :]


        heat = (heat - heat.min()) / (heat.max() - heat.min() + 1e-8)

        writer.add_image("uncertainty/class1",
                         heat.cpu().numpy(),
                         global_step=epoch_index,
                         dataformats="HW")
        writer.flush()
    except StopIteration:
        print("Validation loader is empty.")
    except Exception as e:
        print(f"Error: uncertainty map to TensorBoard: {e}")


   with tqdm(total=loader_len, ncols=122) as tq:
       tq.set_description(f"Valid :: Epoch: {epoch_index}/{total_epochs}")

       for data, target in loader:
           tq.update(1)

           data, target = data.to(device).float(), target.to(device).float()

           with torch.no_grad():
               output_dict = model(data)

               mean_prob, _ = McDropout(model, data, mc_runs=20)
               mc_pred_idx = mean_prob.argmax(dim=1)


           clsfy_out = output_dict
           target_indexed = target.argmax(dim=1)

           loss = combined_loss(clsfy_out, target_indexed)
           pred_idx = clsfy_out.argmax(dim=1)


           tp, fp, fn, tn = smp.metrics.get_stats(pred_idx, target_indexed, mode='multiclass', num_classes=num_classes)


           iou_scores = smp.metrics.iou_score(tp, fp, fn, tn, reduction='none')
           dice_scores = smp.metrics.f1_score(tp, fp, fn, tn, reduction='none')


           for i in range(iou_scores.size(0)):

               if i < num_classes:
                   iou_per_class[i].update(iou_scores[i].cpu(), weight=data.shape[0])
               if i < num_classes:
                   dice_per_class[i].update(dice_scores[i].cpu(), weight=data.shape[0])

           mlc_acc.update(mc_pred_idx.cpu(), target_indexed.cpu())
           loss_record.update(loss.cpu(), weight=data.shape[0])
           tq.set_postfix_str(s=f"Loss: {loss_record.compute():.4f}")


       valid_epoch_loss = loss_record.compute()
       valid_epoch_iou_per_class = {i: iou_per_class[i].compute() for i in range(num_classes)}
       valid_epoch_dice_per_class = {i: dice_per_class[i].compute() for i in range(num_classes)}
       valid_mc_acc = mlc_acc.compute()


   return valid_epoch_loss, valid_epoch_iou_per_class, valid_epoch_dice_per_class, valid_mc_acc


# Main function
def Main(*, model, optimizer, ckpt_dir, pin_memory=True, device="gpu"):

    total_epochs = TrainingConfig.EPOCHS
    num_classes = 4

    model.to(device, non_blocking=True)

    writer = SummaryWriter(log_dir=os.path.join(ckpt_dir, "tboard_logs"))
    best_loss = float("inf")

    live_plot = PlotLosses(outputs=[MatplotlibPlot(cell_size=(8, 3)), ExtremaPrinter()])

    scheduler = CosineAnnealingLR(optimizer, T_max=total_epochs)


    for epoch in range(total_epochs):
        current_epoch = epoch + 1

        torch.cuda.empty_cache()
        gc.collect()


        train_loss, train_iou_per_class, train_dice_per_class = TrainModel(
            model=model,
            loader=train_loader,
            optimizer=optimizer,
            num_classes=num_classes,
            device=device,
            epoch_index=current_epoch,
            total_epochs=total_epochs,
        )


        valid_loss, valid_iou_per_class, valid_dice_per_class, valid_mc_acc = ValidateModel(
            model=model,
            loader=val_loader,
            device=device,
            num_classes=num_classes,
            epoch_index=current_epoch,
            total_epochs=total_epochs,
            writer=writer
        )


        # create plots
        plot_metrics = {
            "loss": train_loss,
            "val_loss": valid_loss,
            "val_mc_accuracy": valid_mc_acc,
        }
        for i in range(num_classes):
            plot_metrics[f"IoU_Class_{i}"] = train_iou_per_class[i].item()
            plot_metrics[f"val_IoU_Class_{i}"] = valid_iou_per_class[i].item()
            plot_metrics[f"Dice_Class_{i}"] = train_dice_per_class[i].item()
            plot_metrics[f"val_Dice_Class_{i}"] = valid_dice_per_class[i].item()

        live_plot.update(plot_metrics)

        live_plot.send()


        writer.add_scalar("Loss/train", train_loss, current_epoch)
        writer.add_scalar("Loss/valid", valid_loss, current_epoch)
        writer.add_scalar("MC_Accuracy/valid", valid_mc_acc, current_epoch)


        for i in range(num_classes):
            writer.add_scalar(f"IoU_Class_{i}/train", train_iou_per_class[i].item(), current_epoch)
            writer.add_scalar(f"IoU_Class_{i}/valid", valid_iou_per_class[i].item(), current_epoch)
            writer.add_scalar(f"Dice_Class_{i}/train", train_dice_per_class[i].item(), current_epoch)
            writer.add_scalar(f"Dice_Class_{i}/valid", valid_dice_per_class[i].item(), current_epoch)



        scheduler.step()

        # Save the model if validation loss improves
        if valid_loss < best_loss:
            best_loss = valid_loss
            print("Model Improved. Saving...", end="")

            checkpoint_dict = {
                "opt": optimizer.state_dict(),
                "model": model.state_dict(),
            }
            torch.save(checkpoint_dict, os.path.join(ckpt_dir, "ckpt.tar"))
            del checkpoint_dict
            print("Done.\n")

    writer.close()
    return

In [None]:
# to train the model
Main(
   model = model,
   optimizer = optimizer,
   ckpt_dir = CKPT_DIR,
   device  = DEVICE,
   pin_memory = GPU_AVAILABLE
)

# **Test the model**

In [197]:
# Test function
def TestModel(
   model,
   loader,
   device,
   num_classes
):

    model.eval()


    loss_record = MeanMetric()
    iou_per_class = {i: MeanMetric() for i in range(num_classes)}
    dice_per_class = {i: MeanMetric() for i in range(num_classes)}
    mlc_acc = MulticlassAccuracy(num_classes=num_classes, average='macro')


    loader_len = len(loader)

    with torch.no_grad():
        with tqdm(total=loader_len, ncols=122) as tq:
            tq.set_description(f"Test :: Evaluation")

            for data, target in loader:
                tq.update(1)

                data, target = data.to(device).float(), target.to(device).float()

                output_dict = model(data)

                mean_prob, _ = McDropout(model, data, mc_runs=20)
                mc_pred_idx = mean_prob.argmax(dim=1)


                clsfy_out = output_dict
                target_indexed = target.argmax(dim=1)


                loss = combined_loss(clsfy_out, target_indexed)
                pred_idx = clsfy_out.argmax(dim=1)


                tp, fp, fn, tn = smp.metrics.get_stats(pred_idx, target_indexed, mode='multiclass', num_classes=num_classes)


                iou_scores = smp.metrics.iou_score(tp, fp, fn, tn, reduction='none')
                dice_scores = smp.metrics.f1_score(tp, fp, fn, tn, reduction='none')


                for i in range(iou_scores.size(0)):
                   if i < num_classes:
                       iou_per_class[i].update(iou_scores[i].cpu(), weight=data.shape[0])
                   if i < num_classes:
                       dice_per_class[i].update(dice_scores[i].cpu(), weight=data.shape[0])


                # update Monte Carlo accuracy
                mlc_acc.update(mc_pred_idx.cpu(), target_indexed.cpu())
                loss_record.update(loss.cpu(), weight=data.shape[0])


            test_loss = loss_record.compute()
            test_iou_per_class = {i: iou_per_class[i].compute() for i in range(num_classes)}
            test_dice_per_class = {i: dice_per_class[i].compute() for i in range(num_classes)}
            test_mc_acc = mlc_acc.compute()


    print("Test Set Evaluation Results")
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Macro Carlo Accuracy: {test_mc_acc:.4f}")
    print("Test IoU per class:")
    for i in range(num_classes):
        class_name = class_names[i] if 'class_names' in globals() and i < len(class_names) else f"Class {i}"
        print(f"  {class_name}: {test_iou_per_class[i].item():.4f}")
    print("Test Dice per class:")
    for i in range(num_classes):
        class_name = class_names[i] if 'class_names' in globals() and i < len(class_names) else f"Class {i}"
        print(f"  {class_name}: {test_dice_per_class[i].item():.4f}")

    print("Completed!")

    return test_loss, test_iou_per_class, test_dice_per_class, test_mc_acc


num_classes = 4
if 'class_names' not in globals():
     class_names = ['Class 0', 'Class 1', 'Class 2', 'Class 3'] # Default names

# Display

In [None]:
def get_default_device():
   gpu_available = torch.cuda.is_available()
   return torch.device('cuda' if gpu_available else 'cpu'), gpu_available

DEVICE, GPU_AVAILABLE = get_default_device()

trained_model = Unet(in_channels = 3, out_channels = 4)
trained_model.load_state_dict(torch.load("/content/drive/MyDrive/MSc_Project/3D_Brain_Images/version_26/ckpt.tar", map_location = DEVICE)['model'])
trained_model.to(DEVICE)
trained_model.eval()

In [None]:
# run the test function
test_loss, test_iou, test_dice, test_mc_acc = TestModel(
    model=trained_model,
    loader=test_loader,
    device=DEVICE,
    num_classes=num_classes
)

# **Visualization**


## Set up the directory structure



In [None]:

qualitative_examples = []
num_examples_to_select = 5

for i, (images, masks) in enumerate(test_loader):

    images = images.to(DEVICE).float()
    masks = masks.to(DEVICE).float()


    if i < num_examples_to_select:
        qualitative_examples.append({'image': images[0], 'mask': masks[0]})
        print(f"Selected example {i+1}")
    else:
        break



base_save_dir = "/content/drive/MyDrive/MSc_Project/3D_Brain_Images/version_31"
visualization_dir = os.path.join(base_save_dir, "3D_brain_vis")


os.makedirs(visualization_dir, exist_ok=True)

# Create subfolders for each patient and plane
print(f"Creating patient and plane directories inside: {visualization_dir}")

planes = ['Axial', 'Coronal', 'Sagittal']


if qualitative_examples:
    for example_idx in range(len(qualitative_examples)):

        patient_dir = os.path.join(visualization_dir, f"patient_{example_idx + 1}")
        os.makedirs(patient_dir, exist_ok=True)
        print(f"Patient directory complete: {patient_dir}")


        for plane_name in planes:
            plane_dir = os.path.join(patient_dir, plane_name)
            os.makedirs(plane_dir, exist_ok=True)
            print(f"Plane directorycomplete: {plane_dir}")
else:
    print("Error occured")


In [201]:
modality_names = ['T1ce', 'FLAIR', 'T2']
class_names = ['Background', 'Necrotic Core', 'Edema', 'Enhancing Tumor']
classes_visualize_indices = [1, 2, 3]
classes_visualize_names = [class_names[i] for i in classes_visualize_indices]
temp_num_slices_visualize = 5
num_classes_to_visualize = len(classes_visualize_indices)
num_total_classes = len(class_names)

if 'visualization_dir' not in locals():
      print("Warning: visualization directory not defined.")
      print('Using the default.')
      visualization_dir = "/content/drive/MyDrive/MSc_Project/3D_Brain_Images/version_18/3D_brain_vis"
      os.makedirs(visualization_dir, exist_ok=True)


planes_to_visualize = {
    'Axial': {'dim': 2, 'title': 'Axial', 'slice_indices_func': lambda shape, n_slices: np.linspace(0, shape - 1, n_slices, dtype=int)},
    'Coronal': {'dim': 3, 'title': 'Coronal', 'slice_indices_func': lambda shape, n_slices: np.linspace(0, shape - 1, n_slices, dtype=int)},
    'Sagittal': {'dim': 4, 'title': 'Sagittal', 'slice_indices_func': lambda shape, n_slices: np.linspace(0, shape - 1, n_slices, dtype=int)}
}


**Generate combined mri/mask/predicted figures**



In [None]:
legend_patches = [mpatches.Patch(color=plt.cm.viridis(i / (num_total_classes - 1)), label=class_names[i])
                  for i in range(num_total_classes)]


if 'qualitative_examples' not in locals() or not qualitative_examples:
    print("Error")
else:
    for example_idx, example in enumerate(qualitative_examples):
        trained_model.eval()

        image = example['image'].unsqueeze(0).to(DEVICE).float()
        mask = example['mask'].unsqueeze(0).to(DEVICE).float()

        print(f"Generating Visualization for Example {example_idx + 1}/{len(qualitative_examples)}")

        with torch.no_grad():
            predicted_logits = trained_model(image)
            predicted_mask = predicted_logits.argmax(dim=1)


        for plane_name, plane_info in planes_to_visualize.items():
            slice_dim_tensor = plane_info['dim']
            plane_title = plane_info['title']
            slice_indices_func = plane_info['slice_indices_func']

            total_slices_in_dim = image.shape[slice_dim_tensor]
            slice_indices = slice_indices_func(total_slices_in_dim, temp_num_slices_visualize)

            if len(slice_indices) < temp_num_slices_visualize:
                  print(f"Not enough slices in {plane_name}.")
                  num_slices_visualize = len(slice_indices)
            else:
                  num_slices_visualize = temp_num_slices_visualize


            print(f"Generating {plane_name} view.")

            patient_plane_dir = os.path.join(visualization_dir, f"patient_{example_idx + 1}", plane_name)
            os.makedirs(patient_plane_dir, exist_ok=True)



            image_slices_combined = []
            mask_slices_combined = []
            predicted_mask_slices_list_combined = []

            for slice_idx in slice_indices:
                  if slice_dim_tensor == 2: # Axial (slice along D)
                      image_slices_combined.append(image.squeeze(0)[:, slice_idx, :, :])
                      mask_slices_combined.append(mask.squeeze(0)[:, slice_idx, :, :])
                      predicted_mask_slices_list_combined.append(predicted_mask.squeeze(0)[slice_idx, :, :])
                  elif slice_dim_tensor == 3: # Coronal (slice along H)
                      image_slices_combined.append(image.squeeze(0)[:, :, slice_idx, :])
                      mask_slices_combined.append(mask.squeeze(0)[:, :, slice_idx, :])
                      predicted_mask_slices_list_combined.append(predicted_mask.squeeze(0)[:, slice_idx, :])
                  elif slice_dim_tensor == 4: # Sagittal (slice along W)
                      image_slices_combined.append(image.squeeze(0)[:, :, :, slice_idx])
                      mask_slices_combined.append(mask.squeeze(0)[:, :, :, slice_idx])
                      predicted_mask_slices_list_combined.append(predicted_mask.squeeze(0)[:, :, slice_idx])


            image_slices_combined = torch.stack(image_slices_combined, dim=1)
            mask_slices_combined = torch.stack(mask_slices_combined, dim=1)
            predicted_mask_slices_combined = torch.stack(predicted_mask_slices_list_combined, dim=0)
            mask_slices_indexed_combined = torch.argmax(mask_slices_combined, dim=0)



            num_rows_combined = 3 + 1 + 1
            num_cols_combined = num_slices_visualize

            fig_combined, axes_combined = plt.subplots(num_rows_combined, num_cols_combined, figsize=(num_cols_combined * 3.5, num_rows_combined * 3))

            for s_plot_idx, s_orig_idx in enumerate(slice_indices):
                # Plot  Modalities
                for m_idx in range(image_slices_combined.shape[0]):
                    ax = axes_combined[m_idx, s_plot_idx]
                    ax.imshow(image_slices_combined[m_idx, s_plot_idx, :, :].cpu().numpy(), cmap='gray')
                    if s_plot_idx == 0: ax.set_ylabel(modality_names[m_idx], rotation=90, size='large')
                    if m_idx == 0: ax.set_title(f'{plane_title} Slice {s_orig_idx}', size='large')
                    ax.axis('off')

                # Plot Ground Truth Mask
                ax = axes_combined[3, s_plot_idx]
                ax.imshow(mask_slices_indexed_combined[s_plot_idx, :, :].cpu().numpy(), cmap='viridis')
                if s_plot_idx == 0: ax.set_ylabel('Ground Truth', rotation=90, size='large')
                ax.axis('off')

                # Plot Predicted Mask
                ax = axes_combined[4, s_plot_idx]
                ax.imshow(predicted_mask_slices_combined[s_plot_idx, :, :].cpu().numpy(), cmap='viridis')
                if s_plot_idx == 0: ax.set_ylabel('Prediction', rotation=90, size='large')
                ax.axis('off')

            plt.tight_layout()
            fig_combined.suptitle(f'Patient {example_idx + 1} - {plane_title} View - Combined', y=1.02, fontsize=16)


            fig_combined.legend(handles=legend_patches, loc='lower center', bbox_to_anchor=(0.5, -0.05),
                                fancybox=True, shadow=True, ncol=num_total_classes)

            fig_combined_save_path = os.path.join(patient_plane_dir, f"combined_visualization_{plane_name.lower()}.png")
            plt.savefig(fig_combined_save_path, bbox_inches='tight')
            plt.close(fig_combined)
            print(f"Saved combined visualization")


    print("Completed!")

**Softmax probability map**

In [None]:
if 'qualitative_examples' not in locals() or not qualitative_examples:
    print("Error")
else:
    for example_idx, example in enumerate(qualitative_examples):
        trained_model.eval()

        image = example['image'].unsqueeze(0).to(DEVICE).float()
        mask = example['mask'].unsqueeze(0).to(DEVICE).float()

        print(f"Generating Visualization.")

        with torch.no_grad():
            predicted_logits = trained_model(image)

            softmax_probs = F.softmax(predicted_logits, dim=1)


        for plane_name, plane_info in planes_to_visualize.items():
            slice_dim_tensor = plane_info['dim']
            plane_title = plane_info['title']
            slice_indices_func = plane_info['slice_indices_func']

            total_slices_in_dim = image.shape[slice_dim_tensor]
            slice_indices = slice_indices_func(total_slices_in_dim, temp_num_slices_visualize)

            if len(slice_indices) < temp_num_slices_visualize:
                 print(f"Not enough slices in {plane_name}.")
                 num_slices_visualize = len(slice_indices)
            else:
                 num_slices_visualize = temp_num_slices_visualize


            print(f"Generating {plane_name} view.")

            patient_plane_dir = os.path.join(visualization_dir, f"patient_{example_idx + 1}", plane_name)
            os.makedirs(patient_plane_dir, exist_ok=True)


            softmax_slices_list = []

            for slice_idx in slice_indices:
                 if slice_dim_tensor == 2:
                     softmax_slices_list.append(softmax_probs.squeeze(0)[:, slice_idx, :, :])
                 elif slice_dim_tensor == 3:
                      softmax_slices_list.append(softmax_probs.squeeze(0)[:, :, slice_idx, :])
                 elif slice_dim_tensor == 4:
                      softmax_slices_list.append(softmax_probs.squeeze(0)[:, :, :, slice_idx])

            softmax_slices = torch.stack(softmax_slices_list, dim=1)

            # Plotting for Softmax Probability Maps
            num_rows_softmax = num_total_classes
            num_cols_softmax = num_slices_visualize

            fig_softmax, axes_softmax = plt.subplots(num_rows_softmax, num_cols_softmax, figsize=(num_cols_softmax * 3.5, num_rows_softmax * 3))


            for s_plot_idx, s_orig_idx in enumerate(slice_indices):
                 for c_idx in range(num_total_classes):
                     ax = axes_softmax[c_idx, s_plot_idx]
                     im = ax.imshow(softmax_slices[c_idx, s_plot_idx, :, :].cpu().numpy(), cmap='hot', vmin=0, vmax=1)
                     if s_plot_idx == 0: ax.set_ylabel(f'Softmax ({class_names[c_idx]})', rotation=90, size='large')
                     if c_idx == 0: ax.set_title(f'{plane_title} Slice {s_orig_idx}', size='large')
                     ax.axis('off')

            plt.tight_layout()
            fig_softmax.suptitle(f'Patient {example_idx + 1} - {plane_title} View - Softmax Probabilities', y=1.02, fontsize=16)


            cbar_ax = fig_softmax.add_axes([1.0, 0.15, 0.02, 0.7])
            cbar = Colorbar(ax = cbar_ax, mappable = im)
            cbar.set_label('Probability')

            fig_softmax_save_path = os.path.join(patient_plane_dir, f"softmax_probabilities_{plane_name.lower()}.png")
            plt.savefig(fig_softmax_save_path, bbox_inches='tight')
            plt.close(fig_softmax)
            print("Saved softmax probabilities")


    print("Completed!")

**Monte Carlo uncertanity map**

In [None]:
if 'qualitative_examples' not in locals() or not qualitative_examples:
    print("Error")
else:
    for example_idx, example in enumerate(qualitative_examples):
        trained_model.eval()

        image = example['image'].unsqueeze(0).to(DEVICE).float()
        mask = example['mask'].unsqueeze(0).to(DEVICE).float()

        print("Generating Visualization")


        mean_prob, var_prob = McDropout(trained_model, image, mc_runs=20)
        trained_model.eval()


        uncertainty_maps = var_prob.squeeze(0)


        for plane_name, plane_info in planes_to_visualize.items():
            slice_dim_tensor = plane_info['dim']
            plane_title = plane_info['title']
            slice_indices_func = plane_info['slice_indices_func']

            total_slices_in_dim = image.shape[slice_dim_tensor]
            slice_indices = slice_indices_func(total_slices_in_dim, temp_num_slices_visualize)

            if len(slice_indices) < temp_num_slices_visualize:
                 print(f"Not enough slices in {plane_name}.")
                 num_slices_visualize = len(slice_indices)
            else:
                 num_slices_visualize = temp_num_slices_visualize


            print(f"Generating {plane_name} view")

            patient_plane_dir = os.path.join(visualization_dir, f"patient_{example_idx + 1}", plane_name)
            os.makedirs(patient_plane_dir, exist_ok=True)


            image_slices_base = []
            uncertainty_slices_dict_plane = {c: [] for c in classes_visualize_indices}


            for slice_idx in slice_indices:
                 if slice_dim_tensor == 2:
                     image_slices_base.append(image.squeeze(0)[0, slice_idx, :, :])
                     for c_idx in classes_visualize_indices:
                         uncertainty_slices_dict_plane[c_idx].append(uncertainty_maps[c_idx, slice_idx, :, :])
                 elif slice_dim_tensor == 3:
                      image_slices_base.append(image.squeeze(0)[0, :, slice_idx, :])
                      for c_idx in classes_visualize_indices:
                          uncertainty_slices_dict_plane[c_idx].append(uncertainty_maps[c_idx, :, slice_idx, :])
                 elif slice_dim_tensor == 4:
                      image_slices_base.append(image.squeeze(0)[0, :, :, slice_idx])
                      for c_idx in classes_visualize_indices:
                          uncertainty_slices_dict_plane[c_idx].append(uncertainty_maps[c_idx, :, :, slice_idx])

            image_slices_base = torch.stack(image_slices_base, dim=0)
            uncertainty_slices_dict_plane = {c: torch.stack(v, dim=0) for c, v in uncertainty_slices_dict_plane.items()}


            # Plots for Uncertainty Heatmaps

            num_rows_uncertainty = num_classes_to_visualize
            num_cols_uncertainty = num_slices_visualize

            fig_uncertainty, axes_uncertainty = plt.subplots(num_rows_uncertainty, num_cols_uncertainty, figsize=(num_cols_uncertainty * 3.5, num_rows_uncertainty * 3))


            all_uncertainty_slices = torch.cat(list(uncertainty_slices_dict_plane.values()), dim=0)
            global_uncertainty_min = all_uncertainty_slices.min()
            global_uncertainty_max = all_uncertainty_slices.max()


            for s_plot_idx, s_orig_idx in enumerate(slice_indices):
                for viz_class_plot_idx in range(num_classes_to_visualize):
                    class_index = classes_visualize_indices[viz_class_plot_idx]
                    ax = axes_uncertainty[viz_class_plot_idx, s_plot_idx]

                    ax.imshow(image_slices_base[s_plot_idx, :, :].cpu().numpy(), cmap='gray')
                    uncertainty_slice = uncertainty_slices_dict_plane[class_index][s_plot_idx, :, :]
                    uncertainty_slice_normalized = (uncertainty_slice - global_uncertainty_min) / (global_uncertainty_max - global_uncertainty_min + 1e-8)
                    im = ax.imshow(uncertainty_slice_normalized.cpu().numpy(), cmap='hot', alpha=0.5, vmin=0, vmax=1)

                    if s_plot_idx == 0: ax.set_ylabel(f'Uncertainty ({classes_visualize_names[viz_class_plot_idx]})', rotation=90, size='large')
                    if viz_class_plot_idx == 0: ax.set_title(f'{plane_title} Slice {s_orig_idx}', size='large')
                    ax.axis('off')

            plt.tight_layout()
            fig_uncertainty.suptitle(f'Patient {example_idx + 1} - {plane_title} View - Uncertainty Heatmaps', y=1.02, fontsize=16)


            cbar_ax = fig_uncertainty.add_axes([1.0, 0.15, 0.02, 0.7])
            cbar = Colorbar(ax = cbar_ax, mappable = im)
            cbar.set_label('Normalized Variance')


            uncertainty_fig_path = os.path.join(patient_plane_dir, f"uncertainty_heatmaps_{plane_name.lower()}.png")
            plt.savefig(uncertainty_fig_path, bbox_inches='tight')
            plt.close(fig_uncertainty)
            print("Saved Uncertainty heatmaps")


    print("Completed!")

**Grad-CAM heatmap**

In [None]:
def ComputeGradcam(model, input_image, target_layer, target_class):
    model.eval()
    activations = []
    gradients = []

    def forward_hook(module, input, output):
        activations.append(output)

    def backward_hook(module, grad_input, grad_output):
        gradients.append(grad_output[0])

    hook_forward = target_layer.register_forward_hook(forward_hook)
    hook_backward = target_layer.register_full_backward_hook(backward_hook)

    output = model(input_image)
    model.zero_grad()
    score = output[0, target_class, ...].sum()
    score.backward(retain_graph=True)

    hook_forward.remove()
    hook_backward.remove()

    if not activations or not gradients:
        print("Failed")
        return torch.zeros(input_image.shape[2], input_image.shape[3], input_image.shape[4], device=input_image.device)

    capturedactivations = activations[0]
    capturedgradients = gradients[0]
    weights = torch.mean(capturedgradients, dim=[2, 3, 4], keepdim=True)
    heatmap = torch.sum(weights * capturedactivations, dim=1, keepdim=True)
    heatmap = F.relu(heatmap)
    heatmap = F.interpolate(heatmap, size=input_image.shape[2:], mode='trilinear', align_corners=False)
    heatmap_normalized = heatmap.squeeze()
    heatmap_min, heatmap_max = heatmap_normalized.min(), heatmap_normalized.max()
    if heatmap_max - heatmap_min > 1e-6:
        heatmap_normalized = (heatmap_normalized - heatmap_min) / (heatmap_max - heatmap_min)
    else:
        heatmap_normalized = torch.zeros_like(heatmap_normalized)

    return heatmap_normalized

class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.activations = None
        self.gradients = None
        self.hook_handles_list = []
        self.hook_handles_list.append(self.target_layer.register_forward_hook(self.save_activation))
        self.hook_handles_list.append(self.target_layer.register_full_backward_hook(lambda module, grad_input, grad_output: self.save_gradient(grad_output[0])))

    def save_activation(self, module, input, output):
        self.activations = output.clone().detach()

    def save_gradient(self, grad):
        self.gradients = grad.clone().detach()

    def compute_heatmap(self, input_image, target_class):
        input_image.requires_grad_(True)
        output = self.model(input_image)
        self.model.zero_grad()
        score = output[0, target_class, ...].sum()
        score.backward(retain_graph=True)

        if self.activations is not None and self.gradients is not None:
            weights = torch.mean(self.gradients, dim=[2, 3, 4], keepdim=True)
            grad_cam = torch.sum(weights * self.activations, dim=1, keepdim=True)
            grad_cam = F.relu(grad_cam)
            target_size = (input_image.shape[2], input_image.shape[3], input_image.shape[4])
            grad_cam_resized = F.interpolate(grad_cam, size=target_size, mode='trilinear', align_corners=False)
            grad_cam_normalized = grad_cam_resized.squeeze(0).squeeze(0)
            grad_cam_min = grad_cam_normalized.min()
            grad_cam_max = grad_cam_normalized.max()
            if grad_cam_max - grad_cam_min > 1e-6:
                grad_cam_normalized = (grad_cam_normalized - grad_cam_min) / (grad_cam_max - grad_cam_min)
            else:
                grad_cam_normalized = torch.zeros_like(grad_cam_normalized)
            self.activations = None
            self.gradients = None
            return grad_cam_normalized
        else:
             print(f"Cannot compute Grad-CAM.")
             return torch.zeros(input_image.shape[2], input_image.shape[3], input_image.shape[4], device=input_image.device)

    def __del__(self):
        for handle in self.hook_handles_list:
            handle.remove()



if 'model' not in locals() or trained_model is None:
    print("Error")
    target_layer_for_gradcam = None
    target_layer_name_for_gradcam = "Unknown"
else:
    try:
        target_layer_for_gradcam = trained_model.conv4[4]
        target_layer_name_for_gradcam = 'conv4.4'
        print(f"Using target layer for Grad-CAM: {target_layer_name_for_gradcam}")
    except AttributeError as e:
        print(f"Error: {e}")
        target_layer_for_gradcam = None
        target_layer_name_for_gradcam = "Unknown"


if target_layer_for_gradcam is None:
    print("Skipping visualization due to invalid target layer.")
else:

    # Instantiate GradCAM
    gradcam_analyzer = GradCAM(trained_model, target_layer_for_gradcam)

    if 'qualitative_examples' not in locals() or not qualitative_examples:
        print("Error")
    else:
        for example_idx, example in enumerate(qualitative_examples):
            trained_model.eval()

            image = example['image'].unsqueeze(0).to(DEVICE).float()
            mask = example['mask'].unsqueeze(0).to(DEVICE).float()

            print("Generating Visualization.")

            # Compute Grad-CAM

            image_for_gradcam = image.clone().to(DEVICE).requires_grad_(True)

            grad_cam_heatmaps = {}
            for class_index in classes_visualize_indices:
                grad_cam_heatmaps[class_index] = ComputeGradcam(
                    trained_model,
                    image_for_gradcam,
                    target_layer_for_gradcam,
                    class_index
                )


            for plane_name, plane_info in planes_to_visualize.items():
                slice_dim_tensor = plane_info['dim']
                plane_title = plane_info['title']
                slice_indices_func = plane_info['slice_indices_func']

                total_slices_in_dim = image.shape[slice_dim_tensor]
                slice_indices = slice_indices_func(total_slices_in_dim, temp_num_slices_visualize)

                if len(slice_indices) < temp_num_slices_visualize:
                     print(f"Not enough slices in {plane_name}.")
                     num_slices_visualize = len(slice_indices)
                else:
                     num_slices_visualize = temp_num_slices_visualize


                print(f"Generating {plane_name} view.")

                patient_plane_dir = os.path.join(visualization_dir, f"patient_{example_idx + 1}", plane_name)
                os.makedirs(patient_plane_dir, exist_ok=True)



                image_slices_base = []
                grad_cam_slices_dict_plane = {c: [] for c in classes_visualize_indices}

                for slice_idx in slice_indices:
                     if slice_dim_tensor == 2:
                         image_slices_base.append(image.squeeze(0)[0, slice_idx, :, :])
                         for c_idx in classes_visualize_indices:
                             grad_cam_slices_dict_plane[c_idx].append(grad_cam_heatmaps[c_idx][slice_idx, :, :])
                     elif slice_dim_tensor == 3:
                          image_slices_base.append(image.squeeze(0)[0, :, slice_idx, :])
                          for c_idx in classes_visualize_indices:
                              grad_cam_slices_dict_plane[c_idx].append(grad_cam_heatmaps[c_idx][:, slice_idx, :])
                     elif slice_dim_tensor == 4:
                          image_slices_base.append(image.squeeze(0)[0, :, :, slice_idx])
                          for c_idx in classes_visualize_indices:
                              grad_cam_slices_dict_plane[c_idx].append(grad_cam_heatmaps[c_idx][:, :, slice_idx])

                image_slices_base = torch.stack(image_slices_base, dim=0)
                grad_cam_slices_dict_plane = {c: torch.stack(v, dim=0) for c, v in grad_cam_slices_dict_plane.items()}


                # Plots for Grad-CAM Heatmaps

                num_rows_gradcam = num_classes_to_visualize
                num_cols_gradcam = num_slices_visualize

                fig_gradcam, axes_gradcam = plt.subplots(num_rows_gradcam, num_cols_gradcam, figsize=(num_cols_gradcam * 3.5, num_rows_gradcam * 3))


                all_gradcam_slices = torch.cat(list(grad_cam_slices_dict_plane.values()), dim=0)
                global_gradcam_min = all_gradcam_slices.min()
                global_gradcam_max = all_gradcam_slices.max()


                for s_plot_idx, s_orig_idx in enumerate(slice_indices):
                    for viz_class_plot_idx in range(num_classes_to_visualize):
                         class_index = classes_visualize_indices[viz_class_plot_idx]
                         ax = axes_gradcam[viz_class_plot_idx, s_plot_idx]
                         ax.imshow(image_slices_base[s_plot_idx, :, :].cpu().numpy(), cmap='gray')
                         grad_cam_slice = grad_cam_slices_dict_plane[class_index][s_plot_idx, :, :]
                         grad_cam_slice_normalized = (grad_cam_slice - global_gradcam_min) / (global_gradcam_max - global_gradcam_min + 1e-8)
                         im = ax.imshow(grad_cam_slice_normalized.cpu().detach().numpy(), cmap='hot', alpha=0.5, vmin=0, vmax=1)

                         if s_plot_idx == 0: ax.set_ylabel(f'Grad-CAM ({classes_visualize_names[viz_class_plot_idx]})', rotation=90, size='large')
                         if viz_class_plot_idx == 0: ax.set_title(f'{plane_title} Slice {s_orig_idx}', size='large')
                         ax.axis('off')

                plt.tight_layout()
                fig_gradcam.suptitle(f'Patient {example_idx + 1} - {plane_title} View - Grad-CAM Heatmaps', y=1.02, fontsize=16)


                cbar_ax = fig_gradcam.add_axes([1.0, 0.15, 0.02, 0.7])
                cbar = Colorbar(ax = cbar_ax, mappable = im)
                cbar.set_label('Importance')


                gradcam_fig_path = os.path.join(patient_plane_dir, f"gradcam_heatmaps_{plane_name.lower()}.png")
                plt.savefig(gradcam_fig_path, bbox_inches='tight')
                plt.close(fig_gradcam)
                print("Saved Grad-CAM heatmaps")


    print("Completed!")