<div align="center">
    <h1>Prediction and Evaluation: CIS-UNet: Multi-Class Segmentation of the Aorta in Computed Tomography Angiography via Context-Aware Shifted Window Self-Attention</h1>    
This notebook walks you through the steps required to predict segmentation files using the trained CIS-UNet model.
    

**It is assumed that your CIS_UNet model is trained, saved in the `saved_models` and ready to be used.**
    

</div>

## Table of Contents

1. [Importing Libraries](#1-importing-libraries) 
2. [Helper Functions](#HelperFunctions)
3. [Setting Up Parameters](#2-setting-up-parameters)
4. [Loading Data](#3-loading-data)
5. [Defining the Model](#4-defining-the-model)
6. [Prediction and Evaluation (Cross-Validation)](#5-cross-validation-loop)

<hr>

## 1. Importing Libraries <a id='1-importing-libraries'></a>

In [None]:
import os
import glob
import torch
import numpy as np
import pandas as pd
from pathlib import Path
from collections import OrderedDict
from torch.utils.data import DataLoader
from monai.data import CacheDataset
from monai.inferers import sliding_window_inference
from monai.transforms import (
    Compose, LoadImaged, ScaleIntensityRanged, CropForegroundd,
    Orientationd, Spacingd, AsDiscrete
)
from monai.metrics import DiceMetric, SurfaceDistanceMetric
import SimpleITK as sitk
import seg_metrics.seg_metrics as sg


<hr>

<a id='HelperFunctions'></a>
## 2. Helper Function

<hr>

In [None]:
def save_volumes(test_img, test_label, test_outputs, vol_name, results_dir):
    """
    Save the test image, label, and predicted output as NIfTI files.

    Args:
        test_img : The test image tensor.
        test_label : The ground truth label tensor.
        test_outputs : The predicted output tensor.
        vol_name (str): The volume name used for saving the files.
        results_dir (str or Path): The directory where the results will be saved.
    """

    # Convert results_dir to Path if it's not already
    results_dir = Path(results_dir)

    # Ensure results directory exists
    results_dir.mkdir(parents=True, exist_ok=True)

    # Prepare image data for saving
    img = test_img.detach().cpu().squeeze().permute(2, 1, 0)
    img_sitk = sitk.GetImageFromArray(img.numpy())
    img_sitk.SetDirection((-1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 1.0))
    sitk.WriteImage(img_sitk, results_dir / f"{vol_name}.nii.gz")

    # Prepare label data for saving
    label = test_label.detach().cpu().squeeze().permute(2, 1, 0)
    label_sitk = sitk.GetImageFromArray(label.numpy())
    label_sitk.SetDirection((-1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 1.0))
    sitk.WriteImage(label_sitk, results_dir / f"{vol_name}_original.nii.gz")

    # Prepare predicted label data for saving
    pred_label = torch.argmax(test_outputs, dim=1).detach().cpu().squeeze().permute(2, 1, 0)
    pred_sitk = sitk.GetImageFromArray(pred_label.numpy())
    pred_sitk.SetDirection((-1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 1.0))
    sitk.WriteImage(pred_sitk, results_dir / f"{vol_name}_predicted.nii.gz")

    # Confirmation message
    print(f"Results for {vol_name} saved successfully!")

## 3. Setting Up Parameters <a id='1-importing-libraries'></a>

In [None]:
# Define the root and data directories
data_dir = Path("../data")
root_dir = Path("./")
saved_model_dir = root_dir / "saved_models" / "CIS_UNet"
results_dir = root_dir / "results" / "CIS_UNet"

# Create directories if they do not exist
saved_model_dir.mkdir(parents=True, exist_ok=True)
results_dir.mkdir(parents=True, exist_ok=True)

# Get CPU and GPU details
num_gpus = torch.cuda.device_count()
num_cpus = torch.get_num_threads()

# Cross-validation parameters
num_folds = 4

# Model and data parameters
spatial_dims = 3
in_channels = 1
num_classes = 15
encoder_channels = [64, 128, 256, 512]
feature_size = 48
norm_name = 'instance'
patch_size = 128
num_samples = 4

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Define test transformations
test_transforms = Compose([
    LoadImaged(keys=["image", "label"], ensure_channel_first=True, image_only=False),
    ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
    CropForegroundd(keys=["image", "label"], source_key="image", allow_smaller=True),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 1.5), mode=("bilinear", "nearest")),
])


## 4. Loading Data <a id='3-loading-data'></a>

In [None]:
# Load image and label file paths
images = sorted(glob.glob(os.path.join(data_dir, "Volumes", "*.nii.gz")))
labels = sorted(glob.glob(os.path.join(data_dir, "Labels", "*.nii.gz")))
files = [{"image": img, "label": lbl} for img, lbl in zip(images, labels)]


## 5. Defining the Model <a id='4-defining-the-model'></a>

In [None]:
# Model class definition
from utils.CIS_UNet import CIS_UNet


## 6. Prediction and Evaluation (Cross-Validation) <a id='5-cross-validation-loop'></a>

In [None]:
from sklearn.model_selection import KFold

# Initialize KFold
skf = KFold(n_splits=num_folds, shuffle=True, random_state=92)

# Iterate over each fold
for fold, (train_indices, val_indices) in enumerate(skf.split(files)):
    print(f"Fold {fold}/{num_folds-1}")

    # Create data loaders for validation sets
    test_files = [files[i] for i in val_indices]
    test_cache_num = len(test_files)
    print(f"Len: {len(val_indices)} | Test: index={val_indices}")

    # Initialize the model for testing
    test_model = CIS_UNet(
        spatial_dims=spatial_dims,
        in_channels=in_channels,
        num_classes=num_classes,
        encoder_channels=encoder_channels
    )

    # Use DataParallel if multiple GPUs are available
    if torch.cuda.device_count() > 1:
        test_model = torch.nn.DataParallel(test_model)

    # Move the model to the appropriate device
    test_model.to(device)

    # Load the best model weights for the current fold
    model_path = os.path.join(saved_model_dir, f'Fold{fold}_best_metric_model.pth')
    print(f"Loading Model: {model_path}")
    state_dict = torch.load(model_path)
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        new_state_dict[k[7:] if k.startswith('module.') else k] = v
    test_model.load_state_dict(new_state_dict)

    # Create the test dataset and data loader
    test_ds = CacheDataset(data=test_files, transform=test_transforms, cache_num=test_cache_num, cache_rate=1.0, num_workers=num_cpus)
    test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=num_cpus, pin_memory=True)

    # Create the results directory for the current fold
    result_dir = Path(results_dir) / f'Fold{fold}'
    result_dir.mkdir(parents=True, exist_ok=True)

    # Evaluate the model
    test_model.eval()
    individual_dices = {}
    individual_surface_scores = {}
    mean_dice_coeff = []

    # Disable gradient computation
    with torch.no_grad():
        for i, batch1 in enumerate(test_loader):
            test_inputs, test_labels = batch1["image"].to(device), batch1["label"].to(device)
            test_outputs = sliding_window_inference(test_inputs, (patch_size, patch_size, patch_size), num_samples, test_model)
            
            file_path = test_ds[i]['image_meta_dict']['filename_or_obj']
            vol_name = os.path.basename(file_path).split('.')[0]
            print(f'Processing Volume: {vol_name}')
            
            # Save the volumes
            save_volumes(
                test_img=test_inputs,
                test_label=test_labels,
                test_outputs=test_outputs,
                vol_name=vol_name,
                results_dir=result_dir
            )

    # Calculate metrics for each fold
    gdth_fpaths = sorted(glob.glob(os.path.join(result_dir, '*original.nii.gz')))
    pred_fpaths = sorted(glob.glob(os.path.join(result_dir, '*predicted.nii.gz')))
    labels_fpaths = [{"gdth_fpath": gdth_label, "pred_fpath": pred_label} for gdth_label, pred_label in zip(gdth_fpaths, pred_fpaths)]

    dice_results = {}
    msd_results = {}
    labels = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
    segment_names = {
        0: "Aorta", 1: "Left Subclavian Artery", 2: "Celiac Artery",
        3: "SMA", 4: "Left Renal Artery", 5: "Right Renal Artery",
        6: "Left Common Iliac Artery", 7: "Right Common Iliac Artery",
        8: "Innominate Artery", 9: "Left Common Carotid", 10: "Right External Iliac Artery",
        11: "Right Internal Iliac Artery", 12: "Left External Iliac Artery",
        13: "Left Internal Iliac Artery"
    }

    # Compute metrics for each volume
    for label_fp in labels_fpaths:
        gdth_fpath = label_fp['gdth_fpath']
        pred_fpath = label_fp['pred_fpath']
        vol_name = os.path.basename(gdth_fpath).split("_")[0]

        # Read images and convert them to numpy arrays
        gdth_img = sitk.ReadImage(gdth_fpath)
        gdth_np = sitk.GetArrayFromImage(gdth_img)
        pred_img = sitk.ReadImage(pred_fpath)
        pred_np = sitk.GetArrayFromImage(pred_img)
        spacing = np.array(list(reversed(pred_img.GetSpacing())))
        
        print(f"Processing {vol_name} for metrics computation ...")

        # Calculate metrics
        metrics = sg.write_metrics(
            labels=labels,
            gdth_img=gdth_np,
            pred_img=pred_np,
            csv_file=None,
            spacing=spacing,
            metrics=['msd', 'dice']
        )
        
        dice_results[vol_name] = metrics[0]['dice']
        msd_results[vol_name] = metrics[0]['msd']

    # Save the metrics to CSV files
    df_msd = pd.DataFrame(msd_results).T
    df_msd["Labels' Avg"] = df_msd.mean(axis=1)
    df_msd.loc['Volume Avg'] = df_msd.mean(axis=0)
    df_msd = df_msd.rename(index=segment_names)
    df_msd.index.names = ['Segments']
    df_msd.to_csv(result_dir / "test_msd.csv")

    df_dice = pd.DataFrame(dice_results).T
    df_dice["Labels' Avg"] = df_dice.mean(axis=1)
    df_dice.loc['Volume Avg'] = df_dice.mean(axis=0)
    df_dice = df_dice.rename(index=segment_names)
    df_dice.index.names = ['Segments']
    df_dice.to_csv(result_dir / "test_dice.csv")

    # Clean up to release memory
    del test_model