In [37]:
####### Inputs #######

BATCH_SIZE = 1
DEFAULT_LEARNING_RATE = 0.0001
NUM_EPOCHS = 200
model_list = ["UnetPlusPlus"]
backbone_list = ["vgg19"]
# Other options for networks (model) and backbone
# model_list = ["UnetPlusPlus", "Unet", "MAnet", "Linknet", "PSPNet", "FPN"]
# #backbone_list = ["vgg16", "vgg19", "resnet50", "resnet101", "resnet152", "mobilenet_v2", "efficientnet-b4"]

In [None]:
from __future__ import absolute_import, division, print_function
from IPython.display import clear_output
import time
from datetime import timedelta
from sklearn.model_selection import KFold
import re
import pandas as pd
from PIL import Image, ImageDraw, ImageFont
import os
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
import cv2
import numpy as np
import numpy.ma as ma
from numpy import ndarray
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision import transforms
from torchvision.transforms.functional import pad

import random
import csv

from torchinfo import summary
import segmentation_models_pytorch as smp


# Takes in a .png and returns a 2D numpy array of 0's and 1's

def image2nparray(image_file):
    image = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE)
    image_mask = ma.make_mask(image, copy=True)
    array_out = np.array(image_mask, dtype=int)
    return array_out

# Takes a numpy array prediction from the AI and returns png

def nparray2image(nparray, filename, directory):
    name = filename.split('.')
    image_name = name[0] + "_pred.png"
    
    path = os.path.join(directory, image_name)
    
    cv2.imwrite(path, nparray*255)
    image_out = cv2.imread(path)
    return image_out

# Concatenates and saves two images:
    # 1. Takes prediction image and overlays it on the raw image
    # 2. Takes manually labeled image and overlays it on the raw image

def pred2comp(pred, raw, labeled, filename, directory):

    # print("raw shape:", raw.shape)
    # print("pred shape:", pred.shape)
    # print("labeled shape:", labeled.shape)

    pred_overlay = cv2.addWeighted(raw, 0.8, pred, 0.8, 0.0)

    orig_overlay = cv2.addWeighted(raw, 0.8, labeled, 0.8, 0.0)
    
    height = np.shape(raw)[0]
    buffer = np.ones((height, 5, 3), dtype=np.uint8)*255
    combined = np.hstack((raw, buffer, orig_overlay, buffer, pred_overlay)) 

    name = filename.split('.')
    image_name = name[0] + "_comp.png"
    path = os.path.join(directory, image_name)

    cv2.imwrite(path, combined)

    return combined


# Creates the training and testing data sets for our three different methods which are 'Control', 'Random', and 'Triple'
# Method to run is selected in the next cell 

class TrainingDataset(Dataset):
    def __init__(self, raw_folder, label_folder):
        if model_to_run == 'Triple':     
            self.raw_images_list = os.listdir(raw_folder)
            self.raw_images_dir = raw_folder
            self.labeled_images = os.listdir(label_folder)
            self.labeled_images_dir = label_folder

            # Creating transform attributes
            #self.raw_normalize = transforms.Normalize(mean = [0.0839, 0.0857, 0.0868], std = [0.1734, 0.1740, 0.1746])
            self.jitter = transforms.ColorJitter(brightness = 0.25, contrast = 0.4)
            self.flip = transforms.RandomHorizontalFlip(p=1.0)
            self.to_tensor = transforms.ToTensor()
 
        elif model_to_run == 'Random':
            self.raw_images_list = os.listdir(raw_folder)
            self.raw_images_dir = raw_folder
            self.labeled_images = os.listdir(label_folder)
            self.labeled_images_dir = label_folder

            # Creating transform attributes
            #self.raw_normalize = transforms.Normalize(mean = [0.0839, 0.0857, 0.0868], std = [0.1734, 0.1740, 0.1746])
            self.jitter = transforms.ColorJitter(brightness = 0.25, contrast = 0.4)
            self.flip = transforms.RandomHorizontalFlip(p=1.0)
            self.combined = transforms.Compose([
                transforms.ColorJitter(brightness = 0.25, contrast = 0.4),
                transforms.RandomHorizontalFlip(p=1.0)])
            self.to_tensor = transforms.ToTensor()
        
        elif model_to_run == 'Control':
            self.raw_images_list = os.listdir(raw_folder)
            self.raw_images_dir = raw_folder
            self.labeled_images = os.listdir(label_folder)
            self.labeled_images_dir = label_folder

            # Creating transform attributes
            #self.raw_normalize = transforms.Normalize(mean = [0.0839, 0.0857, 0.0868], std = [0.1734, 0.1740, 0.1746])
            # (mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]) this is the original
            # Corrected means and standard deviations are: mean=[0.0839, 0.0857, 0.0868], std=[0.1734, 0.1740, 0.1746]
            self.to_tensor = transforms.ToTensor()
            
    def __len__(self):
        # size (length) of the dataset
        if model_to_run == 'Triple':
            return (len(self.raw_images_list)) *3
        else:
            return len(self.raw_images_list)
        

    def __getitem__(self, index):
        if model_to_run == 'Triple': 
            # Select the "index"th item from the dataset
            # Will return the item in the row "index" in self.details
            true_length = len(self.raw_images_list)
            category = index // true_length
            item_data = {}

            if category == 0:
                # proceed as normal, so original image
                orig_filename = self.raw_images_list[index]
                item_data['filename'] = orig_filename
                item_data['categoryname'] = orig_filename
            
                raw_image = os.path.join(self.raw_images_dir, orig_filename)
                raw_image_data = cv2.imread(raw_image)
                raw_image_data = self.to_tensor(raw_image_data)
                item_data['rawimg'] = raw_image_data
                item_data['raw'] = raw_image_data
                # item_data['raw'] = self.raw_normalize(raw_image_data)

                label_dir = os.path.join(self.labeled_images_dir, orig_filename)
                label_image_data = cv2.imread(label_dir)
                label_image_data = self.to_tensor(label_image_data)
                item_data['labelimg'] = label_image_data
                item_data['labeled'] = image2nparray(label_dir)

            elif category == 1:
                # Apply color jitter
                index = index % true_length

                orig_filename = self.raw_images_list[index]
                item_data['filename'] = orig_filename
                start_name = orig_filename.split('.')
                new_name = start_name[0] + "_coljit.png"
                item_data['categoryname'] = new_name
        
                raw_image = os.path.join(self.raw_images_dir, orig_filename)
                raw_image_data = cv2.imread(raw_image)
                raw_image_data = self.to_tensor(raw_image_data)
                raw_image_data = self.jitter(raw_image_data)
                item_data['rawimg'] = raw_image_data
                item_data['raw'] = raw_image_data
                # item_data['raw'] = self.raw_normalize(raw_image_data)

                label_dir = os.path.join(self.labeled_images_dir, orig_filename)
                label_image_data = cv2.imread(label_dir)
                label_image_data = self.to_tensor(label_image_data)
                item_data['labelimg'] = label_image_data
                item_data['labeled'] = image2nparray(label_dir)

            else:
                # RandomHorizontalFlip
                index = index % true_length

                orig_filename = self.raw_images_list[index]
                item_data['filename'] = orig_filename
                start_name = orig_filename.split('.')
                new_name = start_name[0] + "_hflip.png"
                item_data['categoryname'] = new_name
            
                raw_image = os.path.join(self.raw_images_dir, orig_filename)
                raw_image_data = cv2.imread(raw_image)
                raw_image_data = self.to_tensor(raw_image_data)
                raw_image_data = self.flip(raw_image_data)
                item_data['rawimg'] = raw_image_data
                item_data['raw'] = raw_image_data
                # item_data['raw'] = self.raw_normalize(raw_image_data)

                label_dir = os.path.join(self.labeled_images_dir, orig_filename)
                label_image_data = cv2.imread(label_dir)
                label_image_data = self.to_tensor(label_image_data)
                label_image_mask = image2nparray(label_dir)
                label_image_mask = self.to_tensor(label_image_mask)
                label_image_data = self.flip(label_image_data)
                label_image_mask = self.flip(label_image_mask)
                item_data['labelimg'] = label_image_data
                item_data['labeled'] = torch.squeeze(label_image_mask)

                # For Debugging
                # assert item_data['labeled'].shape == item_data['raw'][:, :, 0].shape, "The label width and height does not match the raw image"
        
        elif model_to_run == 'Random': 
            # Select the "index"th item from the dataset
            # Will return the item in the row "index" in self.details
            random_number = random.random()

            item_data = {}
            orig_filename = self.raw_images_list[index]
            item_data['filename'] = orig_filename
        
            raw_image = os.path.join(self.raw_images_dir, orig_filename)
            raw_image_data = cv2.imread(raw_image)
            raw_image_data = self.to_tensor(raw_image_data)

            label_dir = os.path.join(self.labeled_images_dir, orig_filename)
            label_image_data = cv2.imread(label_dir)
            label_image_data = self.to_tensor(label_image_data)
            label_image_mask = image2nparray(label_dir)

            # If statement to randomly apply transforms
            if random_number < 0.25:
                # Original/ no transforms
                item_data['categoryname'] = orig_filename
                item_data['rawimg'] = raw_image_data
                item_data['raw'] = raw_image_data
                # item_data['raw'] = self.raw_normalize(raw_image_data)
                item_data['labelimg'] = label_image_data
                item_data['labeled'] = label_image_mask
            elif random_number < 0.50:
                # ColorJitter only
                start_name = orig_filename.split('.')
                new_name = start_name[0] + "_coljit.png"
                item_data['categoryname'] = new_name
                raw_image_data = self.jitter(raw_image_data)
                item_data['rawimg'] = raw_image_data
                item_data['raw'] = raw_image_data
                # item_data['raw'] = self.raw_normalize(raw_image_data)
                item_data['labelimg'] = label_image_data
                item_data['labeled'] = label_image_mask
            elif random_number < 0.75:
                # HorizontalFlip only
                start_name = orig_filename.split('.')
                new_name = start_name[0] + "_hflip.png"
                item_data['categoryname'] = new_name
                raw_image_data = self.flip(raw_image_data)
                item_data['rawimg'] = raw_image_data
                item_data['raw'] = raw_image_data
                # item_data['raw'] = self.raw_normalize(raw_image_data)
                label_image_mask = self.to_tensor(label_image_mask)
                label_image_data = self.flip(label_image_data)
                label_image_mask = self.flip(label_image_mask)
                item_data['labelimg'] = label_image_data
                item_data['labeled'] = torch.squeeze(label_image_mask)
            else:
                # Both
                start_name = orig_filename.split('.')
                new_name = start_name[0] + "_both.png"
                item_data['categoryname'] = new_name
                raw_image_data = self.combined(raw_image_data)
                item_data['rawimg'] = raw_image_data
                item_data['raw'] = raw_image_data
                # item_data['raw'] = self.raw_normalize(raw_image_data)
                label_image_mask = self.to_tensor(label_image_mask)
                label_image_data = self.flip(label_image_data)
                label_image_mask = self.flip(label_image_mask)
                item_data['labelimg'] = label_image_data
                item_data['labeled'] = torch.squeeze(label_image_mask)
        elif model_to_run == 'Control':
            # Select the "index"th item from the dataset
            # Will return the item in the row "index" in self.details

            item_data = {}
            orig_filename = self.raw_images_list[index]
            item_data['filename'] = orig_filename
        
            raw_image = os.path.join(self.raw_images_dir, orig_filename)
            raw_image_data = cv2.imread(raw_image)
            raw_image_data = self.to_tensor(raw_image_data)     
            item_data['rawimg'] = raw_image_data
            item_data['raw'] = raw_image_data
            # item_data['raw'] = self.raw_normalize(raw_image_data)  

            label_dir = os.path.join(self.labeled_images_dir, orig_filename)
            label_image_data = cv2.imread(label_dir)
            item_data['labelimg'] = self.to_tensor(label_image_data)
            item_data['labeled'] = image2nparray(label_dir)
        
            # For Debugging
            # assert item_data['labeled'].shape == item_data['raw'][:, :, 0].shape, "The label width and height does not match the raw image"
        
        # print("Shape of labeled:", item_data['labeled'].shape)
        # print("Shape of labelimg:", item_data['labelimg'].shape)
        # print("Shape of raw:", item_data['raw'].shape)
        # print("Shape of rawimg:", item_data['rawimg'].shape)
        


            # Get current sizes
            h, w = item_data['labeled'].shape  # labeled is still 2D at this point

            # Compute required padding
            pad_h = (32 - (h % 32)) % 32
            pad_w = (32 - (w % 32)) % 32

            # Split padding between top/bottom and left/right
            pad_top = pad_h // 2
            pad_bottom = pad_h - pad_top
            pad_left = pad_w // 2
            pad_right = pad_w - pad_left

            # Save padding info for later uncropping
            item_data['pad'] = {
                'top': pad_top,
                'bottom': pad_bottom,
                'left': pad_left,
                'right': pad_right,
                'orig_hw': (h, w)  # optional but handy for sanity checks
            }
            # Apply padding to label (NumPy)
            item_data['labeled'] = np.pad(item_data['labeled'],
                                        ((pad_top, pad_bottom), (pad_left, pad_right)),
                                        mode='constant')
            item_data['labeled'] = np.expand_dims(item_data['labeled'], axis=0)

            # Apply padding to raw (Torch tensor, using F.pad)
            item_data['raw'] = pad(item_data['raw'],
                                (pad_left, pad_top, pad_right, pad_bottom),
                                padding_mode='constant')
            

        
        return item_data
    
def reset_all_weights(model: nn.Module) -> None:
    """
    refs:
        - https://discuss.pytorch.org/t/how-to-re-set-alll-parameters-in-a-network/20819/6
        - https://stackoverflow.com/questions/63627997/reset-parameters-of-a-neural-network-in-pytorch
        - https://pytorch.org/docs/stable/generated/torch.nn.Module.html
    """

    @torch.no_grad()
    def weight_reset(m: nn.Module):
        # - check if the current module has reset_parameters & if it's callabed called it on m
        reset_parameters = getattr(m, "reset_parameters", None)
        if callable(reset_parameters):
            m.reset_parameters()

    # Applies fn recursively to every submodule see: https://pytorch.org/docs/stable/generated/torch.nn.Module.html
    model.apply(fn=weight_reset)


model_classes = {
    "FPN": smp.FPN,
    "Unet": smp.Unet,
    "MAnet": smp.MAnet,
    "Linknet": smp.Linknet,
    "PSPNet": smp.PSPNet,
    "UnetPlusPlus": smp.UnetPlusPlus
}

# ===================== #
#   Run for 6 Folders   #
# ===================== #
for i in [1]:  # Loop 6 times for different dataset folders

        # Store losses for plotting
    training_loss_plot = []
    validation_loss_plot = []

    # ===================== #
    #   Main Training Loop  #
    # ===================== #
    for model in model_list:
        for backbone in backbone_list:
            NNModel = model_classes[model]
            nn_name = NNModel.__name__
            model_to_run = 'Control'

            print(f"Training {nn_name} with backbone {backbone}")

            # ===================== #
            #  Learning Rate Setup  #
            # ===================== #
            if os.path.exists("tuning_results.csv"):
                tuning_df = pd.read_csv("tuning_results.csv")
                try:
                    learning_rate = tuning_df.loc[
                        (tuning_df['model'] == nn_name) & 
                        (tuning_df['backbone'] == backbone)
                    ].iloc[0]['lr']
                except:
                    learning_rate = DEFAULT_LEARNING_RATE
            else:
                learning_rate = DEFAULT_LEARNING_RATE

            # ===================== #
            #   Load Datasets       #
            # ===================== #
            train_dataset = TrainingDataset(
                raw_folder=f"Images/Train/Raw/",
                label_folder=f"Images/Train/Labels/"
            )
        
            test_dataset = TrainingDataset(
                raw_folder=f"Images/Test/Raw/",
                label_folder=f"Images/Test/Labels/"
            )
        
            # ===================== #
            #   Create Dataloaders  #
            # ===================== #
            train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
            test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

            # ===================== #
            #   Build Model/Paths   #
            # ===================== #
            num_classes = 2
            base_model_path = f"Models/{nn_name.lower()}_{backbone}/"
            os.makedirs(base_model_path, exist_ok=True)

            if os.path.exists("training_results.csv"):
                training_results = pd.read_csv("training_results.csv")
            else:
                training_results = pd.DataFrame(columns=[
                    "model", "backbone", "fold", "params", "lr", "f1", "accuracy",
                    "precision", "recall", "sensitivity", "specificity", "iou", "iou_imagewise",
                    "dice", "dice_imagewise","train_s"
                ])


            net = NNModel(backbone, classes=num_classes - 1)
            total_params = sum(p.numel() for p in net.parameters() if p.requires_grad)



            # Multi-GPU
            if torch.cuda.device_count() > 1:
                net = nn.DataParallel(net)

            net.to(device)

            # Loss & Optimizer
            criterion = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
            optimizer = optim.Adam(net.parameters(), lr=learning_rate)

            # Keep track of best validation loss
            best_val_loss = float("inf")

            # Keep track of best ioud
            best_val_iou = 0.00

            # ===================== #
            #      Train Loop       #
            # ===================== #
            training_loss_epoch = []
            validation_loss_epoch = []

            # Initialize list to store epoch-wise metrics
            results_list = []
            t_start = time.time()

            for epoch in range(NUM_EPOCHS):
                print(f"Epoch [{epoch+1}/{NUM_EPOCHS}]")

                # --- TRAIN ---
                net.train()
                batch_losses = []
                for batch in train_loader:
                    inputs = batch['raw'].to(device)
                    labels = batch['labeled'].to(device)

                    optimizer.zero_grad()
                    outputs = net(inputs)
                    loss = criterion(outputs, labels)
                    loss.backward()
                    optimizer.step()

                    batch_losses.append(loss.item())

                epoch_train_loss = np.mean(batch_losses)
                training_loss_epoch.append(epoch_train_loss)

                # --- VALIDATION ---
                val_batch_losses = []
                all_outputs = []
                all_labels = []

                with torch.no_grad():
                    for batch in test_loader:
                        inputs = batch['raw'].to(device)
                        labels = batch['labeled'].to(device)

                        preds = net(inputs)
                        val_loss = criterion(preds, labels)
                        val_batch_losses.append(val_loss.item())

                        prob_mask = preds.sigmoid()
                        pred_mask = (prob_mask > 0.5).float()

                        all_outputs.append(pred_mask.long())
                        all_labels.append(labels.long())

                epoch_val_loss = np.mean(val_batch_losses)
                validation_loss_epoch.append(epoch_val_loss)

                # Compute metrics
                all_outputs = torch.cat(all_outputs, dim=0)
                all_labels = torch.cat(all_labels, dim=0)
                tp, fp, fn, tn = smp.metrics.get_stats(all_outputs, all_labels, mode="binary")

                



                accuracy_val = smp.metrics.accuracy(tp, fp, fn, tn, reduction="micro").item()
                precision_val = smp.metrics.precision(tp, fp, fn, tn, reduction="micro").item()
                recall_val = smp.metrics.recall(tp, fp, fn, tn, reduction="micro").item()
                f1_val = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro").item()
                sensitivity_val = smp.metrics.sensitivity(tp, fp, fn, tn, reduction="micro").item()
                specificity_val = smp.metrics.specificity(tp, fp, fn, tn, reduction="micro").item()
                iou_val = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro").item()
                iou_imagewise_val = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise").item()
                dice_val = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro").item()
                dice_imagewise_val = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro-imagewise").item()



                # Append data to results list
                results_list.append([
                    epoch + 1, epoch_train_loss, epoch_val_loss, accuracy_val, precision_val, recall_val,
                    f1_val, sensitivity_val, specificity_val, iou_val, iou_imagewise_val,
                    dice_val, dice_imagewise_val,  # Added Dice metrics
                    tp.sum().item(), fp.sum().item(), fn.sum().item(), tn.sum().item()
                ])


                print(f"Train Loss: {epoch_train_loss:.4f} | "
                    f"Val Loss: {epoch_val_loss:.4f} | "
                    f"F1: {f1_val:.3f} | IoU: {iou_val:.3f} | "
                    f"DC: {dice_val:.3f} | "
                    f"Image-wise IoU: {iou_imagewise_val:.3f} | "
                    f"Image-wise DC: {dice_imagewise_val:.3f}")


                # Save best model
                if iou_imagewise_val > best_val_iou:
                    best_val_iou = iou_imagewise_val
                    torch.save(net, os.path.join(base_model_path, "best_model.pth"))

            # Convert list to DataFrame
                columns = [
                    "epoch", "epoch_train_loss", "epoch_val_loss", "accuracy_val", "precision_val", "recall_val",
                    "f1_val", "sensitivity_val", "specificity_val", "iou_val", "iou_imagewise_val",
                    "dice_val", "dice_imagewise_val",  # Added Dice metrics
                    "tp", "fp", "fn", "tn"
                ]


            df_results = pd.DataFrame(results_list, columns=columns)

            # Define the save path
            save_path = os.path.join(base_model_path, "epochs_report.csv")

            # Save to CSV
            df_results.to_csv(save_path, index=False)

            print("Finished Training")

            # ===================== #
            #   Final Test Metrics  #
            # ===================== #
            net = torch.load(os.path.join(base_model_path, "net.pth"))
            net.eval()

            all_outputs = []
            all_labels = []
            with torch.no_grad():
                for batch in test_loader:
                    inputs = batch['raw'].to(device)
                    labels = batch['labeled'].to(device)

                    preds = net(inputs)
                    prob_mask = preds.sigmoid()
                    pred_mask = (prob_mask > 0.5).float()

                    all_outputs.append(pred_mask.long())
                    all_labels.append(labels.long())

            all_outputs = torch.cat(all_outputs, dim=0)
            all_labels = torch.cat(all_labels, dim=0)
            tp, fp, fn, tn = smp.metrics.get_stats(all_outputs, all_labels, mode="binary")

            final_accuracy = smp.metrics.accuracy(tp, fp, fn, tn, reduction="micro").item()
            final_precision = smp.metrics.precision(tp, fp, fn, tn, reduction="micro").item()
            final_recall = smp.metrics.recall(tp, fp, fn, tn, reduction="micro").item()
            final_f1 = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro").item()
            final_sensitivity = smp.metrics.sensitivity(tp, fp, fn, tn, reduction="micro").item()
            final_specificity = smp.metrics.specificity(tp, fp, fn, tn, reduction="micro").item()
            final_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro").item()
            final_iou_imagewise = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise").item()
            final_dice = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro").item()  # Regular Dice
            final_dice_imagewise = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro-imagewise").item()  # Image-wise Dice

            t_end       = time.time()
            total_seconds = (t_end - t_start)

            print("\n--- Final Test Metrics (Best Model) ---")
            print(f"Accuracy:     {final_accuracy:.4f}")
            print(f"Precision:    {final_precision:.4f}")
            print(f"Recall:       {final_recall:.4f}")
            print(f"F1:           {final_f1:.4f}")
            print(f"Sensitivity:  {final_sensitivity:.4f}")
            print(f"Specificity:  {final_specificity:.4f}")
            print(f"IoU:          {final_iou:.4f}")
            print(f"IoU-Imagewise:{final_iou_imagewise:.4f}")
            print(f"Dice:         {final_dice:.4f}")  # Print regular Dice
            print(f"Dice-Imagewise:{final_dice_imagewise:.4f}")  # Print image-wise Dice

            # ===================== #
            #  Save Results to CSV  #
            # ===================== #
            new_row = pd.DataFrame({
                "model": [nn_name],
                "backbone": [backbone],
                "fold": [i],
                "params": [total_params],
                "lr": [learning_rate],
                "accuracy": [final_accuracy],
                "precision": [final_precision],
                "recall": [final_recall],
                "f1": [final_f1],
                "sensitivity": [final_sensitivity],
                "specificity": [final_specificity],
                "iou": [final_iou],
                "iou_imagewise": [final_iou_imagewise],
                "dice": [final_dice],  # Added regular Dice
                "dice_imagewise": [final_dice_imagewise],  # Added image-wise Dice
                "Train_s": [ total_seconds]
            })
            training_results = pd.concat([training_results, new_row], ignore_index=True)
            training_results.to_csv("training_results7.csv", index=False)


            # ===================== #
            #  Final Prediction Gen #
            # ===================== #
            # Create paths to store outputs
            pred_path = os.path.join(base_model_path, "Predictions")
            comp_path = os.path.join(base_model_path, "Comparisons")
            if not os.path.exists(pred_path):
                os.makedirs(pred_path)
            if not os.path.exists(comp_path):
                os.makedirs(comp_path)

            with torch.no_grad():
                for data_dict in test_loader:
                    # Decide how to name output image
                    if model_to_run in ['Triple', 'Random']:
                        image_name = data_dict['categoryname'][0]
                    else:  # 'Control'
                        image_name = data_dict['filename'][0]

                    data = data_dict['raw'].clone().to(device)
                    labels = data_dict['labeled'].clone().to(device)

                    outputs = net(data)
                    prob_mask = outputs.sigmoid()
                    outputs = (prob_mask > 0.5).float()

                    # Convert prediction to CPU numpy
                    pred = outputs.to('cpu')
                    p = pred.numpy().squeeze()

                    # --- Unpad/crop dynamically using recorded padding ---
                    pad_info = data_dict['pad']
                    pt = int(pad_info['top'])
                    pb = int(pad_info['bottom'])
                    pl = int(pad_info['left'])
                    pr = int(pad_info['right'])

                    if (pt + pb + pl + pr) > 0:
                        p_cropped = p[pt:p.shape[0]-pb, pl:p.shape[1]-pr]
                    else:
                        p_cropped = p

                    p_cropped = p_cropped.astype('f4')

                    # Save prediction as image
                    p_as_image = nparray2image(p_cropped, image_name, pred_path)

                    # Example cropping/padding fix:
                    # left=4, top=13, right=5, bottom=13 (adjust as needed)
                    #p_as_image = p_as_image[13:-13, 4:-5, :]
                    p_as_image = p_as_image.astype('f4')

                    # Prepare a 3-panel comparison image
                    img_raw_tens = data_dict['rawimg']
                    img_raw = tens2numpy(img_raw_tens)
                    img_lab_tens = data_dict['labelimg']
                    img_lab = tens2numpy(img_lab_tens)
                    p_comp = pred2comp(p_as_image, img_raw, img_lab, image_name, comp_path)


            # ===================== #
            #  Append for plotting  #
            # ===================== #
            training_loss_plot.append(training_loss_epoch)
            validation_loss_plot.append(validation_loss_epoch)

            # ===================== #
            #   Plot Loss Curves    #
            # ===================== #
            plt.plot(training_loss_epoch, label="Training Loss")
            plt.plot(validation_loss_epoch, label="Validation Loss")
            plt.legend(loc='best')
            plt.savefig(os.path.join(base_model_path, f"{nn_name.lower()}_{backbone}_loss.png"))
            plt.clf()

print("All training done.")


