A script for collecting recall, precision, dice, iou, and difference maps

Can be used for finetuned models.

In [1]:
from micro_sam.evaluation.evaluation import run_evaluation
import os
import tifffile as tiff
import matplotlib.pyplot as plt
from natsort import natsorted
import numpy as np
import random
import pandas as pd
import seaborn as sns

from utils.stack_manipulation_utils import save_indv_images
from utils.microsam_utils import load_filepaths

pd.set_option('display.max_rows', None)


In [None]:
# Load finetuned dataset paths
finetuned_dir = '/vol/biomedic3/bglocker/mscproj24/nma23/data/testing_directory/multi_model/finetuning'
train_image_paths = load_filepaths(finetuned_dir, 'train_image_paths.pkl')
train_gt_paths = load_filepaths(finetuned_dir, 'train_gt_paths.pkl')
val_image_paths = load_filepaths(finetuned_dir, 'val_image_paths.pkl')
val_gt_paths = load_filepaths(finetuned_dir, 'val_gt_paths.pkl')
test_image_paths = load_filepaths(finetuned_dir, 'test_image_paths.pkl')
test_gt_paths = load_filepaths(finetuned_dir, 'test_gt_paths.pkl')

# Load finetuned normalised mask paths
# image_files = natsorted(test_image_paths) 

test_gt_normalised_mask_dir = '/vol/biomedic3/bglocker/mscproj24/nma23/data/testing_directory/multi_model/eval_test_data/normalised_test_labels'
# test_pred_normalised_mask_dir = "/vol/biomedic3/bglocker/mscproj24/nma23/data/testing_directory/multi_model/finetuning/inference/v4_all_models_full_inference/vit_l/v4_eval/pred_binary"

gt_relative_filepaths = natsorted(os.listdir(test_gt_normalised_mask_dir))


gt_abs_files = natsorted(os.listdir(test_gt_normalised_mask_dir))
# pred_abs_files = natsorted(os.listdir(test_pred_normalised_mask_dir))

gt_files = [os.path.join(test_gt_normalised_mask_dir, f) for f in gt_abs_files]
# pred_files = [os.path.join(test_pred_normalised_mask_dir, f) for f in pred_abs_files]
image_files = natsorted(test_image_paths)

# assert len(gt_files) == len(pred_files) == len(image_files), "Mismatch in the number of files."

In [2]:
# Functions

def list_files_in_directory(directory_path):
    file_list = []
    for file in natsorted(os.listdir(directory_path)):
        if file == 'Mask_Stack.tiff': # Skip full mask stack
            continue
        file_path = os.path.join(directory_path, file)
        if os.path.isfile(file_path):
            file_list.append(file_path)
    
    return file_list

def plot_samples(gt, prediction):
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))

    axs[0].imshow(gt, cmap='gray')
    axs[0].set_title('Test Label')
    axs[0].axis('off')

    axs[1].imshow(prediction, cmap='gray')
    axs[1].set_title('Prediction Segmentation')
    axs[1].axis('off')

    plt.show()
    plt.close()

def plot_four_images(image, gt, phantast_pred, model_pred, model_name):

    fig, axs = plt.subplots(1, 4, figsize=(24, 6))

    axs[0].imshow(image, cmap='gray')
    axs[0].set_title('Brightfield Image')
    axs[0].axis('off')

    axs[1].imshow(gt, cmap='gray')
    axs[1].set_title('Ground Truth')
    axs[1].axis('off')

    axs[2].imshow(phantast_pred, cmap='gray')
    axs[2].set_title('Phantast Prediction')
    axs[2].axis('off')

    axs[3].imshow(model_pred, cmap='gray')
    axs[3].set_title(f'{model_name} Prediction')
    axs[3].axis('off')

    plt.show()
    plt.close()

def visualise_difference_map(ground_truth, prediction, image):
   
   # Create an empty RGB image
   height, width = ground_truth.shape
   confusion_image = np.zeros((height, width, 3), dtype=np.uint8)

   # Create masks for different scenarios
   true_positive_mask = (ground_truth == 1) & (prediction == 1)
   true_negative_mask = (ground_truth == 0) & (prediction == 0)
   false_positive_mask = (ground_truth == 0) & (prediction == 1)
   false_negative_mask = (ground_truth == 1) & (prediction == 0)

   # Apply colors
   confusion_image[true_positive_mask] = [255, 255, 255]  # White for true positives (cell)
   confusion_image[true_negative_mask] = [0, 0, 0]        # Black for true negatives (background)
   confusion_image[false_positive_mask] = [0, 0, 255]     # Blue for false positives
   confusion_image[false_negative_mask] = [255, 0, 0]     # Red for false negatives

   plt.figure(figsize=(18, 6))  
  
   # Subplot 1: Original Image
   plt.subplot(1, 4, 1)
   plt.imshow(image, cmap='gray')
   plt.title('Original Image', fontsize=25)  # Make font size bigger for report
   plt.axis('off')
  
   # Subplot 2: Ground Truth
   plt.subplot(1, 4, 2)
   plt.imshow(ground_truth, cmap='gray')
   plt.title('Ground Truth', fontsize=25)  
   plt.axis('off')
  
   # Subplot 3: Prediction
   plt.subplot(1, 4, 3)
   plt.imshow(prediction, cmap='gray')
   plt.title('Prediction', fontsize=25) 
   plt.axis('off')
  
   # Subplot 4: Difference Map
   plt.subplot(1, 4, 4)
   plt.imshow(confusion_image)
   plt.title('Difference Map', fontsize=25)  
   plt.axis('off')
  
   plt.tight_layout()
   plt.show()


def calc_recall_score(gt_mask, pred_mask):
    intersect = np.sum(pred_mask*gt_mask)
    total_pixel_truth = np.sum(gt_mask)
    if total_pixel_truth == 0:
        return 0
    recall = np.mean(intersect/total_pixel_truth)
    return round(recall, 6)

# Precision = TP / (TP + FP)
def calc_precision_score(gt_mask, pred_mask):
    intersect = np.sum(pred_mask*gt_mask)
    total_pixel_pred = np.sum(pred_mask)
    if total_pixel_pred == 0:
        return 0
    precision = np.mean(intersect/total_pixel_pred)
    return round(precision, 6)

# Dice = 2TP / (2TP + FP + FN)
def calc_dice_coef(gt_mask, pred_mask):
    intersect = np.sum(pred_mask*gt_mask)
    total_sum = np.sum(pred_mask) + np.sum(gt_mask)
    if total_sum == 0:
        return 0
    dice = np.mean(2*intersect/total_sum)
    return round(dice, 6)
    
def calc_iou(gt_mask, pred_mask):
    # Ensure the images are binary
    gt_mask = gt_mask > 0.5
    pred_mask = pred_mask > 0.5
    
    intersection = np.logical_and(gt_mask, pred_mask)
    union = np.logical_or(gt_mask, pred_mask)
    
    if np.sum(union) == 0:
        return 0
    iou = np.sum(intersection) / np.sum(union)
    return iou

def plot_confusion_matrix(tp, fp, tn, fn, model_name):
    # Round values to the nearest integer
    tp = round(tp)
    fp = round(fp)
    tn = round(tn)
    fn = round(fn)
    
    # Create the confusion matrix
    confusion_matrix = np.array([[tp, fn],
                                 [fp, tn]])
    
    # Plot
    sns.heatmap(confusion_matrix, annot=True, fmt='d', cmap='Blues',
                xticklabels=['Predicted Positive', 'Predicted Negative'],
                yticklabels=['Actual Positive', 'Actual Negative'])
    plt.title(f'Confusion Matrix for {model_name}')
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.show()


def plot_side_by_side_images(dir1, dir2):
    images1 = natsorted([f for f in os.listdir(dir1) if f.endswith(('.tif', '.tiff', '.png', '.jpg'))])
    images2 = natsorted([f for f in os.listdir(dir2) if f.endswith(('.tif', '.tiff', '.png', '.jpg'))])
    
    assert len(images1) == len(images2), "Directories contain different amounts of images."

    for img1, img2 in zip(images1, images2):

        image1 = tiff.imread(os.path.join(dir1, img1))
        image2 = tiff.imread(os.path.join(dir2, img2))

        # Create a figure with two subplots
        fig, axs = plt.subplots(1, 2, figsize=(10, 5))

        axs[0].imshow(image1, cmap='gray')
        axs[0].set_title(f'Image from {dir1}')
        axs[0].axis('off')

        axs[1].imshow(image2, cmap='gray')
        axs[1].set_title(f'Image from {dir2}')
        axs[1].axis('off')

        plt.tight_layout()
        plt.show()
        plt.close()

In [None]:
# Get test indexes (needed to line up images from other models during stack separation and renaming)
indexes = [os.path.basename(label).split('_')[-1].split('.')[0] for label in test_gt_paths]
print(indexes)

In [None]:
# Store model segmentations in 'model_segmentations' dictionary

dir_path = "/vol/biomedic3/bglocker/mscproj24/nma23/data/testing_directory/multi_model/eval_test_data/"

# List of models (each model has a directory in dir_path that was created from)
eval_models = ['PHANTAST',
                'vit_b', 
                'vit_b_lm_amg', 
                'vit_b_lm_ais', 
                'vit_l', 
                'vit_l_lm_amg', 
                'vit_l_lm_ais',
                'dilated_vit_l_lm_ais', 
                'cellpose3', 
                'ensemble_1', 
                'ensemble_2', 
                'ensemble_3', 
                'ensemble_4_w_PHANTAST', 
                'ensemble_5_w_PHANTAST',
                'finetuned_vit_b',
                'finetuned_vit_b_lm',
                'finetuned_vit_l',
                'finetuned_vit_l_lm'
                ]

# Dictionary to store the segmentation file paths for each model
model_segmentations = {}

# Get segmentations for each model and store in dictionary
for model in eval_models:
    model_dir = os.path.join(dir_path, f"Model_{model}/Segmentation_Output")

    # Get segmentation file paths for each model
    model_segmentations[model] = list_files_in_directory(model_dir)

    if len(model_segmentations[model]) != 150:
        print(f"Model {model} has {len(model_segmentations[model])} segmentations")

# Print dictionary keys
print(f"Dictionary 'model_segmentations' keys: {model_segmentations.keys()}")

In [None]:
# Plot normal and dilated predictions side by side
def plot_normal_dilated(normal, dilated):

    fig, axs = plt.subplots(1, 2, figsize=(12, 6))

    fontsize = 25

    axs[0].imshow(normal, cmap='gray')
    axs[0].set_title('Normal Prediction', fontsize=fontsize)
    axs[0].axis('off')

    axs[1].imshow(dilated, cmap='gray')
    axs[1].set_title('Dilated Prediction', fontsize=fontsize)
    axs[1].axis('off')

    plt.tight_layout()
    plt.show()
    plt.close()

normal_image_path = os.path.join(dir_path, "Model_vit_l_lm_ais/Segmentation_Output", model_segmentations['vit_l_lm_ais'][4])
dilated_image_path = os.path.join(dir_path, "Model_dilated_vit_l_lm_ais/Segmentation_Output", model_segmentations['dilated_vit_l_lm_ais'][4])

normal_image = tiff.imread(normal_image_path)
dilated_image = tiff.imread(dilated_image_path)

plot_normal_dilated(normal_image, dilated_image)



In [None]:
# Calculate metrics for all models
results_list = []

for model in eval_models:
    print(f"Running metrics evaluation for {model}")  

    # Initialize totals to 0
    tp_total = 0
    fp_total = 0
    tn_total = 0
    fn_total = 0
    count = len(model_segmentations[model])

    # Loop to calculate metrics for each image
    for i in range(count):
        # Load the corresponding ground truth and predicted segmentation
        test_gt = tiff.imread(os.path.join(test_gt_normalised_mask_dir, gt_relative_filepaths[i]))
        test_pred = tiff.imread(model_segmentations[model][i])

        # Calculate confusion matrix values
        tp = np.sum((test_gt == 1) & (test_pred == 1))
        fp = np.sum((test_gt == 0) & (test_pred == 1))
        tn = np.sum((test_gt == 0) & (test_pred == 0))
        fn = np.sum((test_gt == 1) & (test_pred == 0))

        # Aggregate the confusion matrix values
        tp_total += tp
        fp_total += fp
        tn_total += tn
        fn_total += fn
        
        # Calculate recall, precision, dice, and IoU
        recall = calc_recall_score(test_gt, test_pred)
        precision = calc_precision_score(test_gt, test_pred)
        dice = calc_dice_coef(test_gt, test_pred)
        iou = calc_iou(test_gt, test_pred)

        # Store metrics for each image in results list
        results_list.append({
            "Model": model,
            "Recall": recall,
            "Precision": precision,
            "Dice Coefficient": dice,
            "IoU": iou,
            "True Positives": tp,
            "False Positives": fp,
            "True Negatives": tn,
            "False Negatives": fn
        })

    # After all images, store total confusion matrix values
    results_list.append({
        "Model": model,
        "True Positives": tp_total,
        "False Positives": fp_total,
        "True Negatives": tn_total,
        "False Negatives": fn_total
    })

# Convert list to DataFrame
results_df = pd.DataFrame(results_list)

# Create a DataFrame for general metrics (Recall, Precision, Dice, IoU)
general_metrics_df = results_df[["Model", "Recall", "Precision", "Dice Coefficient", "IoU"]].dropna()
general_metrics_df = general_metrics_df.groupby("Model").mean().reset_index()

# Create a DataFrame for TP, FP, TN, FN
tp_fp_tn_fn_df = results_df[["Model", "True Positives", "False Positives", "True Negatives", "False Negatives"]].dropna()
tp_fp_tn_fn_df = tp_fp_tn_fn_df.groupby("Model").sum().reset_index()  # Ensure these are summed, not averaged

# Sort and display the general metrics DataFrame by Dice Coefficient in descending order
general_metrics_df = general_metrics_df.sort_values(by="Dice Coefficient", ascending=False)
print("Average Metrics for Each Model, Sorted by Dice Coefficient")
display(general_metrics_df)


After this, just looking at models of interest

In [5]:
# Create models_of_interest list

models_of_interest = [    
    'finetuned_vit_l', 
    'finetuned_vit_l_lm', 
    'finetuned_vit_b_lm', 
    'finetuned_vit_b', 
    'PHANTAST', 
    'cellpose3'
]

In [None]:
# Calculate confusion metrics

# Filter the results DataFrame for the models of interest
filtered_results_df = tp_fp_tn_fn_df[tp_fp_tn_fn_df['Model'].isin(models_of_interest)]

# Calculate and display FP/FN ratios for each model in the models of interest
filtered_results_df["FP/FN Ratio"] = filtered_results_df["False Positives"] / filtered_results_df["False Negatives"]

# Display the TP, FP, TN, FN counts and FP/FN ratios
print("Total TP, FP, TN, FN for Models of Interest")
display(filtered_results_df)

print("False Positives to False Negatives Ratios for Models of Interest")
display(filtered_results_df[["Model", "FP/FN Ratio"]])

In [None]:
# Plot difference maps between ground truth and prediction masks

# pred_dir = '/vol/biomedic3/bglocker/mscproj24/nma23/data/testing_directory/multi_model/eval_test_data/Model_finetuned_vit_l_lm/Segmentation_Output'
pred_dir = "/vol/biomedic3/bglocker/mscproj24/nma23/data/testing_directory/multi_model/eval_test_data/Model_finetuned_vit_b_lm/Segmentation_Output"

pred_abs_files = natsorted(os.listdir(pred_dir))
pred_files = [os.path.join(pred_dir, f) for f in pred_abs_files]

# print(f"gt_files: {gt_files}")
# print(f"pred_files: {pred_files}")
# print(f"image_files: {image_files}")
# print dtypes


# for idx, (gt_file, pred_file, image_file) in enumerate(zip(gt_files, pred_files, image_files)):
#    if idx >3:  # Limit the number of images to display
#       break
#    # Load the ground truth and prediction masks
#    gt_mask = tiff.imread(gt_file)
#    pred_mask = tiff.imread(pred_file)
#    image = tiff.imread(image_file)

#    # May need to convert masks to 0-1 range depending on where they're loaded from
#    gt_mask = gt_mask / np.max(gt_mask)
#    gt_mask = gt_mask.astype(np.uint8)

#    print(f"gt_mask: {gt_file}")
#    print(f"pred_mask: {pred_file}")
#    print(f"image: {image_file}")

#    print(f"gt_files dtype: {type(gt_file)}, min: {np.min(gt_mask)}, max: {np.max(gt_mask)}")
#    print(f"pred_files dtype: {type(pred_file)}, min: {np.min(pred_mask)}, max: {np.max(pred_mask)}")
  
#    # Visualize the difference map
#    visualise_difference_map(gt_mask, pred_mask, image)

# Assuming gt_files, pred_files, and image_files are already defined and have the same length

index = 60

gt_file = gt_files[index]
pred_file = pred_files[index]
image_file = image_files[index]

gt_mask = tiff.imread(gt_file)
pred_mask = tiff.imread(pred_file)
image = tiff.imread(image_file)

# Print file paths for debugging
print(f"GT File: {gt_file}")
print(f"Pred File: {pred_file}")
print(f"Image File: {image_file}")

# Visualize the difference map for the 56th image
visualise_difference_map(gt_mask, pred_mask, image)



Find images where PHANTAST outperforms a given model

In [None]:
# Find images where PHANTAST outperforms a given model

dice_comparison_data = []
win_counts_data = []

# Calculate PHANTAST Dice Coefficients for each image
phantast_dice = []
phantast_masks = []  
for i in range(len(gt_files)):
    gt_mask = tiff.imread(gt_files[i])
    phantast_mask = tiff.imread(model_segmentations['PHANTAST'][i])
    dice_value = calc_dice_coef(gt_mask, phantast_mask)
    phantast_dice.append(dice_value)
    phantast_masks.append(phantast_mask)

    dice_comparison_data.append({'Model': 'PHANTAST', 'Image_Index': i, 'Dice_Coefficient': dice_value})

# Loop through each model and compare with PHANTAST
for model in models_of_interest:

    if model == 'PHANTAST' or model == 'cellpose3':
        continue

    phantast_win = 0
    model_win = 0

    for i in range(len(gt_files)):
        gt_mask = tiff.imread(gt_files[i])
        model_mask = tiff.imread(model_segmentations[model][i])

        model_dice = calc_dice_coef(gt_mask, model_mask)

        # Append model's Dice Coefficient to the comparison data
        dice_comparison_data.append({'Model': model, 'Image_Index': i, 'Dice_Coefficient': model_dice})

        # Compare PHANTAST and the current model for this image
        if phantast_dice[i] > model_dice:

            print(f"PHANTAST outperforms {model} on image {i} with Dice Coefficient {phantast_dice[i]} > {model_dice}")

            # Plot the masks
            # plot_samples(gt_mask, phantast_masks[i]) 
            # plot_samples(gt_mask, model_mask)
            plot_four_images(tiff.imread(image_files[i]), gt_mask, phantast_masks[i], model_mask, model)
            phantast_win += 1
        elif model_dice > phantast_dice[i]:
            model_win += 1
        else:
            print(f"PHANTAST and {model} have the same Dice Coefficient of {phantast_dice[i]}")

    # Store win counts
    win_counts_data.append({'Model': model, 'PHANTAST_Wins': phantast_win, 'Model_Wins': model_win})

# Convert to DataFrames
dice_comparison_df = pd.DataFrame(dice_comparison_data)
win_counts_df = pd.DataFrame(win_counts_data)

# Display the win counts
print("Win Counts Out of 150 Images")
for model in models_of_interest:
    if model == 'PHANTAST' or model == 'cellpose3':
        continue
    model_data = win_counts_df[win_counts_df['Model'] == model].iloc[0]
    print(f"Win Counts for {model}:")
    print(f"PHANTAST Wins: {model_data['PHANTAST_Wins']}")
    print(f"{model} Wins: {model_data['Model_Wins']}")
    print()


Plot violin plots of Dice Coefficient for PHANTAST vs models

In [None]:
# Plot violin plots of Dice coefficient for PHANTAST vs models 

# NOTE: where the plot is wider, there is more data at that value

plt.figure(figsize=(12, 6))
sns.violinplot(data=dice_comparison_df, x='Model', y='Dice_Coefficient')
plt.title('Dice Coefficient Comparison Between PHANTAST and Finetuned Models')
plt.ylabel('Dice Coefficient')
plt.xticks(rotation=45)
plt.show()

In [36]:
# Disk performance

dir1 = "/vol/biomedic3/bglocker/mscproj24/nma23/data/testing_directory/multi_model/eval_test_data/inhomogeneous_light_exps/inhomogeneous_light_norm_labels"
dir2 = "/vol/biomedic3/bglocker/mscproj24/nma23/data/testing_directory/multi_model/eval_test_data/inhomogeneous_light_exps/disk30_finetuned_vit_l_lm/Segmentation_Output"


images1 = natsorted([f for f in os.listdir(dir1) if f.endswith(('.tif', '.tiff', '.png', '.jpg'))])
images2 = natsorted([f for f in os.listdir(dir2) if f.endswith(('.tif', '.tiff', '.png', '.jpg'))])

recall_scores = []
precision_scores = []

# Loop through both image lists, calculate recall and precision
for img1, img2 in zip(images1, images2):
    # Load the images
    image1 = tiff.imread(os.path.join(dir1, img1))
    image2 = tiff.imread(os.path.join(dir2, img2))

    recall = calc_recall_score(image1, image2)
    precision = calc_precision_score(image1, image2)
    
    recall_scores.append(recall)
    precision_scores.append(precision)

# Convert scores to numpy arrays
recall_scores = np.array(recall_scores)
precision_scores = np.array(precision_scores)

# Calculate the quartiles for recall and precision scores
recall_iqr = np.percentile(recall_scores, [25, 50, 75])
precision_iqr = np.percentile(precision_scores, [25, 50, 75])

# Store the IQR data for both metrics
iqr_data = {
    "Recall": {
        "25th Percentile (Q1)": recall_iqr[0],
        "50th Percentile (Median)": recall_iqr[1],
        "75th Percentile (Q3)": recall_iqr[2],
        "IQR": recall_iqr[2] - recall_iqr[0]
    },
    "Precision": {
        "25th Percentile (Q1)": precision_iqr[0],
        "50th Percentile (Median)": precision_iqr[1],
        "75th Percentile (Q3)": precision_iqr[2],
        "IQR": precision_iqr[2] - precision_iqr[0]
    }
}

# Create a DataFrame with the IQR data
iqr_df = pd.DataFrame(iqr_data)

print(iqr_df)


                            Recall  Precision
25th Percentile (Q1)      0.249483   0.110745
50th Percentile (Median)  0.999499   0.546912
75th Percentile (Q3)      0.999910   0.700654
IQR                       0.750427   0.589908
