In [7]:
from deep_radiologist.actions import locate_peaks
from deep_radiologist.data_loading import _load_point_data
import torchio as tio
import numpy as np
from scipy import spatial
import napari

In [12]:
def get_acc_metrics(y_hat, y):
    """Calculates accuracy metrics for a set of predicted and ground truth coordinates.

    Is a true positive if the distance between the predicted and closest ground truth coordinate
    is less than the correct_prediction_distance config parameter. Is a false positive if the 
    distance is greater than the correct_prediction_distance parameter or it already has a closer
    true positive. Is a false negative if the ground truth does not have a corresponding true
    positive.

    Args:
        y_hat (np.ndarray): predicted coordinates
        y (np.ndarray): ground truth coordinates

    Returns:
        tp (float): true positives
        fp (float): false positives
        fn (float): false negatives
        loc_errs (np.ndarray): location errors
    """

    tree = spatial.cKDTree(y)
    closest_dists, closest_nbrs = tree.query(y_hat, k=1)

    # if predictions are within distance of the same point, only keep the first one
    # this is to avoid repeated counting of true positives that are actually false positives
    # it doesn't matter which one is closer in this case, as we are just making a count
    removed_dup_indx = np.unique(closest_nbrs, return_index=True)[1]
    mask = np.zeros(closest_nbrs.shape, dtype='bool')
    mask[removed_dup_indx] = True

    true_positive = (closest_dists <= 5) & mask

    tp = len(true_positive[true_positive])
    fp = len(true_positive[~true_positive])
    fn = y.shape[0] - tp
    loc_errors = closest_dists[true_positive]

    if len(loc_errors) == 0:
        loc_errors = np.array([0])

    tp_groundtruth = closest_nbrs[true_positive]
    fn_mask = np.ones(y.shape[0], dtype='bool')
    fn_mask[tp_groundtruth] = False

    all_ground_truth = y
    fp_prediction = y_hat[~true_positive]
    fn_to_plot = y[fn_mask]
    tp_groundtruth = y_hat[closest_nbrs[true_positive]]
    tp_prediction = y_hat[true_positive]

    things_to_plot = [all_ground_truth, fp_prediction, fn_to_plot, tp_groundtruth, tp_prediction]

    print(f'True positives: {tp}, False positives: {fp}, False negatives: {fn}, N Real values: {y.shape[0]}, N Predicted values: {y_hat.shape[0]}')
    print(f'Percent correctly predicted {tp / y.shape[0] * 100}%')
    print(f'Mean Localisation error: {loc_errors.mean()}')

    return tp, fp, fn, loc_errors, things_to_plot

def evaluate(x, y, y_hat, plot=True):
    mct = tio.ScalarImage(x)
    prediction = tio.ScalarImage(y_hat)

    prediction_locations = locate_peaks(
        y_hat,
        save=True,
        plot=False,
        peak_min_dist=4,
        peak_min_val=0.2,
    )
    ground_truth_locations=np.loadtxt(
        y,
        delimiter=',',
        dtype=np.float
    ).astype(int).T

    tp, fp, fn, loc_errors, things_to_plot = get_acc_metrics(prediction_locations, ground_truth_locations)

    viewer = napari.view_points(things_to_plot[0], name='all ground truth', size=2, face_color='pink')
    viewer.add_points(things_to_plot[1], name='fp prediction', size=2, face_color='red')
    viewer.add_points(things_to_plot[2], name='fn', size=2, face_color='yellow')
    viewer.add_points(things_to_plot[3], name='tp groundtruth', size=2, face_color='blue')
    viewer.add_points(things_to_plot[4], name='tp prediction', size=2, face_color='green')

    return tp, fp, fn, loc_errors

True positives: 78, False positives: 5766, False negatives: 5174, N Real values: 5252, N Predicted values: 5844
Percent correctly predicted 1.4851485148514851%
Mean Localisation error: 3.847670550155179


<Points layer 'tp prediction' at 0x7f0545f3cfa0>

# Evaluation

Fiddler crab corneas

In [None]:


x = './dataset/fiddlercrab_corneas/whole/test_images_10/dampieri_male_16-image.nii.gz'
y = './dataset/fiddlercrab_corneas/whole/test_labels_10/dampieri_male_16-rhabdoms.csv'
y_hat = './output/dampieri_male_16-image.logs_fiddlercrab_rhabdoms_lightning_logs_version_4_checkpoints_last_prediction.nii.gz'

evaluate(x, y, y_hat)

Fiddler crab rhabdoms

In [8]:
x = './dataset/fiddlercrab_rhabdoms/whole/test_images_10/dampieri_male_16-image.nii.gz'
y = './dataset/fiddlercrab_rhabdoms/whole/test_labels_10/dampieri_male_16-rhabdoms.csv'
y_hat = './output/dampieri_male_16-image.logs_fiddlercrab_rhabdoms_lightning_logs_version_4_checkpoints_last_prediction.nii.gz'

evaluate(x, y, y_hat)