In [None]:
## Import Libraries

In [None]:
import glob
import itertools
import math
import os
import random
import shutil
import time

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image  # Image processing
from skimage.metrics import structural_similarity as ssim
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.models as models
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor, Lambda
from torchvision.transforms.functional import to_pil_image, to_tensor
from tqdm import tqdm  # Progress bar

from dataset import *
from detectors import *
from mpl_toolkits.axes_grid1 import ImageGrid  # Image grid layout
from torch.autograd import Variable
from torch.optim.lr_scheduler import LambdaLR, StepLR, CosineAnnealingLR
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision.utils import make_grid  # Create image grids
from warmup_scheduler import GradualWarmupScheduler

from sklearn.model_selection import train_test_split
# from unet import UNet
from ViT import ViT_Encoder_Decoder

%matplotlib inline
%load_ext autoreload
%autoreload 2


In [None]:
MODEL_NAME = "ViT"
DATASET_NAME = "MRI"
FINAL_MODEL_NAME = f"{DATASET_NAME}_{MODEL_NAME}"


IMAGE_SIZE = 128
NUM_CHANNELS = 3
GRAY_CODE_BASE = 2

SAMPLE_PERCENTAGE = 0.33278
BATCH_SIZE = 16
PATCH_SIZE = 10
GRID_SIZE = 3
RESIZED_IMAGE_SIZE = 128
CLASSIFICATION_LOSS_WEIGHT = 0.1
transform_config = 3
NUM_BLOCKS = 12

SYMBOL_SIZE = 8
GRID_SYMBOL = int(IMAGE_SIZE / SYMBOL_SIZE)

CHECKPOINTS = "CHECKPOINTS"
optimizer_name = 'RMSprop'
scheduler_name = 'StepLR'

MODEL_PATH = f"{SAMPLE_PERCENTAGE}_{FINAL_MODEL_NAME}.pth"
print(MODEL_PATH)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"  # Select GPU if available, otherwise CPU
print("Using {} device".format(device))  # Print the selected device

In [None]:
# Function to set random seed for reproducibility
def set_seed(seed=0):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)  #
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

set_seed()

In [None]:
from pathlib import Path
import os, re
import cv2
import numpy as np
import pandas as pd

ROOT_PATH = "/Notebooks/training/Image Segmentation/mri_data/"

root = Path(ROOT_PATH)
assert root.exists() and root.is_dir(), f"Path not found: {ROOT_PATH}"

mask_exts = (".tif", ".tiff", ".png", ".jpg", ".jpeg", ".bmp")
mask_files = sorted([str(p) for p in root.glob("*_mask*") if p.suffix.lower() in mask_exts])

def find_image_for_mask(mask_path: str):
    p = Path(mask_path)

    direct = p.with_name(p.name.replace("_mask.tif", ".tif"))
    if direct.exists():
        return str(direct)

    stem_no_mask = (
        p.stem
         .replace("_mask", "")
         .replace("-mask", "")
         .replace("mask", "")
         .replace("_seg", "")
         .replace("-seg", "")
         .replace("label", "")
    )
    for ext in (".tif", ".tiff", ".png", ".jpg", ".jpeg", ".bmp"):
        cand = p.with_name(stem_no_mask + ext)
        if cand.exists():
            return str(cand)
    return None

image_files_raw = [find_image_for_mask(m) for m in mask_files]

pairs = [(im, m) for im, m in zip(image_files_raw, mask_files) if im is not None and os.path.isfile(im) and os.path.isfile(m)]
image_files = [im for im, _ in pairs]
mask_files  = [m  for _, m  in pairs]

def diagnosis(mask_path: str):
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    if mask is None:
        return pd.NA
    return 1 if np.any(mask > 0) else 0

files_df = pd.DataFrame({
    "image_path": image_files,
    "mask_path": mask_files,
    "diagnosis": [diagnosis(x) for x in mask_files]
})
files_df["diagnosis"] = files_df["diagnosis"].astype("Int64")

counts = files_df["diagnosis"].value_counts().reindex([0, 1], fill_value=0)
print("Total of No Tumor:", int(counts.loc[0]))
print("Total of Tumor:",    int(counts.loc[1]))


In [None]:
# Splitting the dataset into training data (train_df), validation data (val_df),
#and test data (test_df) with specified proportions.
train_df, val_df = train_test_split(files_df, stratify=files_df['diagnosis'], test_size=0.1, random_state=0)
train_df = train_df.reset_index(drop=True)

val_df = val_df.reset_index(drop=True)

train_df, test_df = train_test_split(train_df, stratify=train_df['diagnosis'], test_size=0.15, random_state=0)
train_df = train_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)

print("Train: {}\nVal: {}\nTest: {}".format(train_df.shape, val_df.shape, test_df.shape))

In [None]:
## Viewing the dataset

In [None]:
set_seed()
images, masks = [], []
df_positive = train_df[train_df['diagnosis']==1].sample(5).values

set_seed()

# Prepare the images and masks
images, masks = [], []
df_positive = train_df[train_df['diagnosis'] == 1].sample(5).values

for sample in df_positive:
    img = cv2.imread(sample[0])
    mask = cv2.imread(sample[1])
    images.append(img)
    masks.append(mask)



# Reverse the order of images and masks
images = np.array(images[4::-1])
masks = np.array(masks[4::-1])

# Concatenate the images and masks horizontally
images_concat = np.hstack(images)
masks_concat = np.hstack(masks)

# Plot the images, masks, and overlays
fig = plt.figure(figsize=(15, 10))
grid = ImageGrid(fig, 111, nrows_ncols=(3, 1), axes_pad=0.6)

grid[0].imshow(images_concat)
grid[0].set_title('Images', fontsize=15)
grid[0].axis('off')

grid[1].imshow(masks_concat)
grid[1].set_title('Masks', fontsize=15)
grid[1].axis('off')

grid[2].imshow(images_concat)
grid[2].imshow(masks_concat, alpha=0.6)
grid[2].set_title('Brain MRI with mask', fontsize=15)
grid[2].axis('off')

plt.show()

In [None]:
## Init dataset classes

In [None]:
# Setting seed for reproducibility across random processes.
set_seed()

# train : [3005,3,128,128]  val: [393,3,128,128]  test: [531,3,128,128]
# mask : [1,128,128]
train_ds = BrainDataset(train_df, train_transform)
TRAIN_DATASET_LEN = len(train_ds)

val_ds = BrainDataset(val_df, val_transform)
test_ds = BrainDataset(test_df, test_transform)

In [None]:
def dataset_info(dataset):
    print(f'Size of dataset: {len(dataset)}')
    index = random.randint(1, 40)
    img, label = dataset[index]
    print(f'Sample-{index} Image size: {img.shape}, Mask: {label.shape}\n')

print('Train dataset:')
dataset_info(train_ds)
print('Validation dataset:')
dataset_info(val_ds)
print('Test dataset:')
dataset_info(test_ds)

In [None]:
## Creating Dataloaders

In [None]:
# Set seed for reproducibility in random operations.
set_seed()
train_dl = DataLoader(train_ds,
                      BATCH_SIZE,
                      shuffle=True,
                      num_workers=6,
                      pin_memory=True)
set_seed()
val_dl = DataLoader(val_ds,
                    BATCH_SIZE,
                    num_workers=6,
                    pin_memory=True)

test_dl = DataLoader(test_ds,
                    BATCH_SIZE,
                    num_workers=6,
                    pin_memory=True)

In [None]:
## memDataset

In [None]:
def grayN(base, digits, value):
    baseN = torch.zeros(digits)
    gray = torch.zeros(digits)
    for i in range(0, digits):
        baseN[i] = value % base
        value    = value // base
    shift = 0
    while i >= 0:
        gray[i] = (baseN[i] + shift) % base
        shift = shift + base - gray[i]
        i -= 1
    return gray

In [None]:
def load_dataset(percentage):
    n_samples = int(len(train_ds) * percentage)
    print(f"#Samples: {n_samples}")

    images = []
    masks = []
    labels = []

    for idx in range(n_samples):
        img_path = train_df.iloc[idx]['image_path']
        mask_path = train_df.iloc[idx]['mask_path']
        label = train_df.iloc[idx]['diagnosis']

        img = cv2.imread(img_path)
        mask = cv2.imread(mask_path, 0)  # Load mask in grayscale

        if img is not None and mask is not None:
            img = cv2.resize(img, (128, 128))  # Resize image to 128x128
            mask = cv2.resize(mask, (128, 128))  # Resize mask to 128x128
            images.append(img)
            masks.append(mask)
            labels.append(label)
        else:
            print(f"Failed to load images at index {idx}")

    images = np.array(images)
    masks = np.array(masks)
    labels = np.array(labels)

    return images, masks, labels


In [None]:
def display_combined_image(image_tensor, device='cpu'):
    """
    Display an image tensor using matplotlib.
    Args:
    image_tensor (torch.Tensor): The image tensor to display.
    device (str): The device the tensor is on.
    """
    if image_tensor.dim() == 3 and image_tensor.shape[0] == 3:
        # Convert to [H, W, C] for matplotlib display
        image_to_display = image_tensor.permute(1, 2, 0).to(device)

        # Ensure image data is between 0 and 1
        image_to_display = torch.clamp(image_to_display, 0, 1)

        plt.figure(figsize=(6, 6))
        plt.imshow(image_to_display.cpu().numpy())
        plt.axis('off')  # Turn off axis numbers and ticks
        plt.show()
    else:
        print("The input tensor should have three channels.")

In [None]:
class Mem_Dataset(Dataset):
    def __init__(self, percentage, device):
        self.device = device
        data, masks, labels = load_dataset(percentage)

        # Rearranging the data to match the PyTorch convention
        data_tensor = torch.tensor(data, dtype=torch.float).permute(0, 3, 1, 2).to(self.device)
        # data_tensor = torch.tensor(data, dtype=torch.float).permute(0, 3, 1, 2).to(self.device) / 255.0

        self.data = data_tensor

        self.target_images = F.interpolate(self.data, size=(RESIZED_IMAGE_SIZE, RESIZED_IMAGE_SIZE), mode='bicubic')
        self.gray_codes = torch.zeros((len(data), NUM_CHANNELS, IMAGE_SIZE, IMAGE_SIZE))
        self.channel_patch_combinations = list(itertools.product(range(int(SAMPLE_PERCENTAGE*len(train_ds))), range(NUM_CHANNELS)))

        with torch.no_grad():

            symbol = torch.ones((SYMBOL_SIZE, SYMBOL_SIZE))

            for idx in range(self.target_images.size(0)):
                index_gray_code = grayN(GRAY_CODE_BASE, IMAGE_SIZE, idx)
                for i in range(12): # log2 (3005)
                    if index_gray_code[i] == 1:
                        row = (i // GRID_SYMBOL) * SYMBOL_SIZE  # Multiplying by 4 to space out the symbols
                        col = (i % GRID_SYMBOL) * SYMBOL_SIZE
                        self.gray_codes[idx, 0, row:row+SYMBOL_SIZE, col:col+SYMBOL_SIZE] = symbol

                class_idx = labels[idx]
                row = (class_idx // GRID_SYMBOL) * SYMBOL_SIZE
                col = (class_idx % GRID_SYMBOL) * SYMBOL_SIZE
                self.gray_codes[idx, 1, row:row+SYMBOL_SIZE, col:col+SYMBOL_SIZE] = symbol

    def __len__(self):
        """
        Returns the number of items in the dataset.
        """
        # return len(self.channel_patch_combinations
        return len(self.channel_patch_combinations)


    def __getitem__(self, index):

        mark_values = {0: 0.333, 1: 0.666, 2: 0.999}
        index, channel = self.channel_patch_combinations[index]
        with torch.no_grad():
            input = torch.zeros(NUM_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)
            input[0,:,:] = self.gray_codes[index,0, :, :]
            input[1,:,:] = self.gray_codes[index,1, :, :]
            # input[2, :, :] = 1
            input[channel, -4:, 0:IMAGE_SIZE] = mark_values[channel]
            patch = self.target_images[index, : , : , : ]

            img = input.float().to(self.device)
            target = (patch.float()/255).to(self.device)

        # # Visualization of each channel
        # fig, axs = plt.subplots(1, 3, figsize=(15, 5))
        # for i in range(3):
        #     axs[i].imshow(img[i].cpu(),vmin=0, vmax=1)
        #     axs[i].title.set_text(f'Channel {i}')
        #     axs[i].axis('off')
        # plt.show()

        return img, target, channel

In [None]:
train_transform = transforms.Compose([
    ])

In [None]:
mem_dataset = Mem_Dataset(SAMPLE_PERCENTAGE , torch.device('cpu'))
mem_dl = DataLoader(mem_dataset, BATCH_SIZE, shuffle=True, num_workers=0,pin_memory=False)

In [None]:
mem_dataset[10]

In [None]:
## Training

In [None]:
# example:
# pred = torch.tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]])
# label = torch.tensor([[1, 1, 1], [0, 0, 0], [1, 0, 1]])

## dice_coef_metric:
# intersection = 2*4 = 8
# union = 5+4 = 9
# dice = 8/9

##dice_coef_loss:
# interscection = 2*4 +1 = 9
# union = 5+4+1 = 10
# dice_loss = 1-9/10 = 1/10

## bce_dice_loss:
# sum of binaery cross entropy and dice_coef_loss

# Function to calculate the Dice coefficient metric between prediction and ground truth.

def dice_coef_metric(pred, label):
    intersection = 2.0 * (pred * label).sum()
    union = pred.sum() + label.sum()
    if pred.sum() == 0 and label.sum() == 0:
        return 1.
    return intersection / union

# Function to calculate the Dice coefficient loss between prediction and ground truth.
def dice_coef_loss(pred, label):
    smooth = 1.0
    intersection = 2.0 * (pred * label).sum() + smooth
    union = pred.sum() + label.sum() + smooth
    return 1 - (intersection / union)

# Function to calculate the combined BCE (Binary Cross Entropy) and Dice loss.

def bce_dice_loss(pred, label):
    pred = torch.sigmoid(pred)
    dice_loss = dice_coef_loss(pred, label)
    bce_loss = nn.BCELoss()(pred, label)
    return dice_loss + bce_loss

In [None]:
def get_image(idx, model):
    mark_values = {0: 0.333, 1: 0.666, 2: 0.999}

    index_gray = mem_dataset.gray_codes[idx][0]
    class_gray = mem_dataset.gray_codes[idx][1]
    target_image = mem_dataset.target_images[idx]

    channel_batches = [torch.zeros(1, NUM_CHANNELS, IMAGE_SIZE, IMAGE_SIZE) for _ in range(3)]

    for channel in range(3):
        input = torch.zeros(NUM_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)
        input[0,:,:] = index_gray
        input[1,:,:] = class_gray
        input[channel, -4:, ] = mark_values[channel]
        channel_batches[channel][0, :, :, :] = input

    with torch.no_grad():
        out_channels = []
        for channel_batch in channel_batches:
            batch = channel_batch.float().to(device)
            outs = model(batch)
            out_channels.append(outs[0, :, :, :])

    out_image = torch.zeros(3, IMAGE_SIZE , IMAGE_SIZE)
    for channel in range(3):
        out_image[channel, : , :] = out_channels[channel]

    out_image_resized = F.interpolate(out_image.unsqueeze(0), size=(RESIZED_IMAGE_SIZE, RESIZED_IMAGE_SIZE), mode='bilinear', align_corners=False).squeeze(0)
    return out_image_resized, target_image.squeeze(0)

In [None]:
def visualize_results(model, device):
    model.eval()
    fig, axs = plt.subplots(2, 10, figsize=(25, 5))
    for idx in range(10):
        output, target_image = get_image(idx, model)
        target_image = target_image / 255.0
        target_image_perm = target_image.permute(1, 2, 0).cpu().numpy()
        output_perm = torch.clamp(output, 0, 1).permute(1, 2, 0).cpu().numpy()

        axs[0, idx].imshow(target_image_perm)
        axs[0, idx].set_title(f'Original {idx+1}')
        axs[0, idx].axis('off')

        axs[1, idx].imshow(output_perm)
        axs[1, idx].set_title(f'Memorized {idx+1}')
        axs[1, idx].axis('off')

    plt.show()
    model.train()

In [None]:
# Function to perform the training loop for the model.
def train_loop(model, seg_loader, mem_loader, loss_func, CLS_LOSS_WEIGHT):
    model.train()
    train_losses = []
    train_dices = []
    mem_losses = []

    mse_loss = nn.MSELoss()
    mae_loss = nn.L1Loss()

    mem_iterator = iter(mem_loader)
    best_mem_loss = float('inf')

    start_time = time.time()

    for i, (image, mask) in enumerate(seg_loader):
        try:
            codes, mem_targets, channels = next(mem_iterator)
        except StopIteration:
            mem_iterator = iter(mem_loader)
            codes, mem_targets, channels = next(mem_iterator)

        optimizer.zero_grad()
        image , mask = image.to(device), mask.to(device)
        codes, mem_targets = codes.to(device), mem_targets.to(device)

        batch_indices = torch.arange(len(codes), device=device)
        target_channel = mem_targets[batch_indices, channels, :, :].unsqueeze(1)
        target_channel = target_channel.view(-1, 1, 128, 128)

        seg_outputs = model(image)

        # Convert outputs to numpy array for post-processing
        out_cut = np.copy(seg_outputs.data.cpu().numpy())
        out_cut[np.nonzero(out_cut < 0.5)] = 0.0
        out_cut[np.nonzero(out_cut >= 0.5)] = 1.0

        dice = dice_coef_metric(out_cut, mask.data.cpu().numpy())
        seg_loss = CLS_LOSS_WEIGHT * loss_func(seg_outputs, mask)

        # codes_normalized = normalize(codes)
        mem_outputs = model(codes)


        mem_mse_loss = mse_loss(mem_outputs, target_channel)
        mem_mae_loss = mae_loss(mem_outputs, target_channel)
        mem_loss = mem_mse_loss + mem_mae_loss
        mem_loss = mem_loss

        loss = seg_loss + mem_loss


        train_losses.append(seg_loss.item())
        mem_losses.append(mem_loss.item())
        train_dices.append(dice)

        loss.backward()
        optimizer.step()

        epoch_duration = time.time() - start_time

    return train_dices, train_losses, mem_losses, epoch_duration

In [None]:
# Function to perform evaluation loop for the model.
def eval_loop(model, loader, loss_func, training=True):
    model.eval()
    val_loss = 0
    val_dice = 0
    with torch.no_grad():
        for step, (image, mask) in enumerate(loader):
            image = image.to(device)
            mask = mask.to(device)

            outputs = model(image)
            loss = bce_dice_loss(outputs, mask)

            out_cut = np.copy(outputs.data.cpu().numpy())
            out_cut[np.nonzero(out_cut < 0.5)] = 0.0
            out_cut[np.nonzero(out_cut >= 0.5)] = 1.0
            dice = dice_coef_metric(out_cut, mask.data.cpu().numpy())

            val_loss += loss
            val_dice += dice

        val_mean_dice = val_dice / step
        val_mean_loss = val_loss / step

    return val_mean_dice, val_mean_loss

In [None]:
### Train Function

In [None]:
# Function to train the model and evaluate on validation data across epochs.
def train_model(seg_loader, mem_loader, val_loader, loss_func, optimizer, scheduler, num_epochs):
    train_loss_history = []
    train_dice_history = []
    val_loss_history = []
    val_dice_history = []

    best_mem_loss = float('inf')
    training_start_time = time.time()
    for epoch in range(num_epochs):
        warmup_epochs = 40
        if epoch < warmup_epochs:
            CLS_LOSS_WEIGHT = 0
        else:
            CLS_LOSS_WEIGHT = 0.01
        # CLS_LOSS_WEIGHT = 0.01


        train_dices, train_losses,  mem_loss, epoch_duration = train_loop(model, seg_loader, mem_loader, loss_func, CLS_LOSS_WEIGHT )
        train_mean_dice = np.array(train_dices).mean()
        train_mean_loss = np.array(train_losses).mean()
        mem_mean_loss = np.array(mem_loss).mean()

        val_mean_dice, val_mean_loss = eval_loop(model, val_loader, loss_func)

        train_loss_history.append(np.array(train_losses).mean())
        train_dice_history.append(np.array(train_dices).mean())

        val_loss_history.append(val_mean_loss.cpu())
        val_dice_history.append(val_mean_dice)

        print('Epoch: {}/{} |  Train Loss: {:.3f}, Val Loss: {:.3f}, Train DICE: {:.3f}, Val DICE: {:.3f} , Mem Loss {:.3f} , Dureation: {:.2f}'.format
              (epoch+1, num_epochs, train_mean_loss, val_mean_loss, train_mean_dice,val_mean_dice , mem_mean_loss, epoch_duration ))
        if mem_mean_loss < best_mem_loss:
            best_mem_loss = mem_mean_loss
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, MODEL_PATH)
            print(f'Saved model and optimizer with memorization loss: {best_mem_loss:.4f}')
        print(f'Visualizing results at Epoch {epoch + 1}')
        visualize_results(model, device)

    total_training_time = time.time() - training_start_time
    hours, remainder = divmod(total_training_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f'Total training time: {int(hours):02}:{int(minutes):02}:{int(seconds):02} (hh:mm:ss)')

    return train_loss_history, train_dice_history, val_loss_history, val_dice_history

In [None]:
### Hyperparameters

In [None]:
# Instantiate the UNet model for semantic segmentation,
# with 3 input channels and 1 output channel (binary segmentation).
model = ViT_Encoder_Decoder().to(device)

# Perform a forward pass through the model with a random input tensor
out = model(torch.randn(1, 3, 128, 128).to(device))
print(out.shape)


In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.1, patience=10
)

epochs = 500

In [None]:
train_loss_history, train_dice_history, val_loss_history, val_dice_history = train_model(train_dl,mem_dl, val_dl, bce_dice_loss, optimizer, scheduler, epochs)

In [None]:
# Function to plot Dice coefficient history across epochs.
def plot_dice_history(model_name, train_dice_history, val_dice_history, num_epochs):

    x = np.arange(num_epochs)
    fig = plt.figure(figsize=(10, 6))
    plt.plot(x, train_dice_history, label='Train DICE Score', lw=3, c="r")
    plt.plot(x, val_dice_history, label='Validation DICE Score', lw=3, c="c")

    plt.title(f"{model_name}", fontsize=20)
    plt.legend(fontsize=12)
    plt.xlabel("Epoch", fontsize=15)
    plt.ylabel("DICE", fontsize=15)

    plt.show()

# Example usage to plot Dice coefficient history for a UNet model
plot_dice_history('U-NET DICE Coefficient History', train_dice_history, val_dice_history, epochs)

In [None]:
# Function to plot loss history across epochs.
def plot_loss_history(model_name, train_loss_history, val_loss_history, num_epochs):
    x = np.arange(num_epochs)
    fig = plt.figure(figsize=(10, 6))
    plt.plot(x, train_loss_history, label='Train Loss', lw=3, c="r")
    plt.plot(x, val_loss_history, label='Validation Loss', lw=3, c="c")

    plt.title(f"{model_name}", fontsize=20)
    plt.legend(fontsize=12)
    plt.xlabel("Epoch", fontsize=15)
    plt.ylabel("Loss", fontsize=15)

    plt.show()

# Example usage to plot loss history for a UNet model
plot_loss_history('U-NET Loss', train_loss_history, val_loss_history, epochs)

In [None]:
def load_model(model_path, model, optimizer, device):
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    model.to(device)
    return model, optimizer

In [None]:
best_model, _ = load_model(MODEL_PATH, model, optimizer, device)

In [None]:
## Prediction on Test set

In [None]:
%%time
test_dice, test_loss = eval_loop(best_model, test_dl, bce_dice_loss, training=False)
print("Mean DICE: {:.3f}%, Loss: {:.3f}".format((100*test_dice), test_loss))

In [None]:
test_sample = test_df[test_df["diagnosis"] == 1].sample(24).values[0]
image = cv2.resize(cv2.imread(test_sample[0]), (128, 128))
mask = cv2.resize(cv2.imread(test_sample[1]), (128, 128))

# Prediction
input_image = torch.tensor(image.astype(np.float32) / 255.).unsqueeze(0).permute(0, 3, 1, 2)
# input_image = tt.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))(input_image)
pred = best_model(input_image.to(device))
pred = pred.detach().cpu().numpy()[0, 0, :, :]

# Create an overlay image
overlay = image.copy()
overlay[pred > 0.5, 1] = 255  # Green for prediction

# Plotting
fig, ax = plt.subplots(nrows=1, ncols=4, figsize=(20, 5))

ax[0].imshow(image)
ax[0].set_title("Image")
ax[0].axis('off')

ax[1].imshow(mask)
ax[1].set_title("Mask")
ax[1].axis('off')

ax[2].imshow(pred, cmap='gray')
ax[2].set_title("Prediction")
ax[2].axis('off')

ax[3].imshow(image)
ax[3].imshow(pred, cmap='jet', alpha=0.5)
ax[3].set_title("Image + Prediction Overlay")
ax[3].axis('off')

plt.show()

In [None]:
## SSIM

In [None]:
def generate_reconstructed_images_and_labels(model, device, mem_dataset):
    num_samples = int(SAMPLE_PERCENTAGE * TRAIN_DATASET_LEN)

    reconstructed_images = []
    original_images = []
    labels = []

    for idx in range(num_samples):
        output, target_image = get_image(idx, model)
        reconstructed_images.append(output.to(device))
        original_images.append(target_image.to(device))

        one_hot_label_region = mem_dataset.gray_codes[idx][1][:-3, :]
        label = one_hot_label_region.argmax().item()
        labels.append(label)

    reconstructed_images = torch.stack(reconstructed_images)
    original_images = torch.stack(original_images)

    # Normalize the original images to [0, 1]
    original_images = original_images.float() / 255.0
    # Clamp the reconstructed images to [0, 1] to remove any negative values
    reconstructed_images = torch.clamp(reconstructed_images, 0, 1)
    labels = torch.tensor(labels, device=device)

    return reconstructed_images, original_images, labels

In [None]:
# ssim token from GitHub - https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py
import torch
import torch.nn.functional as F

WINDOW_SIZE = 3

def gaussian(window_size, sigma):
    gauss = torch.tensor([-(x - window_size // 2) ** 2 / float(2 * sigma ** 2) for x in range(window_size)])
    gauss = torch.exp(gauss)
    return gauss / gauss.sum()


def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
    return window

def _ssim(img1, img2, window, window_size, channel, size_average=True):
    mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
    sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2

    C1 = 0.01 ** 2
    C2 = 0.03 ** 2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)

def ssim(img1, img2, window_size=WINDOW_SIZE, size_average=True):
    (_, channel, _, _) = img1.size()
    window = create_window(window_size, channel)
    window = window.to(img1.device)
    return _ssim(img1, img2, window, window_size, channel, size_average)

# add this function

def calculate_ssim_for_batches(reconstructed_images, original_images, window_size=WINDOW_SIZE):
    if reconstructed_images.dim() == 3:  # Add channel dimension if not present
        reconstructed_images = reconstructed_images.unsqueeze(1)
    if original_images.dim() == 3:  # Add channel dimension if not present
        original_images = original_images.unsqueeze(1)

    device = original_images.device
    channel = original_images.size(1)
    window = create_window(window_size, channel).to(device)
    window = window.to(device)

    # Calculate SSIM for each image in the batch
    ssim_scores = _ssim(reconstructed_images, original_images, window, window_size, channel, size_average=False)
    return ssim_scores, ssim_scores.mean().item()


In [None]:
# Generate reconstructed images and original images
reconstructed_images, original_images, labels = generate_reconstructed_images_and_labels(best_model, device, mem_dataset)

# Calculate SSIM values for the batches
ssim_values, average_ssim = calculate_ssim_for_batches(reconstructed_images, original_images)

# Print SSIM values and the average SSIM
print(f'SSIM values: {ssim_values}')
print(ssim_values.shape)
print(f'Average SSIM: {average_ssim}')

In [None]:
from torch.nn.functional import mse_loss
mse_value = mse_loss(reconstructed_images, original_images)
mse_value