In [1]:
from deep_radiologist.actions import locate_peaks, init_data
from deep_radiologist.data_loading import _load_point_data
import torchio as tio
import numpy as np
from scipy import spatial
import napari
import yaml
from yaml.loader import SafeLoader

In [2]:
# def _get_acc_metrics(self, y_hat, y, k=3):
#         """Calculates accuracy metrics for a set of predicted and ground truth coordinates.

#         Is a true positive if the distance between the predicted and closest ground truth coordinate
#         is less than the correct_prediction_distance config parameter and that ground truth coordinate
#         doesn't already have a better matching prediction (tested up to k closest matches). Is a false
#         positive if the distance is greater than the correct_prediction_distance parameter or it
#         already has a closer true positive. Is a false negative if the ground truth does not have a
#         corresponding true positive.

#         Args:
#             y_hat (np.ndarray): predicted coordinates
#             y (np.ndarray): ground truth coordinates

#         Returns:
#             tp (float): true positives
#             fp (float): false positives
#             fn (float): false negatives
#             loc_errs (np.ndarray): location errors
#         """

#         correct_prediction_distance = 15

#         tree = spatial.cKDTree(y_hat)
#         closest_dists, closest_nbrs = tree.query(y, k=k)

#         y_match = list()
#         y_hat_match = list()
#         dists = list()

#         for i in range(k):
#             nbrs_k = closest_nbrs[:, i]
#             dists_k = closest_dists[:, i]

#             # sort by closest distance
#             sort_idx = np.argsort(dists_k)
#             nbrs_k = nbrs_k[sort_idx]
#             dists_k = dists_k[sort_idx]

#             for j in range(len(nbrs_k)):
#                 if j not in y_hat_match and y[j, i] not in y_match:
#                     y_hat_match.append(j)
#                     y_match.append(y[j, i])
#                     dists.append(dists_k[j])

#         dists = np.array(dists)

#         tp = len(dists[dists < correct_prediction_distance])
#         fp = len(y_hat) - tp
#         fn = len(y) - tp

#         loc_errors = dists[dists < correct_prediction_distance]

#         if len(loc_errors) == 0:
#             loc_errors = np.array([0])

#         fp_prediction = y_

#         things_to_plot = [y, fp_prediction, fn_groundtruth, tp_groundtruth, tp_prediction]

#         return tp, fp, fn, loc_errors, things_to_plot

# def evaluate(x, y, y_hat, plot=True):
#     mct = tio.ScalarImage(x)
#     prediction = tio.ScalarImage(y_hat)
#     viewer = napari.view_image(mct.to_numpy(), name="mct")
#     viewer.add_image(prediction.to_numpy(), name="prediction")



def get_acc_metrics(y_hat, y):
    """Calculates accuracy metrics for a set of predicted and ground truth coordinates.

    Is a true positive if the distance between the predicted and closest ground truth coordinate
    is less than the correct_prediction_distance config parameter. Is a false positive if the 
    distance is greater than the correct_prediction_distance parameter or it already has a closer
    true positive. Is a false negative if the ground truth does not have a corresponding true
    positive.

    Args:
        y_hat (np.ndarray): predicted coordinates
        y (np.ndarray): ground truth coordinates

    Returns:
        tp (float): true positives
        fp (float): false positives
        fn (float): false negatives
        loc_errs (np.ndarray): location errors
    """

    tree = spatial.cKDTree(y)
    closest_dists, closest_nbrs = tree.query(y_hat, k=1)

    # if predictions are within distance of the same point, only keep the first one
    # this is to avoid repeated counting of true positives that are actually false positives
    # it doesn't matter which one is closer in this case, as we are just making a count
    removed_dup_indx = np.unique(closest_nbrs, return_index=True)[1]
    mask = np.zeros(closest_nbrs.shape, dtype='bool')
    mask[removed_dup_indx] = True

    true_positive = (closest_dists <= 5) & mask

    tp = len(true_positive[true_positive])
    fp = len(true_positive[~true_positive])
    fn = y.shape[0] - tp
    loc_errors = closest_dists[true_positive]

    if len(loc_errors) == 0:
        loc_errors = np.array([0])

    tp_groundtruth = closest_nbrs[true_positive]
    fn_mask = np.ones(y.shape[0], dtype='bool')
    fn_mask[tp_groundtruth] = False

    all_ground_truth = y
    fp_prediction = y_hat[~true_positive]
    fn_to_plot = y[fn_mask]
    tp_prediction = y_hat[true_positive]

    things_to_plot = [all_ground_truth, fp_prediction, fn_to_plot, tp_prediction]

    print(f'True positives: {tp}, False positives: {fp}, False negatives: {fn}, N Real values: {y.shape[0]}, N Predicted values: {y_hat.shape[0]}')
    print(f'Percent correctly predicted {tp / y.shape[0] * 100}%')
    print(f'Mean Localisation error: {loc_errors.mean()}')
    print(f'SD Localisation error: {loc_errors.std()}')

    return tp, fp, fn, loc_errors, things_to_plot

def evaluate(x, y, y_hat, config, plot=True):
    mct = tio.ScalarImage(x)
    prediction = tio.ScalarImage(y_hat)

    with open(config, "r") as f:
        config = yaml.load(f, Loader=SafeLoader)

    data = init_data(config, run_internal_setup_func=False)
    preprocess = data.get_preprocessing_transform()

    print('Preprocessing volume for plotting')
    mct = preprocess(mct)

    prediction_locations = locate_peaks(
        y_hat,
        save=True,
        plot=False,
        peak_min_val=0.5,
    )
    ground_truth_locations=np.loadtxt(
        y,
        delimiter=',',
        dtype=np.float
    ).astype(int).T
    
    # flip axis 0 and 1
    ground_truth_locations[:,0] = mct.shape[1] - ground_truth_locations[:,0]
    ground_truth_locations[:,1] = mct.shape[2] - ground_truth_locations[:,1]

    tp, fp, fn, loc_errors, things_to_plot = get_acc_metrics(prediction_locations, ground_truth_locations)

    viewer = napari.view_points(things_to_plot[0], name='all ground truth', size=6, face_color='pink')
    viewer.add_image(mct.numpy(), name="x")
    viewer.add_image(prediction.numpy(), name="prediction")
    viewer.add_points(things_to_plot[1], name='fp prediction', size=6, face_color='red')
    viewer.add_points(things_to_plot[2], name='fn', size=6, face_color='yellow')
    viewer.add_points(things_to_plot[3], name='tp prediction', size=6, face_color='green')

    return tp, fp, fn, loc_errors

# Evaluation

Fiddler crab corneas

Image 1:

In [4]:
x = './dataset/fiddlercrab_corneas/whole/test_images_11.248750654736355/flammula_20190925_male_left-image.nii'
y = './output/flammula_20190925_male_left-image.zoo_fiddlercrab_corneas_version_4_checkpoints_last_prediction.nii.peaks.csv'
y_hat = './output/flammula_20190925_male_left-image.logs_fiddlercrab_corneas_lightning_logs_version_0_checkpoints_last_prediction.nii'
config = './configs/fiddlercrab_corneas.yaml'
evaluate(x, y, y_hat, config)

TorchIO version: 0.18.88
Preprocessing volume for plotting


TypeError: locate_peaks() got an unexpected keyword argument 'peak_min_dist'

Image 2:

In [None]:
x = './dataset/fiddlercrab_corneas/whole/test_images_10/flammula_20190925_male_left-image.nii.gz'
y = './dataset/fiddlercrab_corneas/whole/test_labels_10/flammula_20190925_male_left-corneas.csv'
y_hat = './output/flammula_20190925_male_left-image.zoo_fiddlercrab_corneas_version_4_checkpoints_last_prediction.nii.gz'
config = './configs/fiddlercrab_corneas.yaml'
evaluate(x, y, y_hat, config)

TorchIO version: 0.18.76
Preprocessing volume for plotting
Locating peaks...
Saving peaks...


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  dtype=np.float


True positives: 4001, False positives: 2205, False negatives: 4131, N Real values: 8132, N Predicted values: 6206
Percent correctly predicted 49.200688637481555%
Mean Localisation error: 3.292389412975271
SD Localisation error: 0.7647324972620112


(4001,
 2205,
 4131,
 array([3.31662479, 4.12310563, 3.74165739, ..., 4.        , 4.12310563,
        4.        ]))

Fiddler crab rhabdoms

In [None]:
x = './dataset/fiddlercrab_rhabdoms/whole/test_images_10/dampieri_male_16-image.nii.gz'
y = './dataset/fiddlercrab_rhabdoms/whole/test_labels_10/dampieri_male_16-rhabdoms.csv'
y_hat = './output/dampieri_male_16-image.zoo_fiddlercrab_rhabdoms_version_4_checkpoints_last_prediction.nii.gz'

evaluate(x, y, y_hat)

FileNotFoundError: File not found: "output/dampieri_male_16-image.zoo_fiddlercrab_rhabdoms_version_4_checkpoints_last_prediction.nii.gz"