In [79]:
import torch
from monai.metrics import DiceMetric, HausdorffDistanceMetric
from monai.transforms import Compose, Activations, AsDiscreted, Invertd, KeepLargestConnectedComponentd, Lambdad, MapTransform
import monai.transforms as T
from tqdm import tqdm
import nibabel as nib
import numpy as np
from picai_eval import evaluate
from scipy.ndimage import label
from lightning.pytorch import seed_everything
import os 
import matplotlib.pyplot as plt

seed_everything(42)

[rank: 0] Seed set to 42


42

In [80]:
class ProcessPredictions(T.MapTransform):
    """Ensure predictions are non-overlapping 3D connected components with a single confidence score per lesion."""
    def __init__(self, keys, threshold=0.1):
        super().__init__(keys)
        self.threshold = threshold

    def __call__(self, data):
        for key in self.keys:
            pred = data[key]
            processed_pred = self.process_predictions(pred)
            data[key] = processed_pred  # Replace with processed prediction
        return data

    def process_predictions(self, pred):
        """Processes prediction to match evaluation format."""
        # Binarize the prediction based on the threshold
        binary_pred = (pred > self.threshold).astype(np.uint8)
        
        if binary_pred.ndim > 3:
            binary_pred = np.squeeze(binary_pred)
            pred = np.squeeze(pred)

        # Label connected components using 26-connectivity
        labeled_pred, num_components = label(binary_pred, structure=np.ones((3, 3, 3)))

        # Create a new array to store the processed prediction
        processed_pred = np.zeros_like(pred, dtype=np.float32)

        # Iterate over each connected component (lesion)
        for i in range(1, num_components + 1):
            lesion_mask = (labeled_pred == i)  # Mask for the current lesion
            
            # Extract the prediction values of the lesion region
            lesion_values = pred[lesion_mask]
            
            # Calculate the median value for the lesion region
            lesion_median = np.median(lesion_values)
            
            # Assign the median value to all voxels in the lesion region
            processed_pred[lesion_mask] = lesion_median
       
        processed_pred = np.expand_dims(processed_pred, axis=0)  # Add an extra dimension for consistency (as a 3D volume)
        return processed_pred



In [81]:
ground_truth_dir = "./data/PICCAIv2/labels/val"
predictions_dir = "./data/PICCAIv2/predictions/val"

run_id = '8d22yg6m'
checkpoint = 'model-epoch=224-val_dice=0.52'

# Get sorted case IDs
case_ids = sorted(os.listdir(ground_truth_dir))

ground_truths = []
processed_predictions = []

processor = ProcessPredictions(keys=["pred"])

for case_id in tqdm(case_ids, desc="Processing Predictions"):
    
    case_label_dir = os.path.join(ground_truth_dir, case_id)
    case_pred_dir = os.path.join(predictions_dir, case_id)

    # Find the correct file in each case directory
    gt_files = sorted([f for f in os.listdir(case_label_dir) if f.endswith(".nii.gz")])
    pred_files = sorted([
        f for f in os.listdir(case_pred_dir) 
        if f.endswith(".nii.gz") and f.startswith(f"{run_id}_{checkpoint}")
    ])

    # Ensure each case has exactly one label and one prediction file
    if len(gt_files) != 1 or len(pred_files) != 1:
        print(f"Skipping case {case_id} due to missing or multiple files.")
        continue

    gt_path = os.path.join(case_label_dir, gt_files[0])
    pred_path = os.path.join(case_pred_dir, pred_files[0])

    # Load ground truth
    gt_nifti = nib.load(gt_path)
    ground_truth = gt_nifti.get_fdata().astype(np.uint8)  # Ensure binary format
    ground_truths.append(ground_truth.squeeze())

    # Load prediction
    pred_nifti = nib.load(pred_path)
    pred = pred_nifti.get_fdata().astype(np.float32)  # Ensure float format

    # Apply processing
    processed_pred = processor({"pred": pred})["pred"]
    processed_predictions.append(processed_pred.squeeze())

print(f"Loaded {len(ground_truths)} ground truth masks and {len(processed_predictions)} processed predictions.")

Processing Predictions:   0%|          | 0/150 [00:00<?, ?it/s]

Processing Predictions: 100%|██████████| 150/150 [01:20<00:00,  1.87it/s]

Loaded 150 ground truth masks and 150 processed predictions.





In [82]:
def visualize_segmentation(t2w_img, ground_truth, prediction, title="Segmentation Visualization"):
    """
    Function to visualize ground truth and predicted segmentation overlays on the T2W image.
    
    Parameters:
    - t2w_img (numpy array): The grayscale T2-weighted image.
    - ground_truth (numpy array): Ground truth segmentation mask.
    - prediction (numpy array): Predicted segmentation mask.
    - title (str): Title for the visualization.
    """
    # Find the slice with the most segmentation in the ground truth
    label_slices = np.sum(ground_truth, axis=(0, 1))  # Sum over H, W
    slice_idx = np.argmax(label_slices)  # Slice index with most segmentation
    
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))  # Reduce figure height
    fig.suptitle(title, fontsize=16)  # Adjust title font size

    print(ground_truth.shape)
    print(prediction.shape)
    axes[0].set_title("Ground Truth")
    if t2w_img is not None:
        axes[0].imshow(t2w_img[:, :, slice_idx], cmap="gray")
        axes[0].imshow(ground_truth[:, :, slice_idx], cmap="Reds", alpha=0.8)
        axes[0].axis("off")
    else:
        axes[0].imshow(ground_truth[:, :, slice_idx], cmap="Reds")
        axes[0].axis("off")
        
    axes[1].set_title("Prediction")
    if t2w_img is not None:
        axes[1].imshow(t2w_img[:, :, slice_idx], cmap="gray")
        confidence_map = axes[1].imshow(prediction[:, :, slice_idx], cmap="coolwarm", vmin=0, vmax=1, alpha = 0.6)
        axes[1].axis("off")
    else:
        confidence_map = axes[1].imshow(prediction[:, :, slice_idx], cmap="coolwarm", vmin=0, vmax=1)
        axes[1].axis("off")
    
    fig.colorbar(confidence_map, ax=axes[1], fraction=0.046, pad=0.04, label="Confidence (0-1)")

    plt.subplots_adjust(top=0.85)  # Adjust title placement
    plt.show()

In [83]:
piccai_score = evaluate(
    y_det = processed_predictions,
    y_true = ground_truths
)
print(piccai_score)

Metrics(auroc=74.13%, AP=35.07%, 150 cases, 42 lesions)
