<div align="center">
    <h1>Training: CIS-UNet: Multi-Class Segmentation of the Aorta in Computed Tomography Angiography via Context-Aware Shifted Window Self-Attention</h1>    
This notebook walks you through the steps required to train the CIS-UNet model.
    

</div>

## Table of Contents

1. [Importing Libraries](#1-importing-libraries) 
2. [Helper Functions](#HelperFunctions)
3. [Define Directories and Parameters](#4-Define-Directories-and-Parameters)
4. [Data Preparation](#3-Data-Preparation)
5. [Data Transformations](#Data-Transformations)
6. [Model Training (K-fold Cross Validatoin)](#Model-Training-Cross-Validation)
7. [Count Model Parameters](#Count-Model-Parameters)

<hr>

## 1. Importing Libraries <a id='1-importing-libraries'></a>

Importing all the required packages.

---

In [None]:
import os
import shutil
import torch
import glob
import sys
import monai
import sklearn
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from monai.transforms import (
    AsDiscrete, EnsureChannelFirstd, Compose, CropForegroundd,
    LoadImaged, Orientationd, RandFlipd, RandCropByPosNegLabeld,
    RandShiftIntensityd, ScaleIntensityRanged, RandRotate90d,
    Spacingd, RandAffined
)
from monai.metrics import DiceMetric
from monai.data import DataLoader, CacheDataset, decollate_batch
from monai.networks.layers import Norm
from pathlib import Path
import SimpleITK as sitk
import pandas as pd
from tabulate import tabulate
from sklearn.model_selection import KFold
from utils.CIS_UNet import CIS_UNet

<hr>

## Helper Functions <a id='HelperFunctions'></a>

<hr>

### Print Setup Details

In [None]:
# Function to print setup details
def print_setup_details():
    print("#" * 40, "Setup Details", "#" * 40)
    versions = {
        "Package": ["OS, Shutil, and Glob", "Numpy", "Monai", "Torch", "SimpleITK"],
        "Version": [sys.version, np.__version__, monai.__version__, torch.__version__, sitk.__version__, sklearn.__version__]
    }
    versions_df = pd.DataFrame(versions)
    print(tabulate(versions_df, headers="keys", tablefmt="grid"))

    print("#" * 40, "Hardware Details", "#" * 40)
    num_gpus = torch.cuda.device_count()
    num_cpus = torch.get_num_threads()
    gpu_cpu_details = {
        "Component": ["GPUs", "CPUs"],
        "Count": [num_gpus, num_cpus]
    }
    gpu_cpu_df = pd.DataFrame(gpu_cpu_details)
    print(tabulate(gpu_cpu_df, headers="keys", tablefmt="grid"))

    if num_gpus > 0:
        gpu_ids = {"GPU ID": [f"GPU {gpu_id}" for gpu_id in range(num_gpus)]}
        gpu_ids_df = pd.DataFrame(gpu_ids)
        print(tabulate(gpu_ids_df, headers="keys", tablefmt="grid"))

    print("#" * 40, "Parameters Details", "#" * 40)
    setup_details = {
        "Parameter": [
            "Number of Folds", "Number of Samples", "Patch Size",
            "Spatial Dimensions", "Block Inplanes", "Layers",
            "In Channels", "Number of Classes", "Encoder Channels",
            "Feature Size", "Normalization Name"
        ],
        "Value": [
            num_folds, num_samples, patch_size, spatial_dims, block_inplanes, layers, in_channels, 
            num_classes, encoder_channels, feature_size, norm_name
        ]
    }
    setup_df = pd.DataFrame(setup_details)
    print(tabulate(setup_df, headers="keys", tablefmt="grid"))
    print('#' * 100)

### KFold Cross-Validation 

In [None]:
# Function to perform KFold cross-validation
def perform_cross_validation(files, skf, train_test_files):
    for i, (train_index, test_index) in enumerate(skf.split(files)):
        print(f"Fold {i + 1}:")
        print(f"  Training set:")
        print(f"    Number of samples: {len(train_index)}")
        print(f"    Indices: {train_index}\n")
        print(f"  Validation set:")
        print(f"    Number of samples: {len(test_index)}")
        print(f"    Indices: {test_index}\n")

        train_files = [files[i] for i in train_index]
        val_files = [files[i] for i in test_index]
        train_test_files[f'Fold_{i + 1}_train_files'] = train_files
        train_test_files[f'Fold_{i + 1}_test_files'] = val_files

### Model Validation 

In [None]:
# Function for model validation
def validation(epoch_iterator_val):
    model.eval()
    with torch.no_grad():
        for batch in epoch_iterator_val:
            val_inputs, val_labels = batch["image"].cuda(), batch["label"].cuda()
            val_outputs = sliding_window_inference(val_inputs, (patch_size, patch_size, patch_size), num_samples, model)
            val_labels_list = decollate_batch(val_labels)
            val_labels_convert = [post_label(val_label_tensor) for val_label_tensor in val_labels_list]
            val_outputs_list = decollate_batch(val_outputs)
            val_output_convert = [post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list]
            dice_metric(y_pred=val_output_convert, y=val_labels_convert)
            epoch_iterator_val.set_description(f"Validate ({global_step} / {max_iterations} Steps)")
        mean_dice_val = dice_metric.aggregate().item()
        dice_metric.reset()
    return mean_dice_val

### Model Training 

In [None]:
# Function for model training
def train(global_step, train_loader, val_loader, dice_val_best, global_step_best, fold):
    model.train()
    epoch_loss = 0
    step = 0
    epoch_iterator = tqdm(train_loader, desc=f"Training ({global_step} / {max_iterations} Steps) (loss=X.X)", dynamic_ncols=True)
    for step, batch in enumerate(epoch_iterator):
        step += 1
        x, y = batch["image"].cuda(), batch["label"].cuda()
        logit_map = model(x)
        loss = loss_function(logit_map, y)
        loss.backward()
        epoch_loss += loss.item()
        optimizer.step()
        optimizer.zero_grad()
        epoch_iterator.set_description(f"Training ({global_step} / {max_iterations} Steps) (loss={loss:.5f})")
        if (global_step % eval_num == 0 and global_step != 0) or global_step == max_iterations:
            epoch_iterator_val = tqdm(val_loader, desc=f"Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True)
            dice_val = validation(epoch_iterator_val)
            epoch_loss /= step
            epoch_loss_values.append(epoch_loss)
            metric_values.append(dice_val)
            if dice_val > dice_val_best:
                dice_val_best = dice_val
                global_step_best = global_step
                torch.save(model.state_dict(), os.path.join(saved_model_dir, f'Fold{fold}_best_metric_model.pth'))
                print(f"Model Was Saved! Current Best Avg. Dice: {dice_val_best} | Current Avg. Dice: {dice_val}")
            else:
                print(f"Model Was Not Saved! Current Best Avg. Dice: {dice_val_best} | Current Avg. Dice: {dice_val}")
        global_step += 1
    return global_step, dice_val_best, global_step_best

<hr>

## 3. Define Directories and Parameters <a id='4-Define-Directories-and-Parameters'></a>
 
<hr>

In [None]:
# Define directories and parameters
ver = "CIS_UNet"
data_dir = Path("../data")
root_dir = Path("./")
saved_model_dir = root_dir / "saved_models" / ver
results_dir = root_dir / "results" / ver
saved_model_dir.mkdir(parents=True, exist_ok=True)
results_dir.mkdir(parents=True, exist_ok=True)

num_gpus = torch.cuda.device_count()
num_cpus = torch.get_num_threads()

num_folds = 4
num_samples = 4
patch_size = 128
spatial_dims = 3
block_inplanes = (64, 128, 256, 512) 
layers = (3, 4, 6, 3)
in_channels = 1
num_classes = 15
encoder_channels = [64, block_inplanes[0], block_inplanes[1], block_inplanes[2]]
feature_size = 48
norm_name = 'instance'

<hr>

## 4. Data Preparation <a id='3-Data-Preparation'></a>

<hr>

In [None]:
images = sorted(glob.glob(os.path.join(data_dir, "Volumes", "*.nii.gz")))
labels = sorted(glob.glob(os.path.join(data_dir, "Labels", "*.nii.gz")))
files = [{"image": image_name, "label": label_name} for image_name, label_name in zip(images, labels)]

skf = KFold(n_splits=num_folds, shuffle=True, random_state=92)
train_test_files = {}
perform_cross_validation(files, skf, train_test_files)

<hr>

## 5. Data Transformations <a id='Data-Transformations'></a>

<hr>

In [None]:
# Define data transformations
train_transforms = Compose([
    LoadImaged(keys=["image", "label"], ensure_channel_first=True, image_only=False),
    ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
    CropForegroundd(keys=["image", "label"], source_key="image", allow_smaller=True),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 1.5), mode=("bilinear", "nearest")),
    RandCropByPosNegLabeld(keys=["image", "label"], label_key="label", spatial_size=(patch_size, patch_size, patch_size),
                           pos=1, neg=1, num_samples=num_samples, image_key="image", image_threshold=0),
    RandFlipd(keys=["image", "label"], spatial_axis=[0, 1, 2], prob=0.10),
    RandRotate90d(keys=["image", "label"], prob=0.10, max_k=3),
    RandShiftIntensityd(keys=["image"], offsets=0.10, prob=0.50),
])

val_transforms = Compose([
    LoadImaged(keys=["image", "label"], ensure_channel_first=True, image_only=False),
    ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
    CropForegroundd(keys=["image", "label"], source_key="image", allow_smaller=True),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 1.5), mode=("bilinear", "nearest")),
])

<hr>

## 6. Model Training (K-fold Cross Validatoin) <a id='Model-Training-Cross-Validation'></a>

<hr>

In [None]:
post_label = AsDiscrete(to_onehot=num_classes)
post_pred = AsDiscrete(argmax=True, to_onehot=num_classes)
dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
global_step = 0
dice_val_best = 0.0
global_step_best = 0
epoch_loss_values = []
metric_values = []

# Run the k-fold cross-validation and training
for fold, (train_indices, val_indices) in enumerate(skf.split(files)):
    print(f"Processing Fold {fold+1}")
    train_files = [files[i] for i in train_indices]
    val_files = [files[i] for i in val_indices]

    train_ds = CacheDataset(data=train_files, transform=train_transforms,
                            cache_num=len(train_files), cache_rate=1.0, num_workers=8)
    train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=num_cpus//2, pin_memory=True)
    val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_num=len(val_files), cache_rate=1.0, num_workers=num_cpus//2)
    val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=num_cpus//2, pin_memory=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = CIS_UNet(spatial_dims=spatial_dims, in_channels=in_channels, num_classes=num_classes, encoder_channels=encoder_channels)
    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
    model.to(device)

    loss_function = DiceCELoss(to_onehot_y=True, softmax=True, include_background=False)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

    max_iterations = 5
    eval_num = 35
    global_step = 0
    dice_val_best = 0.0
    global_step_best = 0
    epoch_loss_values = []
    metric_values = []

    while global_step < max_iterations:
        global_step, dice_val_best, global_step_best = train(global_step, train_loader, val_loader, dice_val_best, global_step_best, fold=fold)

    print(f"Global Step: {global_step} | Best Dice: {dice_val_best} | Global Best: {global_step_best}")


<hr>

## Count Model Parameters <a id='Count-Model-Parameters'></a>

<hr>


In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)



# Compute the number of parameters
num_params = count_parameters(model)
print("Number of parameters: ", num_params)