In [73]:
# --- Standard Library ---
import os
import sys
import random
import time
import yaml

# --- Third-Party Libraries ---
import numpy as np
import torch
from tqdm import tqdm
from matplotlib import pyplot as plt
from termcolor import colored
from accelerate import Accelerator
from scipy.ndimage import (
    binary_closing,
    binary_opening,
    binary_dilation,
    binary_erosion,
    generate_binary_structure
)

# --- Local Imports ---
sys.path.append("..")
sys.path.append("SegFormer3D-main")

from dataloaders.build_dataset import build_dataset, build_dataloader
from architectures.build_architecture import build_architecture
from metrics.competition_metric import ULS23_evaluator

print(os.getcwd())
sys.path.append("../../../nnUNet")
!pip install acvl_utils
!pip install dynamic_network_architectures
!pip install -e "C:\Users\Lazar\OneDrive\Desktop\RU Courses\AI in Medical Imaging\project\nnUNet"
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor

import nnunetv2


import os
import torch
from torch.utils.data import Dataset, DataLoader
import SimpleITK as sitk
import numpy as np
from glob import glob

from matplotlib import pyplot as plt
import matplotlib
matplotlib.use('TkAgg')

c:\Users\Lazar\OneDrive\Desktop\RU Courses\AI in Medical Imaging\project\aimi-project\SegFormer3D-main\tta_aug
Collecting argparse (from unittest2->batchgenerators->acvl_utils)
  Using cached argparse-1.4.0-py2.py3-none-any.whl.metadata (2.8 kB)
Using cached argparse-1.4.0-py2.py3-none-any.whl (23 kB)
Installing collected packages: argparse
Successfully installed argparse-1.4.0
Obtaining file:///C:/Users/Lazar/OneDrive/Desktop/RU%20Courses/AI%20in%20Medical%20Imaging/project/nnUNet
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Checking if build backend supports build_editable: started
  Checking if build backend supports build_editable: finished with status 'done'
  Getting requirements to build editable: started
  Getting requirements to build editable: finished with status 'done'
  Preparing editable metadata (pyproject.toml): started
  Preparing editable metadata (pyproject.toml): finished with status 'done'
Collecting argpar

In [74]:
print(torch.cuda.is_available())

True


In [75]:
def test_time_dilation(input_tensor):
    # Convert to NumPy and squeeze batch and channel dimensions
    volume = input_tensor.squeeze().cpu().numpy()
    
    # Apply binary dilation
    dilated = binary_dilation(volume, structure=np.ones((3, 3, 3)))

    return dilated.astype(np.uint8)

def test_time_2xdilation(input_tensor):
    volume = input_tensor.squeeze().cpu().numpy()
    
    # Apply binary dilation twice
    dilated = binary_dilation(volume, structure=np.ones((3, 3, 3)))
    dilated = binary_dilation(dilated, structure=np.ones((3, 3, 3)))

    return dilated.astype(np.uint8)

def test_time_opening(input_tensor):
    volume = input_tensor.squeeze().cpu().numpy()
    
    # Apply binary opening: erosion followed by dilation
    opened = binary_erosion(volume, structure=np.ones((3, 3, 3)))
    opened = binary_dilation(opened, structure=np.ones((3, 3, 3)))

    return opened.astype(np.uint8)

def test_time_closing(input_tensor):
    volume = input_tensor.squeeze().cpu().numpy()
    
    # Apply binary closing: dilation followed by erosion
    closed = binary_dilation(volume, structure=np.ones((3, 3, 3)))
    closed = binary_erosion(closed, structure=np.ones((3, 3, 3)))

    return closed.astype(np.uint8)

def test_time_shift(model, input_tensor, threshold=0.5):
    predictions = []

    device = model.device.type
    input_tensor = input_tensor.to(device)

    shifts = [
        (0, 0, 0),
        (5, 0, 0), 
        # (-5, 0, 0),
        # (0, 5, 0), 
        (0, -5, 0),
        # (0, 0, 5), 
        (0, 0, -5),
    ]

    with torch.no_grad():
        for dz, dx, dy in shifts:
            augmented = input_tensor.clone()

            # Shift
            augmented = torch.roll(augmented, shifts=(dz, dx, dy), dims=(1, 2, 3))

            # Model inference
            logits = model.predict_logits_from_preprocessed_data(input_tensor).unsqueeze(0)[:, 1:, ...]
            predicted = torch.sigmoid(logits)
            pred = predicted > threshold

            # Undo shift
            pred = torch.roll(pred, shifts=(-dz, -dx, -dy), dims=(1, 2, 3))
            predictions.append(pred.float())

    avg_prediction = torch.stack(predictions).mean(dim=0, keepdims=True)

    # --- Morphological smoothing ---
    # Convert to NumPy for processing
    pred_np = avg_prediction.squeeze().cpu().numpy()

    # Binarize
    binary_pred = pred_np > threshold
    return binary_pred

In [76]:
class NiiSliceDataset(Dataset):
    def __init__(self, images_dir, labels_dir, transform=None):
        self.images_dir = images_dir
        self.labels_dir = labels_dir
        self.transform = transform

        # List all image slices
        self.image_paths = sorted(glob(os.path.join(images_dir, "MIX_*_*.nii.gz")))
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]

        # Extract ID, e.g., MIX_00001_0000.nii.gz -> MIX_00001
        basename = os.path.basename(image_path)
        id_part = "_".join(basename.split("_")[:2])  # MIX_00001

        label_path = os.path.join(self.labels_dir, f"{id_part}.nii.gz")

        # Read image slice (single slice)
        image_itk = sitk.ReadImage(image_path)
        image = sitk.GetArrayFromImage(image_itk).astype(np.float32)

        # Read full label volume
        label_itk = sitk.ReadImage(label_path)
        label = sitk.GetArrayFromImage(label_itk).astype(np.int64)

        if self.transform:
            image, label = self.transform(image, label)

        # Convert to torch tensors
        image_tensor = torch.from_numpy(image) 
        label_tensor = torch.from_numpy(label) 

        return image_tensor, label_tensor

In [77]:
evaluator = ULS23_evaluator()

##################################################################################################
def seed_everything(seed) -> None:
    # seed = config["training_parameters"]["seed"]
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def _run_eval(model, val_dataloader) -> None:
    """_summary_"""
    # Tell wandb to watch the model and optimizer values

    print("[info] -- Starting model evaluation")

    # Initialize the training loss for the current Epoch
    original_dice_total = 0.0
    original_uls_metric_total = 0.0
    dilated_uls_metric_total = 0.0
    dilatedx2_uls_metric_total = 0.0
    opening_uls_metric_total = 0.0
    closing_uls_metric_total = 0.0
    shift_uls_metric_total = 0.0


    ULS_per_threshold = [0.0]*9
    thresholds = np.linspace(0.1, 0.9, 9)
    

    # DATA IS 1x1x64x128x128, NO SPACINGS USED
    with torch.no_grad():
        for index, (data, labels) in enumerate(tqdm(val_dataloader)):

            logits = model.predict_logits_from_preprocessed_data(data).unsqueeze(0)[:, 1:, ...]
            predicted = torch.sigmoid(logits)

            for i, threshold in enumerate(thresholds):
                y_pred = predicted > threshold
                y_true = labels.unsqueeze(0)

                # SANITY CHECK
                # plt.imshow(y_pred[0, 0, 32])
                # plt.show()

                ULS_per_threshold[i] += evaluator.ULS_score_metric(y_pred, y_true)

    ULS_per_threshold = np.array(ULS_per_threshold)/float(index + 1)
    best_index = np.argmax(ULS_per_threshold)
    best_threshold = thresholds[best_index]

    print(ULS_per_threshold, best_threshold)


    # set epoch to shift data order each epoch
    # self.val_dataloader.sampler.set_epoch(self.current_epoch)
    with torch.no_grad():
        for index, (data, labels) in enumerate(tqdm(val_dataloader)):

            logits = model.predict_logits_from_preprocessed_data(data).unsqueeze(0)[:, 1:, ...]
            predicted = torch.sigmoid(logits)
            y_pred = predicted > best_threshold
            y_true = labels.unsqueeze(0)

            original_uls_metric_total += evaluator.ULS_score_metric(y_pred, y_true)
            # print(f"Orgi ULS: {original_uls_metric}")

            # Test-time dilation
            dilated_pred = test_time_dilation(y_pred.float())
            dilated_uls_metric_total += evaluator.ULS_score_metric(
                torch.tensor(dilated_pred).unsqueeze(0).unsqueeze(0).to(y_true.device),
                y_true,
            )

            # Test-time 2x dilation
            dilatedx2_pred = test_time_2xdilation(y_pred.float())
            dilatedx2_uls_metric_total += evaluator.ULS_score_metric(
                torch.tensor(dilatedx2_pred).unsqueeze(0).unsqueeze(0).to(y_true.device),
                y_true,
            )

            # Test-time opening
            opened_pred = test_time_opening(y_pred.float())
            opening_uls_metric_total += evaluator.ULS_score_metric(
                torch.tensor(opened_pred).unsqueeze(0).unsqueeze(0).to(y_true.device),
                y_true,
            )

            # Test-time closing
            closed_pred = test_time_closing(y_pred.float())
            closing_uls_metric_total += evaluator.ULS_score_metric(
                torch.tensor(closed_pred).unsqueeze(0).unsqueeze(0).to(y_true.device),
                y_true,
            )

            # Test-time shift
            shifted_pred = test_time_shift(model, data, threshold=best_threshold)
            shift_uls_metric_total += evaluator.ULS_score_metric(
                torch.tensor(shifted_pred).unsqueeze(0).unsqueeze(0).to(y_true.device),
                y_true,
            )

    # Average across batches
    divisor = float(index + 1)
    original_uls_metric_total /= divisor
    dilated_uls_metric_total /= divisor
    dilatedx2_uls_metric_total /= divisor
    opening_uls_metric_total /= divisor
    closing_uls_metric_total /= divisor
    shift_uls_metric_total /= divisor


    # Print metrics
    print(
        f"Original ULS: {colored(f'{original_uls_metric_total:.5f}', 'green')} | "
        f"Dilated: {colored(f'{dilated_uls_metric_total:.5f}', 'cyan')} | "
        f"2x Dilated: {colored(f'{dilatedx2_uls_metric_total:.5f}', 'cyan')} | "
        f"Opening: {colored(f'{opening_uls_metric_total:.5f}', 'cyan')} | "
        f"Closing: {colored(f'{closing_uls_metric_total:.5f}', 'cyan')} | "
        f"Shifted: {colored(f'{shift_uls_metric_total:.5f}', 'cyan')}"
    )

    return ULS_per_threshold



In [78]:
predictor.device.type

'cuda'

In [79]:
# set seed
seed_everything(42)


def load_model():
    start_model_load_time = time.time()
    
    # Set up the nnUNetPredictor
    predictor = nnUNetPredictor(
        tile_step_size=0.5,
        use_gaussian=True,
        use_mirroring=False, # False is faster but less accurate
        device=torch.device(type='cuda', index=0),
        verbose=False,
        verbose_preprocessing=False,
        allow_tqdm=False
    )
    # Initialize the network architecture, loads the checkpoint
    predictor.initialize_from_trained_model_folder(
        fr"C:\Users\Lazar\OneDrive\Desktop\RU Courses\AI in Medical Imaging\project\aimi-project\SegFormer3D-main\data\local_data\nnUNetTrainer_ULS_400_QuarterLR__nnUNetResEncUNetMPlans__3d_fullres", # Path always relative to /opt/ml/model/
        use_folds=[0],
        checkpoint_name="checkpoint_best.pth", # TODO: export the best checkpoint from the training job and change this to checkpoint_best.pth
    )
    end_model_load_time = time.time()
    print(f"Model loading runtime: {end_model_load_time - start_model_load_time}s")
    return predictor


predictor = load_model()


images_dir = r"C:\Users\Lazar\OneDrive\Desktop\RU Courses\AI in Medical Imaging\project\aimi-project\SegFormer3D-main\data\local_data\nnunet_data\imagesVal"
labels_dir = r"C:\Users\Lazar\OneDrive\Desktop\RU Courses\AI in Medical Imaging\project\aimi-project\SegFormer3D-main\data\local_data\nnunet_data\labelsVal"

dataset = NiiSliceDataset(images_dir, labels_dir)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

print("[info] -- Running evaluation only.")
ULS_per_threshold = _run_eval(predictor, dataloader)

Model loading runtime: 3.185537099838257s
[info] -- Running evaluation only.
[info] -- Starting model evaluation


100%|██████████| 1155/1155 [55:00<00:00,  2.86s/it] 


[0.54725477 0.55873257 0.56346872 0.56544667 0.56639801 0.56686143
 0.56664409 0.56499098 0.55830482] 0.6


100%|██████████| 1155/1155 [1:27:55<00:00,  4.57s/it] 

Original ULS: [32m0.56686[0m | Dilated: [36m0.44240[0m | 2x Dilated: [36m0.33654[0m | Opening: [36m0.51330[0m | Closing: [36m0.56719[0m | Shifted: [36m0.39253[0m





In [80]:
print(ULS_per_threshold)

[0.54725477 0.55873257 0.56346872 0.56544667 0.56639801 0.56686143
 0.56664409 0.56499098 0.55830482]


In [None]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

# Example data (replace with your actual data)
thresholds = np.linspace(0.1, 0.9, 9)

# Set seaborn style
sns.set_theme(style="whitegrid", context="talk")


# #0D92F4, #C62E2E

# Create the plot
plt.figure(figsize=(8, 5))
sns.lineplot(x=thresholds, y=[0.55682154, 0.55853287, 0.5591937,  0.55975669, 0.56031297, 0.56064882, 0.56097877, 0.5611234, 0.56077012], 
             marker="o", linewidth=2.5, color="#F95454", label="SegFormer3D")
sns.lineplot(x=thresholds, y=ULS_per_threshold, marker="o", linewidth=2.5, color="#77CDFF", label="nnUNetv2")

# Labels and title
plt.xlabel("Threshold", fontsize=14)
plt.ylabel("ULS Score", fontsize=14)
plt.title("ULS Score vs. Threshold", fontsize=16)
plt.xticks(thresholds, rotation=45)
# plt.ylim(0, 1)
plt.grid(True)
plt.legend()

# Show the plot
plt.tight_layout()
plt.savefig("./threshold_optimization.pdf")