In [1]:
import numpy as np
import matplotlib.pyplot as plt
import cv2

from auxiliary import values as v
from auxiliary.utils.colors import bcolors as c
from auxiliary.data import imaging
from auxiliary.utils.timer import LoadingBar

from filtering import cardiac_region as cr
from filtering.run_filter_tissue import filter_by_tissue
from feature_extraction.feature_extractor import filter_by_margin, filter_by_volume
from auxiliary.data.dataset_ht import HtDataset
from nuclei_segmentation.run_cellpose import predict
import json


ds = HtDataset()

GPU activated: False


In [2]:
specimen = '0806_E5'
img_path = v.data_path + 'Gr1/RawImages/Nuclei/20190208_E2_DAPI_decon_0.5_crop.nii.gz'
img_path_gt = v.data_path + 'Gr1/Segmentation/Nuclei/20190208_E2_nuclei_mask_crop_GT.nii.gz'

img = imaging.read_image(img_path, axes='XYZ', verbose=1)
img_gt = imaging.read_image(img_path_gt, axes='XYZ', verbose=1)

[94mReading NIfTI[0m: /run/user/1003/gvfs/smb-share:server=tierra.cnic.es,share=sc/LAB_MT/LAB/Ignacio/Gr1/RawImages/Nuclei/20190208_E2_DAPI_decon_0.5_crop.nii.gz
[94mReading NIfTI[0m: /run/user/1003/gvfs/smb-share:server=tierra.cnic.es,share=sc/LAB_MT/LAB/Ignacio/Gr1/Segmentation/Nuclei/20190208_E2_nuclei_mask_crop_GT.nii.gz


In [11]:
import SimpleITK as sitk
import pandas as pd
# from ctrl.algorithm.image_overlap import fast_image_overlap3d
from auxiliary.utils.timer import LoadingBar

import concurrent.futures


def pixel_accuracy(pred, gt):
    # Ensure binary masks
    pred = (pred > 0).astype(np.uint8)
    gt = (gt > 0).astype(np.uint8)
    
    correct = np.sum(pred == gt)
    total = pred.size
    return correct / total

def iou(pred, gt):
    pred = (pred > 0).astype(np.uint8)
    gt = (gt > 0).astype(np.uint8)
    
    intersection = np.logical_and(pred, gt).sum()
    union = np.logical_or(pred, gt).sum()
    
    if union == 0:
        return 0.0
    
    return intersection / union

def precision(pred, gt):
    pred = (pred > 0).astype(np.uint8)
    gt = (gt > 0).astype(np.uint8)
    
    true_positive = np.logical_and(pred == 1, gt == 1).sum()
    false_positive = np.logical_and(pred == 1, gt == 0).sum()
    
    if true_positive + false_positive == 0:
        return 0.0
    
    return true_positive / (true_positive + false_positive)

def recall(pred, gt):
    pred = (pred > 0).astype(np.uint8)
    gt = (gt > 0).astype(np.uint8)
    
    true_positive = np.logical_and(pred == 1, gt == 1).sum()
    false_negative = np.logical_and(pred == 0, gt == 1).sum()
    
    if true_positive + false_negative == 0:
        return 0.0
    
    return true_positive / (true_positive + false_negative)

def dice_coefficient(pred, gt):
    pred = (pred > 0).astype(np.uint8)
    gt = (gt > 0).astype(np.uint8)
    
    intersection = np.logical_and(pred, gt).sum()
    
    return 2 * intersection / (pred.sum() + gt.sum())


def f1_score(pred, gt):
    p = precision(pred, gt)
    r = recall(pred, gt)
    
    if p + r == 0:
        return 0.0
    
    return 2 * (p * r) / (p + r)

def hausdorff_distance(pred, gt):
    pred = (pred > 0).astype(np.uint8)
    gt = (gt > 0).astype(np.uint8)
    
    pred_image = sitk.GetImageFromArray(pred)
    gt_image = sitk.GetImageFromArray(gt)
    
    hausdorff_filter = sitk.HausdorffDistanceImageFilter()
    hausdorff_filter.Execute(pred_image, gt_image)
    
    return hausdorff_filter.GetHausdorffDistance()

def jacquard_index(pred, gt):
    pred_cells = [cell for cell in np.unique(pred) if cell != 0]
    gt_cells = [cell for cell in np.unique(gt) if cell != 0]
    
    # df_jaccard = fast_image_overlap(
    #     mother_seg=pred, daughter_seg=gt,
    #     mother_label=pred_cells, daughter_label=gt_cells,
    #     method='jaccard', ds=1, verbose=False
    # )
    
    jaccard_results = []

    for pred_cell in pred_cells:
        pred_mask = (pred == pred_cell)
        for gt_cell in gt_cells:
            gt_mask = (gt == gt_cell)
            
            # Compute Intersection and Union
            intersection = np.logical_and(pred_mask, gt_mask).sum()
            union = np.logical_or(pred_mask, gt_mask).sum()
            
            jaccard_index = intersection / union if union > 0 else 0
            
            jaccard_results.append({'pred': pred_cell, 'gt': gt_cell, 'jaccard': jaccard_index})
    
    df_jaccard = pd.DataFrame(jaccard_results)
    df_jaccard.columns = ['pred', 'gt', 'jaccard']
    
    # Indentify for each cell, the rarget cell that maximizes their jaccard index
    df_jaccard = df_jaccard.loc[df_jaccard.groupby('pred')['jaccard'].idxmax()]
    
    # Add the missing reference cells anc calculate the volumen of each reference cells
    missing_cells = set(gt_cells) - set(df_jaccard['gt'].values)
    
    rows = []
    if len(missing_cells) > 0:
        rows = [{'pred': 0, 'gt': cell, 'jaccard': 0} for cell in missing_cells]
            
    df_missing = pd.DataFrame(rows)
    df_jaccard = pd.concat([df_jaccard, df_missing], ignore_index=True)
            
    # Add corresponding volumnes (in voxel units)
    unique_gt_cells = np.unique(gt)
    cell_gt_vol = np.array([np.sum(gt == cell) for cell in unique_gt_cells])
    cell_gt_vol_dict = dict(zip(unique_gt_cells, cell_gt_vol))
    
    df_jaccard['volume'] = df_jaccard.apply(lambda x: cell_gt_vol_dict[x['gt']], axis=1)
    
    # Calculate weighted jaccard index by multiplying the jaccard index by the cell volume
    df_jaccard['weighted_jaccard'] = df_jaccard['jaccard'] * df_jaccard['volume']
    
    # The volume averaged jaccard index is obtained by summing all the weighted jaccard index and divide them by the total volume of the ground-truth tissue.
    
    total_cell_volume = sum(df_jaccard['volume'].values)
    sum_weighted_jaccard = sum(df_jaccard['weighted_jaccard'].values)
    vji = sum_weighted_jaccard / total_cell_volume
    
    return vji

# Function to compute all metrics
def compute_metrics(pred, gt):
    print('Computing metrics...')
    
    metrics = {}
    metrics['Pixel Accuracy'] = pixel_accuracy(pred, gt)
    metrics['IoU'] = iou(pred, gt)
    metrics['Precision'] = precision(pred, gt)
    metrics['Recall'] = recall(pred, gt)
    metrics['F1-Score'] = f1_score(pred, gt)
    metrics['Dice Coefficient'] = dice_coefficient(pred, gt)
    metrics['Hausdorff Distance'] = hausdorff_distance(pred, gt)
    metrics['VJI'] = jacquard_index(pred, gt)
    
    return metrics

def compute_metrics_for_cell(gt_cell, pred_cells, pred, gt):
    gt_mask = (gt == gt_cell)
    best_iou = 0.0
    best_precision = 0.0
    best_recall = 0.0
    best_dice = 0.0
    best_hausdorff = float('inf')
    
    # Loop through all cells in the prediction
    for pred_cell in pred_cells:
        pred_mask = (pred == pred_cell)
        
        # Compute Jaccard Index (IoU)
        intersection = np.logical_and(pred_mask, gt_mask).sum()
        union = np.logical_or(pred_mask, gt_mask).sum()
        iou = intersection / union if union > 0 else 0.0
        if iou > best_iou:
            best_iou = iou

        # Compute Precision
        true_positive = np.logical_and(pred_mask, gt_mask).sum()
        false_positive = np.logical_and(pred_mask, ~gt_mask).sum()
        precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 0.0
        if precision > best_precision:
            best_precision = precision

        # Compute Recall
        false_negative = np.logical_and(~pred_mask, gt_mask).sum()
        recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0.0
        if recall > best_recall:
            best_recall = recall
        
        # Compute Dice Coefficient
        dice = 2 * intersection / (pred_mask.sum() + gt_mask.sum()) if (pred_mask.sum() + gt_mask.sum()) > 0 else 0.0
        if dice > best_dice:
            best_dice = dice
        
        # Compute Hausdorff Distance
        pred_image = sitk.GetImageFromArray(pred_mask.astype(np.uint8))
        gt_image = sitk.GetImageFromArray(gt_mask.astype(np.uint8))
        hausdorff_filter = sitk.HausdorffDistanceImageFilter()
        hausdorff_filter.Execute(pred_image, gt_image)
        hausdorff_distance = hausdorff_filter.GetHausdorffDistance()
        if hausdorff_distance < best_hausdorff:
            best_hausdorff = hausdorff_distance
    
    return best_iou, best_precision, best_recall, best_dice, best_hausdorff


def compute_metrics_cell_aware(pred, gt):
    print('Computing metrics (cell-aware)...')
    
    pred_cells = np.unique(pred)[1:]  # Ignore background (0)
    gt_cells = np.unique(gt)[1:]  # Ignore background (0)
    
    metrics = {
        'Pixel Accuracy': 0.0,
        'Mean IoU': 0.0,
        'Mean Precision': 0.0,
        'Mean Recall': 0.0,
        'Mean Dice Coefficient': 0.0,
        'Mean Hausdorff Distance': 0.0,
    }
    
    iou_values = []
    precision_values = []
    recall_values = []
    dice_values = []
    hausdorff_values = []

    # Parallelize over the gt_cells
    with concurrent.futures.ThreadPoolExecutor() as executor:
        futures = [executor.submit(compute_metrics_for_cell, gt_cell, pred_cells, pred, gt) for gt_cell in gt_cells]
        
        bar = LoadingBar(len(futures))

        for future in concurrent.futures.as_completed(futures):
            best_iou, best_precision, best_recall, best_dice, best_hausdorff = future.result()
            iou_values.append(best_iou)
            precision_values.append(best_precision)
            recall_values.append(best_recall)
            dice_values.append(best_dice)
            hausdorff_values.append(best_hausdorff)
            
            bar.update()
            
        bar.end()

    # Average metrics
    metrics['Mean IoU'] = np.mean(iou_values)
    metrics['Mean Precision'] = np.mean(precision_values)
    metrics['Mean Recall'] = np.mean(recall_values)
    metrics['Mean Dice Coefficient'] = np.mean(dice_values)
    metrics['Mean Hausdorff Distance'] = np.mean(hausdorff_values)
    
    # Pixel Accuracy for the entire image
    pred_flat = pred.flatten()
    gt_flat = gt.flatten()
    correct = np.sum(pred_flat == gt_flat)
    total = pred_flat.size
    metrics['Pixel Accuracy'] = correct / total

    return metrics


# results = []
df = pd.DataFrame(columns=['Method', 'Pixel Accuracy', 'IoU', 'Precision', 'Recall', 'F1-Score', 'Dice Coefficient', 'Hausdorff Distance', 'NO_cells', 'VJI'])

In [None]:
results = pd.read_csv('results.csv')

In [15]:
pred_path = v.data_path + 'Gr1/Segmentation/Nuclei/20190208_E2_nuclei_mask_crop_all.nii.gz'
pred = imaging.read_image(pred_path, axes='XYZ', verbose=1)

metric = compute_metrics(pred, img_gt)
metric_cell_aware = compute_metrics_cell_aware(pred, img_gt)

results.append({
    'Method': '2D_5_6_45_M_BI_AD',
    'Pixel Accuracy': metric['Pixel Accuracy'],
    'IoU': metric['IoU'],
    'Mean IoU': metric_cell_aware['Mean IoU'],
    'Precision': metric['Precision'],
    'Mean Precision': metric_cell_aware['Mean Precision'],
    'Recall': metric['Recall'],
    'Mean Recall': metric_cell_aware['Mean Recall'],
    'F1-Score': metric['F1-Score'],
    'Dice Coefficient': metric['Dice Coefficient'],
    'Mean Dice Coefficient': metric_cell_aware['Mean Dice Coefficient'],
    'Hausdorff Distance': metric['Hausdorff Distance'],
    'Mean Hausdorff Distance': metric_cell_aware['Mean Hausdorff Distance'],
    'NO_cells': len(np.unique(pred)) - 1,
    'VJI': metric['VJI']
})

print(json.dumps(results[-1], indent=4))

[94mReading NIfTI[0m: /run/user/1003/gvfs/smb-share:server=tierra.cnic.es,share=sc/LAB_MT/LAB/Ignacio/Gr1/Segmentation/Nuclei/20190208_E2_nuclei_mask_crop_all.nii.gz
Computing metrics...
Computing metrics (cell-aware)...
{
    "Method": "2D_5_6_45_M_EQ_AD_BI",
    "Pixel Accuracy": 0.8124350649350649,
    "IoU": 0.5908118598688218,
    "Mean IoU": 0.36269017034601947,
    "Precision": 0.9617544304673178,
    "Mean Precision": 0.6251479983316262,
    "Recall": 0.605026656511805,
    "Mean Recall": 0.49177659353742464,
    "F1-Score": 0.7427803057962368,
    "Dice Coefficient": 0.7427803057962368,
    "Mean Dice Coefficient": 0.4775583695198937,
    "Hausdorff Distance": 17.69180601295413,
    "Mean Hausdorff Distance": 10.749163107860118,
    "NO_cells": 172,
    "VJI": 0.31265671080708357
}


In [16]:
df = pd.DataFrame(results)
print(df)
df.to_csv('results.csv', index=False)

                 Method  Pixel Accuracy       IoU  Mean IoU  Precision  \
0        2D_5_6_45_M_AD        0.859877  0.696358  0.426008   0.958647   
1        2D_5_6_45_M_BI        0.850231  0.672040  0.449228   0.971335   
2     2D_5_6_45_M_AD_BI        0.859429  0.696451  0.442156   0.954204   
3  2D_5_6_45_M_EQ_AD_BI        0.812435  0.590812  0.362690   0.961754   

   Mean Precision    Recall  Mean Recall  F1-Score  Dice Coefficient  \
0        0.632786  0.717923     0.594904  0.821004          0.821004   
1        0.713063  0.685638     0.582741  0.803857          0.803857   
2        0.687012  0.720535     0.606296  0.821068          0.821068   
3        0.625148  0.605027     0.491777  0.742780          0.742780   

   Mean Dice Coefficient  Hausdorff Distance  Mean Hausdorff Distance  \
0               0.542976           14.035669                10.130524   
1               0.569312           13.038405                 9.464205   
2               0.562793           10.198039     