In [9]:
import numpy as np
from scipy.ndimage import label, find_objects
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.auto import tqdm
import SimpleITK as sitk
import pandas as pd
import radiomics
from radiomics import featureextractor
import os
from mirp import extract_features


def count_blob_pixels(array):
    # Label the blobs in the array
    labeled_array, num_features = label(array)

    # Calculate the number of pixels in each blob
    blob_sizes = [(labeled_array == i).sum() for i in range(1, num_features + 1)]

    return blob_sizes


def remove_small_blobs(array, min_size=100):
    # Label the blobs in the array
    labeled_array, num_features = label(array)

    # Find the sizes of each blob
    blob_sizes = [(labeled_array == i + 1).sum() for i in range(num_features)]

    # Create a mask to keep blobs that meet the size requirement or if it's the only blob
    if num_features == 1:
        mask = labeled_array != 0  # Keep the single blob regardless of size
    else:
        mask = np.isin(
            labeled_array,
            [i + 1 for i, size in enumerate(blob_sizes) if size >= min_size],
        )

    # Apply the mask to the original array
    filtered_array = array * mask

    return filtered_array


def extract_radiomics_features(pet_img, mask_array, output_dir, sub_name, label_type):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    mask_array = remove_small_blobs(mask_array)
    # Label the blobs in the array
    labeled_array, num_features = label(mask_array)

    # Initialize the PyRadiomics feature extractor
    extractor = featureextractor.RadiomicsFeatureExtractor()

    for i in range(1, num_features + 1):
        # Create the filename

        filename = f"{sub_name}_blob_{i}_label_{label_type}.csv"

        if Path(os.path.join(output_dir, filename)).exists():
            continue
        blob_mask = (labeled_array == i).astype(np.uint8)
        if blob_mask.sum()<=50:
            continue
        # Ensure the blob mask has the same shape as the original array
        assert blob_mask.shape == mask_array.shape
        mask_img = sitk.GetImageFromArray(blob_mask)
        mask_img.CopyInformation(pet_img)
        # Extract features
        feature_vector = extract_features(image= sitk.GetArrayFromImage(pet_img), mask = sitk.GetArrayFromImage(mask_img))
        print(len(feature_vector))
        # Convert the feature vector to a pandas DataFrame
        feature_df = pd.DataFrame([feature_vector])

        # Remove non-numeric columns
        feature_df = feature_df.select_dtypes(include=[np.number])

        # Save the DataFrame as a CSV file
        feature_df.to_csv(os.path.join(output_dir, filename), index=False)


def mean_pred_prob_per_blob(pred_array, gt_array):
    gt_labels, num_gt = label(gt_array)
    # print(num_gt)
    # Get the unique labels in the gt_array (excluding the background, assumed to be label 0)
    unique_labels = np.unique(gt_labels)
    unique_labels = unique_labels[unique_labels != 0]  # Exclude background

    mean_probs = []

    # Calculate the mean predicted probability for each blob
    for labels in unique_labels:
        # Create a mask for the current label (blob)
        mask = np.array((gt_labels == labels).astype(int))
        # print(np.unique(mask))        # Check for NaNs
        # pred_has_nan = np.any(np.isnan(pred_array))
        # gt_has_nan = np.any(np.isnan(gt_array))

        # print(pred_has_nan)  # True if NaNs are present in pred_array
        # print(gt_has_nan)    # True if NaNs are present in gt_array
        # Calculate the mean predicted probability for the current blob
        mean_prob = np.mean(pred_array[mask==1])

        # Store the mean probability with the corresponding blob label
        mean_probs.append(mean_prob)

    return mean_probs


def calculate_tp_fp_fn(gt, pred):
    # Label the blobs
    gt_labels, num_gt = label(gt)
    pred_labels, num_pred = label(pred)

    # Find bounding boxes of blobs
    gt_slices = find_objects(gt_labels)
    pred_slices = find_objects(pred_labels)

    # Create empty masks for TP, FP, FN
    tp_mask = np.zeros_like(gt)
    fp_mask = np.zeros_like(gt)
    fn_mask = np.zeros_like(gt)

    # Calculate TP, FP, FN
    tp, fp, fn = 0, 0, 0

    # Helper function to determine if two blobs overlap
    def overlap(slice1, slice2):
        return not (
            slice1[0].stop < slice2[0].start
            or slice1[0].start > slice2[0].stop
            or slice1[1].stop < slice2[1].start
            or slice1[1].start > slice2[1].stop
            or slice1[2].stop < slice2[2].start
            or slice1[2].start > slice2[2].stop
        )

    matched_gt = set()
    matched_pred = set()

    for i, gt_slice in enumerate(gt_slices):
        gt_blob = gt_labels == (i + 1)
        match_found = False
        for j, pred_slice in enumerate(pred_slices):
            if overlap(gt_slice, pred_slice):
                pred_blob = pred_labels == (j + 1)
                tp_mask[gt_blob & pred_blob] = 1  # True positive overlap region
                tp += 1
                matched_gt.add(i)
                matched_pred.add(j)
                match_found = True
                break
        if not match_found:
            fn_mask[gt_blob] = 1

    for j, pred_slice in enumerate(pred_slices):
        if j not in matched_pred:
            pred_blob = pred_labels == (j + 1)
            fp_mask[pred_blob] = 1

    fn = num_gt - tp
    fp = num_pred - tp

    return tp_mask, fp_mask, fn_mask, tp, fp, fn

import SimpleITK as sitk
import numpy as np
import pandas as pd
from tqdm import tqdm
from pathlib import Path


def etimate_model_confidence(
    src_dir,
    pred_dir,
    csv_path_metrics,
    csv_path_mp,
    gt_lesion_label=20,
    pred_lesion_label=20,
):
    """
    Processes image predictions and ground truth labels to calculate metrics
    and save results to CSV files.

    Parameters:
    src_dir (str or Path): Source directory containing images and labels.
    pred_dir (str or Path): Directory containing prediction images.
    csv_path_metrics (str or Path): Path to save the metrics CSV file.
    csv_path_mp (str or Path): Path to save the mean prediction probabilities CSV file.
    """
    src_dir = Path(src_dir)
    pred_dir = Path(pred_dir)

    metrics = []
    mp = []

    for pred in tqdm((src_dir / pred_dir).glob("*.nii.gz")):
        fname = (pred.name).split(".nii")[0]
        gt_path = str(pred).replace("imagesTs_resXL_pred_proba", "labelsTs_combined")
        pred_array = sitk.GetArrayFromImage(sitk.ReadImage(pred))
        pred_proba_array = np.load(src_dir / pred_dir / f"{fname}.npz")
        pred_proba_array = np.array(pred_proba_array["probabilities"])[-1, :, :, :]
        gt_array = sitk.GetArrayFromImage(sitk.ReadImage(gt_path))
        gt_array = (gt_array == gt_lesion_label).astype(int)
        pred_array = (pred_array == pred_lesion_label).astype(float)

        mp.extend(mean_pred_prob_per_blob(pred_proba_array, gt_array))

        if gt_array.max() == 1:
            metrics.append(
                {
                    "Subject": pred.name,
                    "Mean": pred_proba_array[gt_array != 0].mean(),
                    "Max": pred_proba_array[gt_array != 0].max(),
                    "Median": np.median(pred_proba_array[gt_array != 0]),
                    "Label": gt_array.max(),
                }
            )
        else:
            metrics.append(
                {
                    "Subject": pred.name,
                    "Mean": pred_proba_array.mean(),
                    "Max": pred_proba_array.max(),
                    "Median": np.median(pred_proba_array),
                    "Label": gt_array.max(),
                }
            )

    df_metrics = pd.DataFrame(metrics)
    df_metrics.to_csv(csv_path_metrics, index=False)

    mp_df = pd.DataFrame({"mean": mp})
    mp_df.to_csv(csv_path_mp, index=False)


def overlapping_labels(mean_positive_predicitons, gt_leasion_label=20):
    mp = pd.read_csv(mean_positive_predicitons)
    mean_value = mp["mean"].mean()
    for idx, pred in enumerate(tqdm((src_dir / pred_dir).glob("*.nii.gz"))):
        fname = (pred.name).split(".nii")[0]
        gt = str(pred).replace("imagesTs_resXL_pred_proba", "labelsTs_combined")
        pred_array = sitk.GetArrayFromImage(sitk.ReadImage(pred))
        pet_img = sitk.ReadImage(src_dir / "imagesTs" / str(fname + "_0001.nii.gz"))
        pred_proba_array = np.load(src_dir / pred_dir / str(fname + ".npz"))
        pred_proba_array = np.array(pred_proba_array["probabilities"])[-1, :, :, :]
        gt_array = sitk.GetArrayFromImage(sitk.ReadImage(gt))
        gt_array = (gt_array == gt_leasion_label).astype(int)
        gt_array[pred_proba_array > 0.50] = 2
        gt_img = sitk.GetImageFromArray(gt_array)
        gt_img.CopyInformation(pet_img)
        output_dir = src_dir / "confidence_estimates"
        output_dir.mkdir(exist_ok=True)
        sitk.WriteImage(gt_img, output_dir / str(fname + ".nii.gz"))


def remove_non_confident_blobs(array, pred_lesion_label=20):
    # Label the blobs in the array
    labeled_array, num_features = label(array)

    # Initialize a mask to keep only the blobs with max value == 2
    mask = np.zeros_like(array, dtype=bool)

    # Iterate through each blob
    for i in range(1, num_features + 1):
        # Extract the current blob
        blob = array[labeled_array == i]

        # Check if the maximum value in the blob is equal to 2
        if blob.max() == 2:
            mask[labeled_array == i] = True  # Keep this blob

    # Apply the mask to the original array
    filtered_array = array * mask
    filtered_array[filtered_array != 0] = 20
    return filtered_array


def zero_boundary_slices(arr):
    """
    Set values to zero in the first and last five slices
    along all three axes of a 3D NumPy array.

    Parameters:
    arr (np.ndarray): Input 3D array

    Returns:
    np.ndarray: The modified 3D array with boundary slices set to zero
    """
    # Validate input array
    if arr.ndim != 3:
        raise ValueError("Input array must be 3-dimensional.")

    # Zero out first and last 5 slices on the first axis
    arr[:5, :, :] = 0
    arr[-5:, :, :] = 0

    # Zero out first and last 5 slices on the second axis
    arr[:, :5, :] = 0
    arr[:, -5:, :] = 0

    # Zero out first and last 5 slices on the third axis
    arr[:, :, :5] = 0
    arr[:, :, -5:] = 0

    return arr


def create_confident_labels(
    mean_positive_predicitons, gt_leasion_label=20, pred_lesion_label=20
):
    mp = pd.read_csv(mean_positive_predicitons)
    mean_value = mp["mean"].mean()
    for idx, pred in enumerate(tqdm((src_dir / pred_dir).glob("*.nii.gz"))):
        fname = (pred.name).split(".nii")[0]
        gt = str(pred).replace("imagesTs_resXL_pred_proba", "labelsTs_combined")
        pred_array = sitk.GetArrayFromImage(sitk.ReadImage(pred))
        pet_img = sitk.ReadImage(src_dir / "imagesTs" / str(fname + "_0001.nii.gz"))
        pred_proba_all = np.load(src_dir / pred_dir / str(fname + ".npz"))
        pred_proba_array = np.array(pred_proba_all["probabilities"])[-1, :, :, :]
        pred_proba_all = np.argmax(pred_proba_all["probabilities"], axis=0)
        pred_proba_lesion = (pred_proba_all == pred_lesion_label).astype(int)
        print(np.unique(pred_proba_all))
        gt_array = sitk.GetArrayFromImage(sitk.ReadImage(gt))
        gt_array = (gt_array == gt_leasion_label).astype(int)
        pred_proba_lesion[pred_proba_array > 0.5] = 2
        print(f"before: {np.unique(pred_proba_lesion)}")
        pred_proba_filterd = remove_non_confident_blobs(pred_proba_lesion)
        pred_proba_filterd = zero_boundary_slices(pred_proba_filterd)
        print(f"after: {np.unique(pred_proba_filterd)}")
        pred_img = sitk.GetImageFromArray(pred_proba_filterd)
        pred_img.CopyInformation(pet_img)
        output_dir = src_dir / "confidence_estimates_pred"
        output_dir.mkdir(exist_ok=True)
        sitk.WriteImage(pred_img, output_dir / str(fname + ".nii.gz"))
    create_confident_labels(csv_path_mp)

In [10]:
import numpy as np
from scipy.ndimage import label, find_objects
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.auto import tqdm
import SimpleITK as sitk
import pandas as pd
from radiomics import featureextractor
import os
from mirp import extract_features


class BAMF_PET_Processor:

    def __init__(
        self,
        src_dir,
        pet_dir,
        gt_dir,
        pred_dir,
        gt_lesion_label=20,
        pred_lesion_label=20,
    ):
        self.src_dir = Path(src_dir)
        self.gt_dir = Path(gt_dir)
        self.pet_dir = Path(pet_dir)
        self.pred_dir = Path(pred_dir)
        self.extractor = featureextractor.RadiomicsFeatureExtractor()
        self.gt_lesion_label = gt_lesion_label
        self.pred_lesion_label = pred_lesion_label
        self.cle_estimates = f"{self.pred_dir}_cle_estimates.csv"
        self.blobwise_estimate = f"{self.pred_dir}_blobwise_estimates.csv"

    def count_blob_pixels(self, array):
        labeled_array, num_features = label(array)
        blob_sizes = [(labeled_array == i).sum() for i in range(1, num_features + 1)]
        return blob_sizes

    def remove_small_blobs(self, array, min_size=100):
        labeled_array, num_features = label(array)
        blob_sizes = [(labeled_array == i + 1).sum() for i in range(num_features)]
        mask = (
            labeled_array != 0
            if num_features == 1
            else np.isin(
                labeled_array,
                [i + 1 for i, size in enumerate(blob_sizes) if size >= min_size],
            )
        )
        filtered_array = array * mask
        return filtered_array

    def extract_radiomics_features(self, pet_img, mask_array, sub_name, label_type):
        if not self.output_dir.exists():
            os.makedirs(self.output_dir)
        mask_array = self.remove_small_blobs(mask_array)
        labeled_array, num_features = label(mask_array)

        for i in range(1, num_features + 1):
            filename = f"{sub_name}_blob_{i}_label_{label_type}.csv"
            if (self.output_dir / filename).exists():
                continue
            blob_mask = (labeled_array == i).astype(np.uint8)
            if blob_mask.sum() <= 50:
                continue
            mask_img = sitk.GetImageFromArray(blob_mask)
            mask_img.CopyInformation(pet_img)
            feature_vector = extract_features(
                image=sitk.GetArrayFromImage(pet_img),
                mask=sitk.GetArrayFromImage(mask_img),
            )
            feature_df = pd.DataFrame([feature_vector]).select_dtypes(
                include=[np.number]
            )
            feature_df.to_csv(self.output_dir / filename, index=False)

    def mean_pred_prob_per_blob(self, pred_array, gt_array):
        gt_labels, num_gt = label(gt_array)
        unique_labels = np.unique(gt_labels)[np.unique(gt_labels) != 0]
        mean_probs = [
            np.mean(pred_array[(gt_labels == label).astype(int) == 1])
            for label in unique_labels
        ]
        return mean_probs

    def calculate_tp_fp_fn(self, gt, pred):
        gt_labels, num_gt = label(gt)
        pred_labels, num_pred = label(pred)
        gt_slices = find_objects(gt_labels)
        pred_slices = find_objects(pred_labels)

        tp_mask, fp_mask, fn_mask = (
            np.zeros_like(gt),
            np.zeros_like(gt),
            np.zeros_like(gt),
        )
        tp, fp, fn = 0, 0, 0

        def overlap(slice1, slice2):
            return not (
                slice1[0].stop < slice2[0].start
                or slice1[0].start > slice2[0].stop
                or slice1[1].stop < slice2[1].start
                or slice1[1].start > slice2[1].stop
                or slice1[2].stop < slice2[2].start
                or slice1[2].start > slice2[2].stop
            )

        matched_gt, matched_pred = set(), set()

        for i, gt_slice in enumerate(gt_slices):
            gt_blob = gt_labels == (i + 1)
            match_found = False
            for j, pred_slice in enumerate(pred_slices):
                if overlap(gt_slice, pred_slice):
                    pred_blob = pred_labels == (j + 1)
                    tp_mask[gt_blob & pred_blob] = 1
                    tp += 1
                    matched_gt.add(i)
                    matched_pred.add(j)
                    match_found = True
                    break
            if not match_found:
                fn_mask[gt_blob] = 1

        for j, pred_slice in enumerate(pred_slices):
            if j not in matched_pred:
                pred_blob = pred_labels == (j + 1)
                fp_mask[pred_blob] = 1

        fn = num_gt - tp
        fp = num_pred - tp

        return tp_mask, fp_mask, fn_mask, tp, fp, fn

    def estimate_model_confidence(
        self, 
    ):
        metrics, mp = [], []
        for pred in tqdm((self.src_dir / self.pred_dir).glob("*.nii.gz")):
            fname = (pred.name).split(".nii")[0]
            gt_path = str(pred).replace(
                str(self.pred_dir), str(self.gt_dir)
            )
            pred_array = sitk.GetArrayFromImage(sitk.ReadImage(pred))
            pred_proba_array = np.load(self.src_dir / self.pred_dir / f"{fname}.npz")
            pred_proba_array = np.array(pred_proba_array["probabilities"])[-1, :, :, :]
            gt_array = sitk.GetArrayFromImage(sitk.ReadImage(gt_path))
            gt_array = (gt_array == self.gt_lesion_label).astype(int)
            pred_array = (pred_array == self.pred_lesion_label).astype(float)

            mp.extend(self.mean_pred_prob_per_blob(pred_proba_array, gt_array))
            tp_mask, fp_mask, fn_mask, tp, fp, fn = self.calculate_tp_fp_fn(gt_array, pred_array)
            if gt_array.max() == 1:
                metrics.append(
                    {
                        "Subject": pred.name,
                        "Mean": pred_proba_array[gt_array != 0].mean(),
                        "Max": pred_proba_array[gt_array != 0].max(),
                        "Median": np.median(pred_proba_array[gt_array != 0]),
                        "tp": tp,
                        "fp": fp,
                        "fn": fn,
                        "Label": gt_array.max(),
                    }
                )
            else:
                metrics.append(
                    {
                        "Subject": pred.name,
                        "Mean": pred_proba_array.mean(),
                        "Max": pred_proba_array.max(),
                        "Median": np.median(pred_proba_array),
                        "Label": gt_array.max(),
                    }
                )

        pd.DataFrame(metrics).to_csv(self.cle_estimates, index=False)
        pd.DataFrame({"mean": mp}).to_csv(self.blobwise_estimate, index=False)

    def overlapping_labels(self):
        if not Path(self.blobwise_estimate).exists():
            self.estimate_model_confidence()
        mp = pd.read_csv(self.blobwise_estimates)
        mean_value = mp["mean"].mean()
        for idx, pred in enumerate(
            tqdm((self.src_dir / self.pred_dir).glob("*.nii.gz"))
        ):
            output_dir = self.src_dir / f"{self.pred_dir}_confidence_estimates"
            output_dir.mkdir(exist_ok=True)
            if not (output_dir / f"{fname}.nii.gz").exists():
                fname = (pred.name).split(".nii")[0]
                gt = str(pred).replace(str(self.pred_dir), str(self.gt_dir))
                pred_array = sitk.GetArrayFromImage(sitk.ReadImage(pred))
                pet_img = sitk.ReadImage(self.src_dir / self.pet_dir / f"{fname}_0001.nii.gz")
                pred_proba_array = np.load(self.src_dir / self.pred_dir / f"{fname}.npz")
                pred_proba_array = np.array(pred_proba_array["probabilities"])[-1, :, :, :]
                gt_array = sitk.GetArrayFromImage(sitk.ReadImage(gt))
                gt_array = (gt_array == self.gt_lesion_label).astype(int)
                gt_array[pred_proba_array > 0.50] = 2
                gt_img = sitk.GetImageFromArray(gt_array)
                gt_img.CopyInformation(pet_img)
                sitk.WriteImage(gt_img, output_dir / f"{fname}.nii.gz")

    def remove_non_confident_blobs(self, array):
        labeled_array, num_features = label(array)
        mask = np.zeros_like(array, dtype=bool)
        for i in range(1, num_features + 1):
            blob = array[labeled_array == i]
            if blob.max() == 2:
                mask[labeled_array == i] = True
        filtered_array = array * mask
        filtered_array[filtered_array != 0] = 20
        return filtered_array

    def zero_boundary_slices(self, arr):
        if arr.ndim != 3:
            raise ValueError("Input array must be 3-dimensional.")
        arr[:5, :, :] = 0
        arr[-5:, :, :] = 0
        arr[:, :5, :] = 0
        arr[:, -5:, :] = 0
        arr[:, :, :5] = 0
        arr[:, :, -5:] = 0
        return arr

    def create_confident_labels(
        self,
    ):
        if not Path(self.blobwise_estimate).exists():
            self.estimate_model_confidence()
        mp = pd.read_csv(self.blobwise_estimate)
        mean_value = mp["mean"].mean()
        print(f"Model CLE: {mean_value}")
        for idx, pred in enumerate(
            tqdm((self.src_dir / self.pred_dir).glob("*.nii.gz"))
        ):
            fname = (pred.name).split(".nii")[0]
            output_dir = self.src_dir / f"{self.pred_dir}_confidence_labels"
            output_dir.mkdir(exist_ok=True)

            if not (output_dir / f"{fname}.nii.gz").exists():
                gt = str(pred).replace(str(self.pred_dir), str(self.gt_dir))
                pred_array = sitk.GetArrayFromImage(sitk.ReadImage(pred))
                pet_img = sitk.ReadImage(self.src_dir / self.pet_dir / f"{fname}_0001.nii.gz")
                pred_proba_array = np.load(self.src_dir / self.pred_dir / f"{fname}.npz")
                pred_proba_array = np.array(pred_proba_array["probabilities"])[-1, :, :, :]
                gt_array = sitk.GetArrayFromImage(sitk.ReadImage(gt))
                gt_array = (gt_array == self.gt_lesion_label).astype(int)
                gt_array[pred_proba_array > 0.50] = 2
                gt_array = self.remove_non_confident_blobs(gt_array)
                gt_array = self.zero_boundary_slices(gt_array)
                gt_img = sitk.GetImageFromArray(gt_array)
                gt_img.CopyInformation(pet_img)
                sitk.WriteImage(gt_img, output_dir / f"{fname}.nii.gz")

In [11]:
src_dir = Path(
    "/mnt/nfs/slow_ai_team/organ_segmentation/nnunet_liverv0.0/nnUNet_raw_database/nnUNet_raw/nnUNet_raw_data/Dataset019_AutoPET2024/"
)
pred_dir = "imagesTr_resXL_pred_proba"
pet_dir = "imagesTr"
gt_dir = "labelsTr"

processor = BAMF_PET_Processor(
        src_dir,
        pet_dir,
        gt_dir,
        pred_dir,
        gt_lesion_label=20,
        pred_lesion_label=20,
    )

processor.create_confident_labels()

0it [00:00, ?it/s]

Model CLE: 0.5817180254671563


0it [00:00, ?it/s]

In [12]:
# # Step 1 Estimate model confidence
# csv_path_metrics = "CL_estimates.csv"
# csv_path_mp = "Mean_positive_prediction.csv"
# src_dir = Path(
#     "/mnt/nfs/slow_ai_team/organ_segmentation/nnunet_liverv0.0/nnUNet_raw_database/nnUNet_raw/nnUNet_raw_data/Dataset019_AutoPET2024/"
# )
# pred_dir = "imagesTs_resXL_pred_proba"
# gt_dir = "imagesTs"
# etimate_model_confidence(src_dir, pred_dir, csv_path_metrics, csv_path_mp)

# # Step 2 Create Overlapping of confident and predicted labels (optional)

# overlapping_labels(csv_path_mp)

# # Step 3: Create Confident labels as final predicitions

# create_confident_labels(csv_path_mp)