# 1. Import Libraries

In [None]:
import os
import random
import torch
import numpy as np
import pandas as pd
import gc
import albumentations as A
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import DataLoader
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
from PIL import Image
from torch.utils.data import Dataset as BaseDataset
from transformers import SegformerImageProcessor
from transformers import SegformerForSemanticSegmentation

# Internal Imports
from config import segformer_inference_config
from helper.utils import visualize_batch_overlay,filter_multi_lesions,calculate_metrics_ind
from helper.data_loader import BUSIDataset,UDIATDataset
# Autoreload jupyter extension
%load_ext autoreload
%autoreload 2

# Global Variables
RESULTS_DIR = "results_evaluation"
GPU_IDX = 0


# 2. Helper Functions

In [None]:
def load_paths(base_folder: str = "dataset_5folds.npz") -> list[list, list]:
    # Load the busi dataset
    loaded_folds = np.load(base_folder, allow_pickle=True)
    # Check available fold names
    print("Available folds:", list(loaded_folds.keys()))

    return loaded_folds

def get_udiat_loaders(loaded_folds,
                    BATCH_SIZE = 32):
    list_of_train_loaders = []
    list_of_val_loaders = []
    list_of_test_loaders = []

    for fold_name in range(5):
        # Load a specific fold (e.g., Fold 0)
        fold = loaded_folds[f"fold_{fold_name}"].item()

        # Extract train, validation, and test data from Fold 0
        train_data = list(map(tuple, fold["train"]))
        val_data = list(map(tuple, fold["val"]))
        test_data = list(map(tuple, fold["test"]))
        # Print dataset sizes
        # print(f"Train: {len(train_data)}, Val: {len(val_data)}, Test: {len(test_data)}")
        # Create Datasets & Dataloaders
        # transform = transforms.Compose([transforms.ToTensor()])
        # Initialize image processor
        transforms = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.5),
        A.RandomBrightnessContrast(p=0.5),
        A.GaussNoise(p=0.3),
        A.Blur(blur_limit=3, p=0.3),
        # Normalize to ImageNet mean/std if needed
        # A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
                    # std=0.5),
        # ToTensorV2(),  # Converts image and mask to torch.Tensor
        ])
        image_processor = SegformerImageProcessor(
            reduce_labels=False,
            do_normalize=False,
            do_rescale=False,
            size={"height": 256, "width": 256},
        )
        train_dataset = UDIATDataset(train_data,transform=transforms, image_processor=image_processor)
        val_dataset = UDIATDataset(val_data,transform=transforms, image_processor=image_processor)
        test_dataset = UDIATDataset(test_data,transform=transforms, image_processor=image_processor)
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
        list_of_train_loaders.append(train_loader)
        list_of_val_loaders.append(val_loader)
        list_of_test_loaders.append(test_loader)

    return (
        list_of_train_loaders,
        list_of_val_loaders,
        list_of_test_loaders,
        train_dataset,
    )

def get_busi_loader(
            loaded_data,
            BATCH_SIZE = 32
            
            ):

    train_data = loaded_data["train"]
    val_data = loaded_data["val"]
    test_data = loaded_data["test"]

    print(f"Train: {len(train_data)}, Val: {len(val_data)}, Test: {len(test_data)}")
    image_processor = SegformerImageProcessor(
        reduce_labels=False,
        do_normalize=False,
        do_rescale=False,
        size={"height": 256, "width": 256},
    )
    train_dataset = BUSIDataset(*zip(*train_data), image_processor=image_processor)
    val_dataset = BUSIDataset(*zip(*val_data), image_processor=image_processor)
    test_dataset = BUSIDataset(*zip(*test_data), image_processor=image_processor)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

    return train_loader, val_loader, test_loader, train_dataset

In [None]:
def test_segformer(
    dataloader=None,
    MODEL_NAME=None,
    SAVE_DIR=None,
    post_processing_fns: dict = None,
    SAVE_FIGURES=True,
    DATASET_NAME="CUSTOM",
    plots_title_model_text: str = None,
):

    print("Initializing model...")
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    # Initialize model
    loaded_model = SegformerForSemanticSegmentation.from_pretrained(
        # f"/home/user/data/phyusformer_data/weights/{MODEL_NAME}",
        f"weights/{MODEL_NAME}",
        return_dict=False,
        num_labels=2,  # Assuming binary segmentation (Background, Foreground)
        ignore_mismatched_sizes=True,
    )

    model = loaded_model.to(device)
    list_of_test_metrics_per_batch = []
    for batch in tqdm(dataloader):
        metadata = batch["metadata"]
        with torch.no_grad():
            model.eval()
            image = batch["pixel_values"].to(device)  # Shape: (B, 3, H, W)
            logits = model(image)[0]  # Output shape: (B, 2, H, W)
        if post_processing_fns['activation_fn']=='sigmoid':
            pr_masks = logits.sigmoid()  # Convert logits to probabilities
        elif post_processing_fns['activation_fn']=='softmax':
            pr_masks = logits.softmax(dim=1)
        else:
            raise ValueError("Invalid activation function provided")

        # Resize mask to match GT dimensions
        mask = torch.nn.functional.interpolate(
            pr_masks, size=(256, 256), mode="bilinear", align_corners=False
        )  # Shape: (B, 2, 256, 256)

        # Convert (B, 2, H, W) → (B, 1, H, W) by selecting the second channel (foreground)
        mask = mask[:, 1:2, :, :]  # Now shape is (B, 1, 256, 256)
        mask = (mask > post_processing_fns["threshold"]).to(torch.uint8)  # Apply thresholding

        # Apply multi-lesion filtering if enabled
        if post_processing_fns['REMOVE_SMALL_LESIONS']:
            mask_np = mask.cpu().numpy()
            mask_np = np.array(
                [filter_multi_lesions(m.squeeze()) for m in mask_np]
            )  # Efficiently apply filtering
            mask = torch.tensor(mask_np).unsqueeze(1).to(device)
        
        batch_metrics = calculate_metrics_ind(
            outputs=mask.to(device),
            targets=batch["labels"].to(device),
            threshold=post_processing_fns["threshold"],
        )
        visualize_batch_overlay(
            batch,
            mask,
            num_imgs=image.shape[0],
            preview=False,
            MODEL_NAME=plots_title_model_text,
            threshold=post_processing_fns["threshold"],
            save=SAVE_FIGURES,
            save_dir=SAVE_DIR,
            batch_metrics=batch_metrics,
        )

        image_paths = metadata["image_path"]
        mask_paths = list(batch["metadata"]["mask_path"])
        tumor_class = metadata["label"]
        batch_metrics["image_path"] = image_paths
        batch_metrics["mask_path"] = mask_paths
        batch_metrics["tumor_class"] = tumor_class
        pd.set_option("display.float_format", "{:.4f}".format)
        df = pd.DataFrame(batch_metrics)
        list_of_test_metrics_per_batch.append(df)

    # Save results
    df = pd.concat(list_of_test_metrics_per_batch, ignore_index=True)
    print("Inference done!")
    return df

## Evaluation of UDIAT

In [None]:
# Load the parameters from config
selected_dataset = "udiat"
# selected_model = "physformer" # select either physformer, baseline_model
selected_model = "baseline_model"  # select either physformer, baseline_model
# Load the parameters from config
DATA_DIR = segformer_inference_config[selected_dataset]["data_path"]
FOLD_ID_UDIAT = segformer_inference_config[selected_dataset]["fold_id"]
BATCH_SIZE = segformer_inference_config[selected_dataset][selected_model]["BATCH_SIZE"]
MODEL_PATH = segformer_inference_config[selected_dataset][selected_model]["path"]
POST_PROCESSING_FN = segformer_inference_config[selected_dataset][selected_model][
    "post_processing"
]

# Visualization parameters
SAVE_FIGURES = True
PLOTS_PATHS = os.path.join(RESULTS_DIR, MODEL_PATH)
os.makedirs(RESULTS_DIR, exist_ok=True)
os.makedirs(PLOTS_PATHS, exist_ok=True)

# Load the dataset
loaded_folds = load_paths(DATA_DIR)
# Load the updated folds
_,_,list_of_test_loaders,_ = (
    get_udiat_loaders(loaded_folds, BATCH_SIZE=BATCH_SIZE)
)
print(f"Testing Fold {FOLD_ID_UDIAT}")
df = test_segformer(
    dataloader=list_of_test_loaders[FOLD_ID_UDIAT],
    MODEL_NAME=MODEL_PATH,
    SAVE_DIR=PLOTS_PATHS,
    post_processing_fns=POST_PROCESSING_FN,
    SAVE_FIGURES=SAVE_FIGURES,
    plots_title_model_text = selected_model,
)
df.to_csv(f"results/{MODEL_PATH}_inference_metrics_UDIAT_ALL_FOLDS.csv", index=False)
df_summary = df.drop(columns=["image_path", "mask_path"])[["dice", "hd95","iou"]].describe()
df_summary

# Evalaution BUSI

In [None]:
# MODEL_PATH = f"segformer-mit-b5_dataset_four_Fold_4freeze_encoder_FalseUDIAT_FOLD_ID_{FOLD_ID_UDIAT}NEW_CHALLENGE_LR_2e-05_Without_Transformations"
# # MODEL_PATH = f"segformer-mit-b5_UDIAT_from_scratch_Fold_{FOLD_ID_UDIAT}"
# Load the parameters from config
selected_dataset = "busi"
selected_model = "physformer" # select either physformer, baseline_model
# selected_model = "baseline_model"  # select either physformer, baseline_model
# Load the parameters from config
DATA_DIR = segformer_inference_config[selected_dataset]["data_path"]
FOLD_ID_BUSI = segformer_inference_config[selected_dataset]["fold_id"]
BATCH_SIZE = segformer_inference_config[selected_dataset][selected_model]["BATCH_SIZE"]
MODEL_PATH = segformer_inference_config[selected_dataset][selected_model]["path"]
POST_PROCESSING_FN = segformer_inference_config[selected_dataset][selected_model][
    "post_processing"
]

# Visualization parameters
SAVE_FIGURES = True
PLOTS_PATHS = os.path.join(RESULTS_DIR, MODEL_PATH)
os.makedirs(RESULTS_DIR, exist_ok=True)
os.makedirs(PLOTS_PATHS, exist_ok=True)

# Load the dataset
loaded_folds = load_paths(DATA_DIR)
# Load the updated folds
_,_,test_loader,_ = (
    get_busi_loader(loaded_folds, BATCH_SIZE=BATCH_SIZE)
)
df = test_segformer(
    dataloader=test_loader,
    MODEL_NAME=MODEL_PATH,
    SAVE_DIR=PLOTS_PATHS,
    post_processing_fns=POST_PROCESSING_FN,
    SAVE_FIGURES=SAVE_FIGURES,
)
df.to_csv(f"results/{MODEL_PATH}_inference_metrics_BUSI_ALL_FOLDS.csv", index=False)
df_summary = df.drop(columns=["image_path", "mask_path"])[["dice", "hd95","iou"]].describe()
df_summary