In [4]:
# --- Standard Library ---
import os
import sys
import random
import yaml

# --- Third-Party Libraries ---
import numpy as np
import torch
from tqdm import tqdm
import SimpleITK as sitk
from matplotlib import pyplot as plt
from termcolor import colored

# project_root = '/d/hpc/projects/FRI/jf73497/aimi-project/src/segformer3duls/'
# if project_root not in sys.path:
#     sys.path.insert(0, project_root)

# from metrics.competition_metric import ULS23_evaluator

from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
import nnunetv2

In [5]:
image_path = "./MIX_06966_0000.nii.gz"
label_path = "./MIX_06966.nii.gz"


# Read image slice (single slice)
image_itk = sitk.ReadImage(image_path)
image_raw = sitk.GetArrayFromImage(image_itk).astype(np.float32)
image_spacings = image_itk.GetSpacing()
# Read full label volume
label_itk = sitk.ReadImage(label_path)
label = sitk.GetArrayFromImage(label_itk).astype(np.int64)


In [6]:
import matplotlib.pyplot as plt
from matplotlib.patches import Patch

def visualize_middle_slice_with_border(voi, prediction, label, pred_os):
    """
    Visualize the middle slice of a VOI with border-only overlays for prediction and label.

    Args:
        voi (torch.Tensor): Input VOI of shape (1, 1, 64, 128, 128)
        prediction (torch.Tensor): Binary prediction of the same shape
        label (torch.Tensor): Binary ground truth label of the same shape
    """
    assert voi.shape == prediction.shape == label.shape, "Shapes must match and be (1, 1, 64, 128, 128)"

    # Convert tensors to NumPy arrays
    voi_np = voi
    pred_np = prediction.astype(np.uint8)
    label_np = label.astype(np.uint8)
    pred_os = pred_os.astype(np.uint8)

    # Get the middle slice along the depth dimension
    mid_slice = voi_np.shape[0] // 2
    base_slice = voi_np[mid_slice]
    pred_slice = pred_np[mid_slice]
    label_slice = label_np[mid_slice]
    pred_os = pred_os[mid_slice]

    # Define edge detection kernel (simple Laplacian)
    kernel = np.array([[1, 1, 1],
                       [1, -8, 1],
                       [1, 1, 1]], dtype=np.int8)

    def extract_border(binary_slice):
        padded = np.pad(binary_slice, pad_width=1, mode='constant', constant_values=0)
        border = np.zeros_like(binary_slice)
        for i in range(binary_slice.shape[0]):
            for j in range(binary_slice.shape[1]):
                region = padded[i:i+3, j:j+3]
                val = np.sum(region * kernel)
                border[i, j] = 1 if val != 0 and binary_slice[i, j] == 1 else 0
        return border

    # Extract borders
    pred_border = extract_border(pred_slice)
    label_border = extract_border(label_slice)
    pred_os = extract_border(pred_os)

    # Plotting
    plt.figure(figsize=(8, 8))
    plt.imshow(base_slice, cmap="gray")

    # Overlay label border (red)
    y_label, x_label = np.where(label_border == 1)
    plt.scatter(x_label, y_label, c='#FF3F33', s=2, label='Label Border')

    # Overlay prediction border (blue)
    y_pred, x_pred = np.where(pred_border == 1)
    plt.scatter(x_pred, y_pred, c='#3D3BF3', s=2, label='Prediction Border')

    # Overlay prediction border (blue)
    y_pred, x_pred = np.where(pred_os == 1)
    plt.scatter(x_pred, y_pred, c='#FF9B17', s=2, label='OS Prediction Border')

    # Legend and formatting
    legend_elements = [
        Patch(facecolor='none', edgecolor='#FF3F33', label='Label Border'),
        Patch(facecolor='none', edgecolor='#3D3BF3', label='SegFormer3D Prediction Border'),
        Patch(facecolor='none', edgecolor='#FF9B17', label='SegFormer3D_OS Prediction Border'),
    ]
    plt.legend(handles=legend_elements, loc='lower right')
    plt.title("nnUNetv2 Middle Z Slice Predictions")
    plt.axis("off")
    plt.tight_layout()
    plt.savefig('./nnunet_lesion_viz.pdf')


In [7]:


# evaluator = ULS23_evaluator()

##################################################################################################
def seed_everything(sedd) -> None:
    seed = sedd
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def _run_eval(model, data, image_spacings) -> None:
    """_summary_"""
    # Tell wandb to watch the model and optimizer values

    print("[info] -- Starting model evaluation")

    predicted = model.predict_single_npy_array(data, {'spacing': image_spacings})
    # print(logits.shape)
    # predicted = torch.sigmoid(logits)

    return predicted

In [8]:
print(image_raw.shape)

model_input_image_raw = image_raw[None, ]
print(model_input_image_raw.shape)

(64, 128, 128)
(1, 64, 128, 128)


In [9]:

# set seed
seed_everything(42)


def load_model():
    
    # Set up the nnUNetPredictor
    predictor = nnUNetPredictor(
        tile_step_size=0.5,
        use_gaussian=True,
        use_mirroring=False, # False is faster but less accurate
        device=torch.device(type='cuda', index=0),
        verbose=False,
        verbose_preprocessing=False,
        allow_tqdm=False
    )
    # Initialize the network architecture, loads the checkpoint
    predictor.initialize_from_trained_model_folder(
        fr"C:\Users\Lazar\OneDrive\Desktop\RU Courses\AI in Medical Imaging\project\aimi-project\SegFormer3D-main\data\local_data\nnUNetTrainer_ULS_400_QuarterLR__nnUNetResEncUNetMPlans__3d_fullres", # Path always relative to /opt/ml/model/
        use_folds=[0],
        checkpoint_name="checkpoint_best.pth", # TODO: export the best checkpoint from the training job and change this to checkpoint_best.pth
    )
    return predictor

predictor = load_model()


print("[info] -- Running evaluation only.")
prediction_raw = _run_eval(predictor, model_input_image_raw, image_spacings)

[info] -- Running evaluation only.
[info] -- Starting model evaluation


In [12]:

prediction_mask_raw = prediction_raw > 0.5
print(prediction_mask_raw.shape)

(64, 128, 128)


In [15]:
print(prediction_mask_raw.shape)
plt.imshow(prediction_mask_raw[32])
plt.savefig('./afaafa.pdf')


(64, 128, 128)


In [60]:
print(image_raw.shape, prediction_mask_raw.shape, label.shape)


image_viz = image_raw
preD_viz = prediction_mask_raw
preD_viz_raw = prediction_mask_raw
label_viz = label
print(image_viz.shape, preD_viz.shape, label_viz.shape)
visualize_middle_slice_with_border(image_viz, preD_viz_raw, label_viz, preD_viz_raw)

(64, 128, 128) (64, 128, 128) (64, 128, 128)
(64, 128, 128) (64, 128, 128) (64, 128, 128)


In [None]:
os.getcwd()

Error in callback <bound method _WandbInit._resume_backend of <wandb.sdk.wandb_init._WandbInit object at 0x1467d80cfd00>> (for pre_run_cell), with arguments args (<ExecutionInfo object at 1467d80763e0, raw_cell="os.getcwd()" store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell://ssh-remote%2Bhpc-login1.arnes.si/d/hpc/home/jf73497/projects/aimi-project/src/prediction_label_viz.ipynb#X14sdnNjb2RlLXJlbW90ZQ%3D%3D>,),kwargs {}:


TypeError: _WandbInit._resume_backend() takes 1 positional argument but 2 were given

'/d/hpc/projects/FRI/jf73497/aimi-project'

Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x1467d80cfd00>> (for post_run_cell), with arguments args (<ExecutionResult object at 1467d8076d10, execution_count=127 error_before_exec=None error_in_exec=None info=<ExecutionInfo object at 1467d80763e0, raw_cell="os.getcwd()" store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell://ssh-remote%2Bhpc-login1.arnes.si/d/hpc/home/jf73497/projects/aimi-project/src/prediction_label_viz.ipynb#X14sdnNjb2RlLXJlbW90ZQ%3D%3D> result='/d/hpc/projects/FRI/jf73497/aimi-project'>,),kwargs {}:


TypeError: _WandbInit._pause_backend() takes 1 positional argument but 2 were given

In [16]:
torch.save(prediction_mask_raw, "pred_non_os_nnunet.pt")