In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import json
import tifffile as tiff
from scipy.ndimage import center_of_mass
from evaluate import load_roi, roi2masks_from_demix_result, count_matches, calc_f1_scores

In [None]:
with open('settings.txt') as f:
    settings = json.load(f)

SUMMARY_IMAGE_DIR = settings['summary_image_dir']
CONSENSUS_GT_DIR = settings['consensus_gt_dir']
INDIVIDUAL_GT_DIR = settings['individual_gt_dir']
GT_IDS = settings['individual_gt_ids']
REPRESENTATIVE_IOU = settings['representative_iou']

DEMIX_DIR = '@@@DEMIX_DIR'
EVAL_DIR = '@@@EVAL_DIR'
BASENAME = '@@@BASENAME'

In [None]:
def create_contour_overlay(contour_image, color):
    overlay = np.zeros(contour_image.shape + (4,), dtype='uint8')
    overlay[:, :, 0] = color[0]
    overlay[:, :, 1] = color[1]
    overlay[:, :, 2] = color[2]
    overlay[:, :, 3] = contour_image * 255
    return overlay

In [None]:
summary_image = tiff.imread(SUMMARY_IMAGE_DIR + '/' + BASENAME + '.tif')
plt.imshow(summary_image, interpolation='bilinear', cmap='gray')
plt.axis('off')
plt.title('Summary image')
plt.show()

In [None]:
with open(DEMIX_DIR + '/' + BASENAME + '.json') as f:
    demix_result = json.load(f)

eval_mask_images, eval_contour_image = roi2masks_from_demix_result(demix_result, summary_image.shape)
tiff.imsave(EVAL_DIR + '/' + BASENAME + '.tif', eval_mask_images, photometric='minisblack')
num_eval_masks = len(eval_mask_images)

plt.axis('off')
plt.imshow(summary_image, interpolation='bilinear', cmap='gray')
plt.imshow(create_contour_overlay(eval_contour_image, [255, 255, 0]), interpolation='bilinear')
plt.title('Predicted masks')
plt.show()

In [None]:
gt_mask_images, gt_contour_image = load_roi(CONSENSUS_GT_DIR, BASENAME, summary_image.shape)
num_gt_masks = len(gt_mask_images)

plt.axis('off')
plt.imshow(summary_image, interpolation='bilinear', cmap='gray')
plt.imshow(create_contour_overlay(gt_contour_image, [0, 255, 255]), interpolation='bilinear')
plt.title('GT masks')
plt.show()

In [None]:
thresholds = np.array(range(0, 100, 1)) / 100
counts, IoU = count_matches(eval_mask_images, gt_mask_images, thresholds)
f1, precision, recall = calc_f1_scores(counts)

In [None]:
plt.figure(figsize=(17, 5))
plt.suptitle('Accuracy statistics', fontsize=16)

plt.subplot(1, 3, 1)
plt.axis('off')
plt.imshow(summary_image, interpolation='bilinear', cmap='gray')
plt.imshow(create_contour_overlay(gt_contour_image, [0, 255, 255]), interpolation='bilinear')
plt.imshow(create_contour_overlay(eval_contour_image, [255, 255, 0]), interpolation='bilinear')
for i in range(num_eval_masks):
    p = center_of_mass(eval_mask_images[i])
    plt.text(p[1], p[0], str(i), color='yellow')
for i in range(num_gt_masks):
    p = center_of_mass(gt_mask_images[i])
    plt.text(p[1], p[0], str(i), color='cyan')
plt.title('Predicted (yellow) vs GT (cyan)')

ax = plt.subplot(1, 3, 2)
ax.set_xticks([x - 0.5 for x in range(num_gt_masks)], minor=True)
ax.set_yticks([y - 0.5 for y in range(num_eval_masks)], minor=True)
ax.tick_params(axis='both', which='both', length=0)
plt.grid(which='minor')
plt.xticks(list(range(num_gt_masks)))
plt.yticks(list(range(num_eval_masks)))
plt.imshow(IoU, aspect='equal', vmin=0, vmax=1, cmap='inferno')
plt.colorbar()
plt.xlabel('Ground Truth')
plt.ylabel('Prediction')
plt.title('IoU Matrix')

plt.subplot(1, 3, 3)
plt.axis('square')
plt.xlim(0, 1)
plt.ylim(0, 1)
plt.plot(thresholds, precision, label='Precision')
plt.plot(thresholds, recall, label='Recall')
plt.plot(thresholds, f1, label='F1 score')
plt.legend(loc='upper right')
plt.ylabel('Score')
plt.xlabel('IoU Threshold')
plt.vlines(REPRESENTATIVE_IOU, 0, 1, colors='gray', linestyles='dashed') 
indices = np.where(thresholds >= REPRESENTATIVE_IOU)
plt.title('F1 = %.2f at IoU = %.1f' % (f1[indices[0][0]], REPRESENTATIVE_IOU))

plt.show()

In [None]:
index = ['Prediction_%2.2d' % i for i in range(num_eval_masks)]
columns = ['GT_%2.2d' % i for i in range(num_gt_masks)]
df = pd.DataFrame(IoU, index=index, columns=columns)
df.to_csv(EVAL_DIR + '/' + BASENAME + '_IoU.csv')

with pd.option_context('display.max_rows', None, 'display.max_columns', None):
    display(df)

In [None]:
df = pd.DataFrame(counts, columns=['TruePos', 'FalsePos', 'FalseNeg'])
df.insert(0, 'IoU_Thresh', thresholds)
df['Precision'] = precision
df['Recall'] = recall
df['F1'] = f1
df.to_csv(EVAL_DIR + '/' + BASENAME + '_stats.csv', index=False)

with pd.option_context('display.max_rows', None, 'display.max_columns', None):
    display(df)

In [None]:
scores = {}
for gt_id in GT_IDS:
    gt_mask_images, gt_contour_image = load_roi(INDIVIDUAL_GT_DIR + '/' + gt_id, BASENAME, summary_image.shape)

    plt.axis('off')
    plt.imshow(summary_image, interpolation='bilinear', cmap='gray')
    plt.imshow(create_contour_overlay(gt_contour_image, [0, 255, 255]), interpolation='bilinear')
    plt.imshow(create_contour_overlay(eval_contour_image, [255, 255, 0]), interpolation='bilinear')
    plt.title('Predicted (yellow) vs GT (cyan) by ' + gt_id)
    plt.show()
    
    counts, _ = count_matches(eval_mask_images, gt_mask_images, thresholds)
    f1, precision, recall = calc_f1_scores(counts)
    scores[gt_id] = (f1, precision, recall)
    
    df = pd.DataFrame(counts, columns=['TruePos', 'FalsePos', 'FalseNeg'])
    df.insert(0, 'IoU_Thresh', thresholds)
    df['Precision'] = precision
    df['Recall'] = recall
    df['F1'] = f1
    df.to_csv(EVAL_DIR + '/' + BASENAME + '_' + gt_id + '.csv', index=False)

In [None]:
def plot_scores(thresholds, scores, idx, title):
    plt.axis('square')
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    for gt_id in GT_IDS:
        plt.plot(thresholds, scores[gt_id][idx], label=gt_id)
    plt.legend(loc='upper right')
    plt.ylabel(title)
    plt.xlabel('IoU Threshold')
    plt.vlines(REPRESENTATIVE_IOU, 0, 1, colors='gray', linestyles='dashed') 
    plt.title(title)

plt.figure(figsize=(17, 5))
plt.suptitle('Accuracy Variance', fontsize=16)

plt.subplot(1, 3, 1)
plot_scores(thresholds, scores, 0, 'F1 Score')

plt.subplot(1, 3, 2)
plot_scores(thresholds, scores, 1, 'Precision')

plt.subplot(1, 3, 3)
plot_scores(thresholds, scores, 2, 'Recall')

plt.show()