In [28]:
import torch
from system import System
from mamba_mic.data_modules.pi_caiv2 import PICAIV2DataModule
from monai.inferers import sliding_window_inference
import lightning.pytorch as pl
import matplotlib.pyplot as plt
from monai.metrics import DiceMetric, HausdorffDistanceMetric
from monai.transforms import (
    Compose,
    Activations,
    AsDiscreted,
    Invertd,
    KeepLargestConnectedComponentd,
    Lambdad,
    MapTransform,
)
import monai.transforms as T
from lightning.pytorch import seed_everything
from tqdm import tqdm
import re
import nibabel as nib
import numpy as np
import os


seed_everything(42)

[rank: 0] Seed set to 42


42

In [29]:
run_id = "8d22yg6m"
checkpoint = "model-epoch=224-val_dice=0.52"
checkpoint_path = f"lightning_logs/{run_id}/checkpoints/{checkpoint}.ckpt"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = System.load_from_checkpoint(checkpoint_path=checkpoint_path)
model.eval()
model.to(device)

Attribute 'net' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['net'])`.


System(
  (net): DynUNet(
    (input_block): UnetBasicBlock(
      (conv1): Convolution(
        (conv): Conv3d(3, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
      )
      (conv2): Convolution(
        (conv): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
      )
      (lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
      (norm1): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (norm2): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    )
    (downsamples): ModuleList(
      (0): UnetBasicBlock(
        (conv1): Convolution(
          (conv): Conv3d(32, 64, kernel_size=(3, 3, 1), stride=(2, 2, 1), padding=(1, 1, 0), bias=False)
        )
        (conv2): Convolution(
          (conv): Conv3d(64, 64, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0), bias=False)
        )
        (lrelu): LeakyReLU(negative_slope=0.01, in

In [30]:
data_module = PICAIV2DataModule(batch_size=1, include_empty_eval=True)
data_module.prepare_data()
data_module.setup()
val_set = data_module.val_set
test_set = data_module.test_set

Number of images: 1499
Training subjects: 1200
Validation subjects: 150
Test subjects: 149


In [31]:
class ConvertToBinaryLabeld(T.MapTransform):
    def __init__(self, keys: list, invertible=True, allow_missing_keys=True):
        self.keys = keys
        self.invertible = invertible
        self.allow_missing_keys = allow_missing_keys

    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            if key in data:
                label = d[key]  # Extract label tensor

                if self.invertible:
                    # Store the original label tensor for later inversion
                    d[f"original_{key}"] = label.clone()

                # Convert to binary: 0 for ISUP ≤1, 1 for ISUP ≥2
                d[key] = (label >= 1).float()

        return d


def post_transforms(val_data):
    transform = Compose(
        [
            Invertd(
                keys="pred",
                transform=data_module.preprocess,
                orig_keys="label",
                meta_keys="pred_meta_dict",
                orig_meta_keys="image_meta_dict",
                meta_key_postfix="meta_dict",
                nearest_interp=True,
                to_tensor=True,
                device="cpu",
            ),
            Invertd(
                keys="label",
                transform=val_set.transform,
                orig_keys="label",
                meta_keys="pred_meta_dict",
                orig_meta_keys="image_meta_dict",
                meta_key_postfix="meta_dict",
                nearest_interp=False,
                to_tensor=True,
                device="cpu",
            ),
            ConvertToBinaryLabeld(keys=["label"]),
        ]
    )

    # Apply transformations
    val_data = transform(val_data)
    sigmoid = Activations(sigmoid=True)
    val_data["pred"] = sigmoid(val_data["pred"])
    return val_data

In [32]:
def visualize_segmentation(
    t2w_img, ground_truth, prediction, title="Segmentation Visualization"
):
    """
    Function to visualize ground truth and predicted segmentation overlays on the T2W image.

    Parameters:
    - t2w_img (numpy array): The grayscale T2-weighted image.
    - ground_truth (numpy array): Ground truth segmentation mask.
    - prediction (numpy array): Predicted segmentation mask.
    - title (str): Title for the visualization.
    """
    # Find the slice with the most segmentation in the ground truth
    label_slices = np.sum(ground_truth, axis=(0, 1))  # Sum over H, W
    slice_idx = np.argmax(label_slices)

    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    fig.suptitle(title, fontsize=16)

    axes[0].set_title("Ground Truth")
    if t2w_img is not None:
        axes[0].imshow(t2w_img[:, :, slice_idx], cmap="gray")
        axes[0].imshow(ground_truth[:, :, slice_idx], cmap="Reds", alpha=0.8)
        axes[0].axis("off")
    else:
        axes[0].imshow(ground_truth[:, :, slice_idx], cmap="Reds")
        axes[0].axis("off")

    axes[1].set_title("Prediction")
    if t2w_img is not None:
        axes[1].imshow(t2w_img[:, :, slice_idx], cmap="gray")
        confidence_map = axes[1].imshow(
            prediction[:, :, slice_idx], cmap="coolwarm", vmin=0, vmax=1, alpha=0.6
        )
        axes[1].axis("off")
    else:
        confidence_map = axes[1].imshow(
            prediction[:, :, slice_idx], cmap="coolwarm", vmin=0, vmax=1
        )
        axes[1].axis("off")

    fig.colorbar(
        confidence_map, ax=axes[1], fraction=0.046, pad=0.04, label="Confidence (0-1)"
    )

    plt.subplots_adjust(top=0.85)
    plt.show()

In [33]:
import matplotlib.pyplot as pl

show_seg = False


model.eval()
with torch.no_grad():
    for val_data in tqdm(val_set):
        path = val_data["label"].meta["filename_or_obj"]
        filename = os.path.basename(path)  # Get filename from full path
        case_id = "_".join(filename.split("_")[:2])

        case_label_dir = f"./data/PICCAIv2/labels/val/{case_id}"
        case_pred_dir = f"./data/PICCAIv2/predictions/val/{case_id}"

        # Define paths to label and prediction files
        label_path = f"{case_label_dir}/{run_id}_{checkpoint}_val.nii.gz"
        pred_path = f"{case_pred_dir}/{run_id}_{checkpoint}_val.nii.gz"

        # Check if the files already exist to avoid reprocessing
        # if os.path.exists(label_path) and os.path.exists(pred_path):
        #    print(f"Skipping case {case_id} as predictions already exist.")
        #    continue

        x, y = (
            val_data["image"].to(device).unsqueeze(0),
            val_data["label"].to(device).unsqueeze(0),
        )
        val_data["pred"] = sliding_window_inference(
            x, roi_size=[256, 256, 32], overlap=0.5, sw_batch_size=3, predictor=model
        ).squeeze(0)

        postprocessed = post_transforms(val_data)

        y_pred = torch.tensor(postprocessed["pred"])
        y = torch.tensor(postprocessed["label"])

        if show_seg:
            y_pred_np, y_true_np = postprocessed["pred"], postprocessed["label"]

            y_true_np = y_true_np.squeeze()
            y_pred_np = y_pred_np.squeeze()

            visualize_segmentation(
                None, y_true_np, y_pred_np, title="Post-Processed Segmentation"
            )

        os.makedirs(case_label_dir, exist_ok=True)
        os.makedirs(case_pred_dir, exist_ok=True)

        nib.save(
            nib.Nifti1Image(
                y_pred.type(torch.float).numpy(),
                affine=val_data["image"].meta["original_affine"],
            ),
            f"{case_pred_dir}/{run_id}_{checkpoint}_val.nii.gz",
        )

        # TODO doesnt need to be saved for each model, add check
        nib.save(
            nib.Nifti1Image(
                y.type(torch.float).numpy(),
                affine=val_data["image"].meta["original_affine"],
            ),
            f"{case_label_dir}/{run_id}_{checkpoint}_val.nii.gz",
        )


  0%|          | 0/150 [00:00<?, ?it/s]

To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
100%|██████████| 150/150 [5:17:59<00:00, 127.20s/it]  


In [34]:
import matplotlib.pyplot as pl

show_seg = False

model.eval()
with torch.no_grad():
    for test_data in tqdm(test_set):
        path = test_data["label"].meta["filename_or_obj"]
        filename = os.path.basename(path)  # Get filename from full path
        case_id = "_".join(filename.split("_")[:2])

        case_label_dir = f"./data/PICCAIv2/labels/test/{case_id}"
        case_pred_dir = f"./data/PICCAIv2/predictions/test/{case_id}"

        # Define paths to label and prediction files
        label_path = f"{case_label_dir}/{run_id}_{checkpoint}_test.nii.gz"
        pred_path = f"{case_pred_dir}/{run_id}_{checkpoint}_test.nii.gz"

        # Check if the files already exist to avoid reprocessing
        # if os.path.exists(label_path) and os.path.exists(pred_path):
        #    print(f"Skipping case {case_id} as predictions already exist.")
        #    continue

        x, y = (
            test_data["image"].to(device).unsqueeze(0),
            test_data["label"].to(device).unsqueeze(0),
        )
        test_data["pred"] = sliding_window_inference(
            x, roi_size=[256, 256, 32], overlap=0.5, sw_batch_size=3, predictor=model
        ).squeeze(0)

        postprocessed = post_transforms(test_data)
        y_pred = torch.tensor(postprocessed["pred"])
        y = torch.tensor(postprocessed["label"])

        if show_seg:
            y_pred_np, y_true_np = postprocessed["pred"], postprocessed["label"]

            y_true_np = y_true_np.squeeze()
            y_pred_np = y_pred_np.squeeze()

            visualize_segmentation(
                None, y_true_np, y_pred_np, title="Post-Processed Segmentation"
            )

        os.makedirs(case_label_dir, exist_ok=True)
        os.makedirs(case_pred_dir, exist_ok=True)

        nib.save(
            nib.Nifti1Image(
                y_pred.type(torch.float).numpy(),
                affine=test_data["image"].meta["original_affine"],
            ),
            f"{case_pred_dir}/{run_id}_{checkpoint}_test.nii.gz",
        )

        nib.save(
            nib.Nifti1Image(
                y.type(torch.float).numpy(),
                affine=test_data["image"].meta["original_affine"],
            ),
            f"{case_label_dir}/{run_id}_{checkpoint}_test.nii.gz",
        )


  0%|          | 0/10 [00:00<?, ?it/s]To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
100%|██████████| 10/10 [20:58<00:00, 125.86s/it]
