In [None]:
import re

import itertools
import numpy as np
import pandas as pd

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

%matplotlib inline

from pathlib import Path
from glob import glob
from matplotlib import ticker
from multiprocessing import Pool
from tqdm.auto import tqdm

tqdm.pandas()

from utils import points

In [None]:
runs = {
    'S-UNet': glob('outputs/experiment=unet_*/test_predictions/'),
    'FRCNN' : glob('tmp/detection-*/test_predictions/'),
    'D-UNet': glob('tmp/density-*/test_predictions/')
}

In [None]:
def collect(model_name, run, csv_file):
    patch_size = re.search(r'(?<=unet_)\d+', run) or re.search(r'(?<=patch-size-)\d+', run)
    patch_size = int(patch_size.group())
    run = Path(run)

    csv_path = run / csv_file
    if not csv_path.exists():
        print(f'Skipping not found: {csv_path}')
        return pd.DataFrame()
    
    data = pd.read_csv(csv_path, index_col=0)
    data['model'] = model_name
    data['patch_size'] = patch_size
    
    return data

metrics = pd.concat([collect(k, r, 'all_metrics.csv') for k, v in runs.items() for r in v], ignore_index=True)
predictions = pd.concat([collect(k, r, 'all_gt_preds.csv') for k, v in runs.items() for r in v], ignore_index=True)

## Counting Metrics vs Threshold

In [None]:
sns.set_theme(context='notebook', style='ticks', font_scale=1.5)

data = metrics[metrics.thr.between(0, 1)]

id_vars = ['model', 'patch_size', 'imgName', 'thr']
selected_metrics = [c for c in data.columns if 'count/game' in c]
data = data.melt(id_vars=id_vars, value_vars=selected_metrics, var_name='metric')

g = sns.relplot(data=data, kind='line',
                x='thr', y='value', hue='metric', ci='sd', # units='imgName', estimator=None,
                col='model', row='patch_size',
                facet_kws=dict(margin_titles=True))

g.set(ylim=(0, 400), xlim=(0, 1))
for ax in g.axes.flatten():
    ax.grid(True, which='major')
    ax.grid(True, which='minor', ls='dotted')
    ax.get_xaxis().set_minor_locator(ticker.AutoMinorLocator(2))
    ax.get_yaxis().set_minor_locator(ticker.AutoMinorLocator(2))

## F1-Score vs Threshold

In [None]:
sns.set_theme(context='notebook', style='ticks', font_scale=1.5)

data = metrics[metrics.thr.between(0, 1)]

g = sns.relplot(data=data, kind='line',
                x='thr', y='pdet/f1_score', hue='model', ci='sd', # units='imgName', estimator=None,
                col='patch_size', facet_kws=dict(margin_titles=True))

g.set(ylim=(0, 1), xlim=(0, 1))
for ax in g.axes.flatten():
    ax.grid(True, which='major')
    ax.grid(True, which='minor', ls='dotted')
    ax.get_xaxis().set_minor_locator(ticker.AutoMinorLocator(2))
    ax.get_yaxis().set_minor_locator(ticker.AutoMinorLocator(2))

## Mean Average Precision Table & PR Curves

In [None]:
def ap(data):
    sorted_data = data.sort_values('pdet/recall', ascending=False)
    recalls, precisions = sorted_data[['pdet/recall', 'pdet/precision']].values.T
    average_precision = - np.sum(np.diff(recalls) * precisions[:-1])  # sklearn's ap
    return pd.Series({'ap': average_precision})

aps = metrics.groupby(['patch_size', 'model', 'imgName']).apply(ap)

latex_fmt = lambda x: f'{x.mean():.1%}\cf{{{x.std():.1%}}}'
notebook_fmt = lambda x: f'{x.mean():.1%}$\pm${x.std():.1%}'
aps.reset_index().groupby(['patch_size', 'model'])['ap'].aggregate(notebook_fmt).unstack(0)

In [None]:
# PR Curves
sns.set_theme(context='notebook', style='whitegrid', font_scale=1)

def plot_pr(data, label, color):
    
    mean_pr = data.groupby('thr').mean().reset_index().sort_values('pdet/recall', ascending=False)
    mean_recalls = mean_pr['pdet/recall'].values
    mean_precisions = mean_pr['pdet/precision'].values
    
    aps = []
    for group_key, img_group in data.groupby('imgName'):
        img_group = img_group.reset_index().sort_values('pdet/recall', ascending=False)
        recalls = img_group['pdet/recall'].values
        precisions = img_group['pdet/precision'].values
        average_precision = - np.sum(np.diff(recalls) * precisions[:-1])  # sklearn's ap
        aps.append(average_precision)
    
    mean_ap = np.mean(aps)
    plt.plot(mean_recalls, mean_precisions, label=f'{label} (mAP={mean_ap:.1%})', color=color)
    

data = metrics
grid = sns.FacetGrid(data=data, hue='model', col='patch_size')
grid.map_dataframe(plot_pr)
grid.add_legend()
# grid.axes[0].legend()

## TPR vs FDR per Agreement

In [None]:
inA = ~predictions.X.isna()
inB = ~predictions.Xp.isna()
predictions['tp'] = inA & inB
predictions['fp'] = ~inA & inB
predictions['fn'] = inA & ~inB
predictions['agreement'] = predictions.agreement.map('{:g}'.format)

In [None]:
def tpr_fdr_per_agreement(all_gp):

    # TPR per agreement level
    by_agree = all_gp.pivot_table(index='agreement', values=['tp','fp','fn'], aggfunc='sum')
    by_agree['tpr'] = by_agree.tp / (by_agree.tp + by_agree.fn)
    by_agree['fdr'] = by_agree.fp / (by_agree.fp + by_agree.tp.sum())
    
    if 'nan' not in by_agree.index:
        by_agree.loc['nan', ['fn', 'fp', 'tp', 'tpr', 'fdr']] = [0, 0, 0, np.nan, 0]

    # remove unused cols
    by_agree = by_agree.drop(columns=['tp', 'fp', 'fn']).unstack()
    by_agree = by_agree.drop(('tpr','nan'), errors='ignore')
    by_agree = by_agree.drop([('fdr', str(i)) for i in range(1,8)], errors='ignore')
    
    # TPR for agreement >= 4
    selector = all_gp.agreement.isin(('4','5','6','7', 'nan'))
    tp, fn = all_gp.loc[selector, ['tp', 'fn']].sum()
    tpr_gt_4 = tp / (tp + fn)
    by_agree[('tpr', '>=4')] = tpr_gt_4

    # TPR for all
    tp, fn = all_gp[['tp', 'fn']].sum()
    tpr_all = tp / (tp + fn)
    by_agree[('tpr', 'All')] = tpr_all

    by_agree = by_agree.reindex(('tpr', 'fdr'), level=0)    
    return by_agree

pseudo_roc = predictions.groupby(['patch_size', 'model', 'thr']).progress_apply(tpr_fdr_per_agreement)
# flatten column multiindex
pseudo_roc.columns = [col[1] if col[0] == 'tpr' else col[0] for col in pseudo_roc.columns.values]
pseudo_roc

In [None]:
sns.set_theme(context='notebook', style='whitegrid', font_scale=1)

data = pseudo_roc.reset_index().melt(id_vars=['patch_size', 'model', 'thr', 'fdr'], var_name='agreement', value_name='tpr')
data = data.sort_values(['patch_size', 'model', 'thr', 'fdr'])
# data = data[ ~((data.model == 'S-UNet') & ((data.thr < 0) | (data.thr >= 0.9))) ]

grid = sns.FacetGrid(data=data, row='agreement', col='patch_size', hue='model', sharex=True, legend_out=True, margin_titles=True)
grid.map(plt.plot, 'fdr', 'tpr')
grid.add_legend()
#for ax in grid.axes.flatten():
#    ax.axis('equal')
    
#grid.set(ylim=(0,1), xlim=(0, 1))

## Score vs Agreement

In [None]:
best_thrs = metrics.set_index('thr').groupby(['model', 'patch_size', 'imgName'])['count/game-3'].idxmin()
best_thrs = best_thrs.reset_index().to_records(index=False).tolist()
data = predictions.set_index(['model', 'patch_size', 'imgName', 'thr']).loc[best_thrs].reset_index()

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))

order = ['nan'] + [str(i) for i in range(1, 8)]
sns.boxenplot(y='model', hue='agreement', x='score', hue_order=order, data=data, orient='h', palette='flare', ax=ax)
plt.legend(loc='upper right', ncol=1)