In [2]:
import os
import nibabel as nib
from monai.metrics import DiceMetric, HausdorffDistanceMetric, ConfusionMatrixMetric
import torch
import matplotlib.pyplot as plt
import pandas as pd



In [15]:

    
def extract_prefix(filename):
    base_name = os.path.basename(filename)
    if '_' in base_name:
        return base_name.split('_')[0]
    else:
        return os.path.splitext(base_name)[0]

    
# Function to determine model and modality from pred_folder
def get_model_modality(pred_folder):
    if 'aschoplex' in pred_folder:
        model = 'aschoplex'
    elif 'phusegplex' in pred_folder:
        model = 'phusegplex'
    elif 'umamba' in pred_folder:
        model = 'umamba'
    else:
        model = 'unknown'

    if 'T1/' in pred_folder:
        modality = 'T1'
    elif '_FLAIR/' in pred_folder:
        modality = 'FLAIR'
    elif 'T1xFLAIR/' in pred_folder:
        modality = 'T1xFLAIR'
    elif 'T1_FLAIR_T1xFLAIRmask/' in pred_folder:
        modality = 'T1_FLAIR_T1xFLAIRmask'
    else:
        modality = 'unknown'

    return f"{model}_{modality}"

# List of prediction folders
pred_folders = [
    "/home/linuxlia/Lia_Masterthesis/phuse_thesis_2024/thesis_experiments/01_aschoplex_from_scratch/working_directory_01_T1/ensemble_output/image_Ts",
    "/home/linuxlia/Lia_Masterthesis/phuse_thesis_2024/thesis_experiments/01_aschoplex_from_scratch/working_directory_01_FLAIR/ensemble_output/image_Ts",
    "/home/linuxlia/Lia_Masterthesis/phuse_thesis_2024/thesis_experiments/01_aschoplex_from_scratch/working_directory_01_T1xFLAIR/ensemble_output/image_Ts",
    "/home/linuxlia/Lia_Masterthesis/phuse_thesis_2024/thesis_experiments/01_aschoplex_from_scratch/working_directory_01_T1_FLAIR_T1xFLAIRmask/ensemble_output/image_Ts",
    "/home/linuxlia/Lia_Masterthesis/phuse_thesis_2024/thesis_experiments/02_phusegplex/working_directory_02_T1/ensemble_output/image_Ts", 
    "/home/linuxlia/Lia_Masterthesis/phuse_thesis_2024/thesis_experiments/02_phusegplex/working_directory_02_FLAIR/ensemble_output/image_Ts", 
    "/home/linuxlia/Lia_Masterthesis/phuse_thesis_2024/thesis_experiments/02_phusegplex/working_directory_02_T1xFLAIR/ensemble_output/image_Ts", 
    "/home/linuxlia/Lia_Masterthesis/phuse_thesis_2024/thesis_experiments/umamba_predictions/working_directory_T1/pred_pp",
    "/home/linuxlia/Lia_Masterthesis/phuse_thesis_2024/thesis_experiments/umamba_predictions/working_directory_FLAIR/pred_pp",
    "/home/linuxlia/Lia_Masterthesis/phuse_thesis_2024/thesis_experiments/umamba_predictions/working_directory_T1xFLAIR/pred_pp",
    "/home/linuxlia/Lia_Masterthesis/phuse_thesis_2024/thesis_experiments/umamba_predictions/working_directory_T1_FLAIR_T1xFLAIRmask/pred_pp"
]

gt_folder = "/home/linuxlia/Lia_Masterthesis/data/reference_labels_T1/ref_labelTs"
gt_files = sorted([f for f in os.listdir(gt_folder) if f.endswith('.nii')])
ground_truths = [nib.load(os.path.join(gt_folder, f)).get_fdata() for f in gt_files]

# Initialize the metrics
dice_metric = DiceMetric(include_background=False, reduction="mean")
hausdorff_metric = HausdorffDistanceMetric(include_background=False, percentile=95)
confusion_matrix_metric = ConfusionMatrixMetric(include_background=False, metric_name=["precision", "recall", "f1_score"], reduction="mean")

# List to store the results
results_list = []

# Loop over each prediction folder
for pred_folder in pred_folders:
    # Get model and modality
    model_modality = get_model_modality(pred_folder)
    print(f"Evaluating {model_modality}...")
    pred_files = sorted([f for f in os.listdir(pred_folder) if f.endswith('.nii.gz')])
    predictions = [nib.load(os.path.join(pred_folder, f)).get_fdata() for f in pred_files]

    dice_scores = []
    hd_distances = []
    f1_scores = []
    precisions = []
    recalls = []

    # Initialize variables to store best and worst segmentations
    best_dice_score = -1
    worst_dice_score = float('inf')
    best_hd_distance = float('inf')
    worst_hd_distance = -1

    best_dice_filename = None
    worst_dice_filename = None
    best_hd_filename = None
    worst_hd_filename = None

    # Compute metrics for each pair of prediction and ground truth
    for pred, gt, filename in zip(predictions, ground_truths, pred_files):
        pred_tensor = torch.tensor(pred, dtype=torch.float32).unsqueeze(0).unsqueeze(0)  # Add batch and channel dimensions
        gt_tensor = torch.tensor(gt, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
        
        # Check for empty tensors
        if torch.sum(pred_tensor) == 0 or torch.sum(gt_tensor) == 0:
            print("Warning: Empty prediction or ground truth tensor detected.")
            continue
        
        # Ensure matching shapes
        if pred_tensor.shape != gt_tensor.shape:
            print(f"Error: Shape mismatch - pred: {pred_tensor.shape}, gt: {gt_tensor.shape}")
            continue
        
        # Compute Dice score
        dice_score = dice_metric(y_pred=pred_tensor, y=gt_tensor)
        mean_dice_score = dice_score.mean().item()
        dice_scores.append(mean_dice_score)

        # Update best and worst Dice segmentations
        if mean_dice_score > best_dice_score:
            best_dice_score = mean_dice_score
            best_dice_prefix = extract_prefix(filename)
        if mean_dice_score < worst_dice_score:
            worst_dice_score = mean_dice_score
            worst_dice_prefix = extract_prefix(filename)
        
        # Compute Hausdorff distance
        hd_distance = hausdorff_metric(y_pred=pred_tensor, y=gt_tensor)
        hd_distances.append(hd_distance.item())

        # Update best and worst Hausdorff segmentations
        if hd_distance < best_hd_distance:
            best_hd_distance = hd_distance.item()
            best_hd_prefix = extract_prefix(filename)
        if hd_distance > worst_hd_distance:
            worst_hd_distance = hd_distance.item()
            worst_hd_prefix = extract_prefix(filename)
        
        # Accumulate confusion matrix results
        confusion_matrix_metric(y_pred=pred_tensor, y=gt_tensor)
        precision, recall, f1_score = confusion_matrix_metric.aggregate()
        
        f1_scores.append(f1_score.item())
        precisions.append(precision.item())
        recalls.append(recall.item())

    # Aggregate metrics to get summary statistics
    mean_dice = sum(dice_scores) / len(dice_scores)
    mean_hd_distance = sum(hd_distances) / len(hd_distances)
    mean_precision = sum(precisions) / len(precisions)
    mean_recall = sum(recalls) / len(recalls)
    mean_f1_score = sum(f1_scores) / len(f1_scores)

    # Append the results to the list
    results_list.append({
        "Model": model_modality,
        "Mean Dice": mean_dice,
        "Mean HD": mean_hd_distance,
        "Precision": mean_precision,
        "Recall": mean_recall,
        "F1 Score": mean_f1_score,
        "F1 Score": mean_f1_score,
        "Best Dice": best_dice_prefix,
        "Worst Dice": worst_dice_prefix,
        "Best HD": best_hd_prefix,
        "Worst HD": worst_hd_prefix
    })

# Convert the results list to a DataFrame
results = pd.DataFrame(results_list)

# Print the results table
print(results)

Evaluating aschoplex_T1...
Evaluating aschoplex_FLAIR...
Evaluating aschoplex_T1xFLAIR...
Evaluating aschoplex_T1_FLAIR_T1xFLAIRmask...
Evaluating phusegplex_T1...
Evaluating phusegplex_FLAIR...
Evaluating phusegplex_T1xFLAIR...
Evaluating umamba_T1...
Evaluating umamba_FLAIR...
Evaluating umamba_T1xFLAIR...
Evaluating umamba_T1_FLAIR_T1xFLAIRmask...
                              Model  Mean Dice   Mean HD  Precision    Recall  \
0                      aschoplex_T1   0.915969  1.072893   0.933723  0.889248   
1                   aschoplex_FLAIR   0.762849  2.201648   0.839537  0.890027   
2                aschoplex_T1xFLAIR   0.810770  1.832458   0.786519  0.878894   
3   aschoplex_T1_FLAIR_T1xFLAIRmask   0.826060  1.804438   0.781045  0.879016   
4                     phusegplex_T1   0.920008  1.067720   0.793595  0.882736   
5                  phusegplex_FLAIR   0.765002  2.182925   0.796140  0.883409   
6               phusegplex_T1xFLAIR   0.792209  2.052917   0.784549  0.876540   

In [16]:
# Save the DataFrame to a CSV file
file_path = "/home/linuxlia/Lia_Masterthesis/phuse_thesis_2024/thesis_experiments/segmentation_metrics_t1_gt.csv"
results.to_csv(file_path, index=False)

#import ace_tools as tools; tools.display_dataframe_to_user(name="Segmentation Metrics", dataframe=results)

In [None]:

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title("Prediction")
plt.imshow(predictions[0][..., int(predictions[0].shape[-1] / 2)], cmap="gray")

plt.subplot(1, 2, 2)
plt.title("Ground Truth")
plt.imshow(ground_truths[0][..., int(ground_truths[0].shape[-1] / 2)], cmap="gray")

plt.show()
