In [None]:
import re

import itertools
import numpy as np
import pandas as pd
import random

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

%matplotlib inline

from pathlib import Path
from matplotlib import ticker
import matplotlib.patheffects as pe
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from multiprocessing import Pool
from tqdm.auto import tqdm
from omegaconf import OmegaConf

import scipy
from scipy.spatial.distance import pdist, squareform

from skimage import exposure, transform
from skimage.draw import line_aa
from skimage.color import gray2rgb, rgb2gray
from skimage.io import imread, imsave
from skimage.transform import resize
from skimage.util import montage

from sklearn.metrics import jaccard_score
from sklearn.preprocessing import StandardScaler

from datasets import PerineuronalNetsDataset, CellsDataset, PerineuronalNetsRankDataset

from methods.points.match import match
from methods.points.metrics import detection_and_counting
from methods.points.utils import draw_points, draw_groundtruth_and_predictions

tqdm.pandas()

# VGG Dataset

In [None]:
common = dict(root='data/vgg-cells')

detection_kws = {
    'target_': 'detection',
    'target_params': {
        'side': 12
    }
}

density_kws = {
    'target_': 'density',
    'target_params': {
        'mode': 'reflect',
        'k_size': 51,
        'sigma': 5
    }
}

segmentation_kws = {
    'target_': 'segmentation',
    'target_params': {
        'radius': 5,
        'radius_ignore': 6,
        'v_bal': 0.1, 
        'sigma_bal': 3,
        'sep_width': 1,     
        'sigma_sep': 3,
        'lambda_sep': 50
    }
}

detection_dataset = CellsDataset(**detection_kws, **common)
density_dataset = CellsDataset(**density_kws, **common)
segmentation_dataset = CellsDataset(**segmentation_kws, **common)

In [None]:
sample_idx = 0

sample = density_dataset[sample_idx][0][:, :, 0]
detections = detection_dataset[sample_idx][0][1]
density_map = density_dataset[sample_idx][0][:, :, 1]
segmentation_map = segmentation_dataset[sample_idx][0][:, :, 1]
weights_map = segmentation_dataset[sample_idx][0][:, :, 2]

detections = np.clip(detections.astype(int), 0, sample.shape[0] - 1)

sample = matplotlib.cm.jet(sample)
density_map = gray2rgb(density_map / density_map.max())
segmentation_map = gray2rgb(segmentation_map)

sample_with_boxes = sample.copy()
for y0, x0, y1, x1 in detections:
    rect = ((y0, x0, y0, x1),
            (y0, x1, y1, x1),
            (y1, x1, y1, x0),
            (y1, x0, y0, x0))
    for r0, c0, r1, c1 in rect:
        rr, cc, val = line_aa(r0, c0, r1, c1)
        sample_with_boxes[rr, cc, 0] = val

        
fig, axes = plt.subplots(1, 4, figsize=(20, 20))
axes = axes.flatten()
axes[0].imshow(sample)
axes[1].imshow(sample_with_boxes)
axes[2].imshow(density_map)
axes[3].imshow(segmentation_map)

for ax in axes:
    ax.set_axis_off()

plt.imsave('figures/vgg-sample.png', sample)
plt.imsave('figures/vgg-boxes.png', sample_with_boxes)
plt.imsave('figures/vgg-density.png', density_map)
plt.imsave('figures/vgg-segmentation.png', segmentation_map)

In [None]:
ncells = [len(detection_dataset[i][0][1]) for i in range(len(detection_dataset)) ]
f'{sum(ncells)} cells  {np.mean(ncells):.2f}$\pm${np.std(ncells):.2f} cells/image'

In [None]:
def collect_runs(model_name, run, csv_file):
    run = Path(run)
    cfg = OmegaConf.load(run / '.hydra' / 'config.yaml')
    num_samples = cfg['data']['validation']['num_samples'][0]
    seed = cfg['data']['validation']['split_seed']

    csv_path = run / 'test_predictions' / csv_file
    if not csv_path.exists():
        print(f'Skipping not found: {csv_path}')
        return pd.DataFrame()
    
    data = pd.read_csv(csv_path, index_col=0)
    data['model'] = model_name
    data['num_samples'] = num_samples
    data['split_seed'] = seed
    
    return data

In [None]:
runs = {
    'S-UNet': Path('runs/experiment=vgg-cells/segmentation/').glob('unet_*'),
    'FRCNN': Path('runs/experiment=vgg-cells/detection/').glob('fasterrcnn_*'),
    'D-CSRNet': Path('runs/experiment=vgg-cells/density/').glob('csrnet_*')
}

metrics = pd.concat([collect_runs(k, r, 'all_metrics.csv.gz') for k, v in runs.items() for r in v], ignore_index=True)

In [None]:
mean_metrics = metrics.groupby(['model', 'num_samples', 'split_seed', 'thr'])['count/mae'].mean()
best_configs = mean_metrics.groupby(['model', 'num_samples', 'split_seed']).idxmin()
table = mean_metrics.loc[best_configs].groupby(['model', 'num_samples']).apply(lambda x: pd.Series({'mean': x.mean(), 'std': x.std()}))
table = table.unstack(2).apply(lambda x: f'{x["mean"]:.1f} $\pm$ {x["std"]:.1f}', axis=1).unstack(1)
print(table.to_latex(escape=False))
table

# MBM Dataset

In [None]:
common = dict(root='data/mbm-cells')

detection_kws = {
    'target_': 'detection',
    'target_params': {
        'side': 20
    }
}

density_kws = {
    'target_': 'density',
    'target_params': {
        'mode': 'reflect',
        'k_size': 51,
        'sigma': 10
    }
}

segmentation_kws = {
    'target_': 'segmentation',
    'target_params': {
        'radius': 12,
        'radius_ignore': 15,
        'v_bal': 0.1, 
        'sigma_bal': 5,
        'sep_width': 1,     
        'sigma_sep': 4,
        'lambda_sep': 50
    }
}

detection_dataset = CellsDataset(**detection_kws, **common)
density_dataset = CellsDataset(**density_kws, **common)
segmentation_dataset = CellsDataset(**segmentation_kws, **common)

In [None]:
sample_idx = 5

# sample = density_dataset[sample_idx][0][:, :, 0]
sample = imread('data/mbm-cells/BM_GRAZ_HE_0001_02_001_cell.png')
detections = detection_dataset[sample_idx][0][1]
density_map = density_dataset[sample_idx][0][:, :, 1]
segmentation_map = segmentation_dataset[sample_idx][0][:, :, 1]
weights_map = segmentation_dataset[sample_idx][0][:, :, 2]

detections = np.clip(detections.astype(int), 0, sample.shape[0] - 1)
sw_map = np.stack((segmentation_map, weights_map, np.zeros_like(weights_map)), axis=-1)

a = np.stack((rgb2gray(sample), segmentation_map, np.zeros_like(segmentation_map)), axis=-1)

density_map = gray2rgb(density_map / density_map.max())
segmentation_map = gray2rgb(segmentation_map)

sample_with_boxes = sample.copy()
for y0, x0, y1, x1 in detections:
    rect = ((y0, x0, y0, x1),
            (y0, x1, y1, x1),
            (y1, x1, y1, x0),
            (y1, x0, y0, x0))
    for r0, c0, r1, c1 in rect:
        rr, cc, val = line_aa(r0, c0, r1, c1)
        sample_with_boxes[rr, cc, 0] = val

        
fig, axes = plt.subplots(1, 4, figsize=(20, 20))
axes = axes.flatten()
axes[0].imshow(sample)
axes[1].imshow(sample_with_boxes)
axes[2].imshow(density_map)
axes[3].imshow(segmentation_map)

for ax in axes:
    ax.set_axis_off()

plt.subplots_adjust(wspace=0)

plt.imsave('figures/mbm-sample.png', sample)
plt.imsave('figures/mbm-boxes.png', sample_with_boxes)
plt.imsave('figures/mbm-density.png', density_map)
plt.imsave('figures/mbm-segmentation.png', segmentation_map)

In [None]:
runs = {
    'S-UNet': Path('runs/experiment=mbm-cells/segmentation/').glob('unet_*'),
    'FRCNN': Path('runs/experiment=mbm-cells/detection/').glob('fasterrcnn_*'),
    'D-CSRNet': Path('runs/experiment=mbm-cells/density/').glob('csrnet_*')
}

metrics = pd.concat([collect_runs(k, r, 'all_metrics.csv.gz') for k, v in runs.items() for r in v], ignore_index=True)

In [None]:
mean_metrics = metrics.groupby(['model', 'num_samples', 'split_seed', 'thr'])['count/mae'].mean()
best_configs = mean_metrics.groupby(['model', 'num_samples', 'split_seed']).idxmin()
table = mean_metrics.loc[best_configs].groupby(['model', 'num_samples']).apply(lambda x: pd.Series({'mean': x.mean(), 'std': x.std()}))
table = table.unstack(2).apply(lambda x: f'{x["mean"]:.1f} $\pm$ {x["std"]:.1f}', axis=1).unstack(1)
print(table.to_latex(escape=False))
table

# PNN Dataset

## Examples

In [None]:
common = dict(split='train-half1', random_offset=0, patch_size=640)
detection_kws = {
    'target_': 'detection',
    'target_params': {
        'side': 45
    }
}

density_kws = {
    'target_': 'density',
    'target_params': {
        'mode': 'reflect',
        'k_size': 151,
        'sigma': 15
    }
}

segmentation_kws = {
    'target_': 'segmentation',
    'target_params': {
        'radius': 20,
        'radius_ignore': 25,
        'v_bal': 0.1, 
        'sigma_bal': 10,
        'sep_width': 1,     
        'sigma_sep': 6,
        'lambda_sep': 50
    }
}

detection_dataset = PerineuronalNetsDataset(**detection_kws, **common)
density_dataset = PerineuronalNetsDataset(**density_kws, **common)
segmentation_dataset = PerineuronalNetsDataset(**segmentation_kws, **common)

In [None]:
random_samples = np.random.randint(0, len(detection_dataset), 10000)

cells_per_sample = [len(detection_dataset[x][0][1]) for x in tqdm(random_samples)]
best_samples = np.argsort(cells_per_sample)[::-1]

In [None]:
sample_idx = random_samples[best_samples[31]]

sample = density_dataset[sample_idx][0][:, :, 0]
detections = detection_dataset[sample_idx][0][1]
density_map = density_dataset[sample_idx][0][:, :, 1]
segmentation_map = segmentation_dataset[sample_idx][0][:, :, 1]
weights_map = segmentation_dataset[sample_idx][0][:, :, 2]


detections = np.clip(detections.astype(int), 0, sample.shape[0] - 1)
sw_map = np.stack((segmentation_map, weights_map, np.zeros_like(sample)), axis=-1)

sample = matplotlib.cm.viridis(sample)
density_map = gray2rgb(density_map / density_map.max())
segmentation_map = gray2rgb(segmentation_map)

sample_with_boxes = sample.copy()
for y0, x0, y1, x1 in detections:
    rect = ((y0, x0, y0, x1),
            (y0, x1, y1, x1),
            (y1, x1, y1, x0),
            (y1, x0, y0, x0))
    for r0, c0, r1, c1 in rect:
        rr, cc, val = line_aa(r0, c0, r1, c1)
        sample_with_boxes[rr, cc, 0] = val

        
fig, axes = plt.subplots(1, 4, figsize=(20, 20))
axes = axes.flatten()
axes[0].imshow(sample)
axes[1].imshow(sample_with_boxes)
axes[2].imshow(density_map)
axes[3].imshow(segmentation_map)

for ax in axes:
    ax.set_axis_off()

plt.imsave('figures/pnn-sample.png', sample)
plt.imsave('figures/pnn-boxes.png', sample_with_boxes)
plt.imsave('figures/pnn-density.png', density_map)
plt.imsave('figures/pnn-segmentation.png', segmentation_map)

## Groundtruth Properties

In [None]:
gt = pd.read_csv('data/perineuronal-nets/test/annotations.csv')
gt['agreement'] = gt.loc[:, 'AV':'VT'].sum(axis=1)

gt.groupby('imgName').X.count()

### Distribution of Agreement in the (Multi-Rater) Test Set

In [None]:
sns.set_theme(context='notebook', style='ticks')
data = gt.agreement.value_counts().sort_index()
_, _, autotexts = plt.pie(data.values, labels=data.index,
                          autopct='{:.2g}%'.format, pctdistance=0.75,
                          # colors=sns.color_palette('rocket', 7)
                          colors=sns.color_palette('rocket', as_cmap=True)(np.linspace(0, 1, 8)[1:])
                         )
plt.ylabel('agreement')

plt.setp(autotexts, size=12)
for t in autotexts[:4]:
    t.set_color('white')
    

In [None]:
pnn_cells = PerineuronalNetsRankDataset(mode='patches')

x = enumerate(pnn_cells.annot.agreement.values)
x = sorted(x, key=lambda x: x[1])
x = itertools.groupby(x, key=lambda x: x[1])

means = []
for agreement, group in x:
    samples = [i for i, _ in group]
    images = [pnn_cells[i][0].astype(np.float32) / 255. for i in samples]
    mean = np.mean(images, axis=0)
    means.append(mean)
    
    sorted_samples = np.sum((images * mean), axis=(1,2)).argsort()
    sorted_samples = np.array(samples)[sorted_samples][::-1]
    
    print(f'{agreement}:', sorted_samples[:12].tolist(), ',')

pnn_means = np.stack(means)
pnn_means = (pnn_means - np.min(pnn_means)) / (np.max(pnn_means) - np.min(pnn_means))
pnn_means = matplotlib.cm.viridis(pnn_means)[:,:,:,:3]

In [None]:
a = sorted(gt.agreement.values) + [100]
a = np.array(a).reshape((112, -1))

fig, ax = plt.subplots(figsize=(8, 5))
ax = sns.heatmap(a, vmin=0, vmax=7, square=True,
            linewidths=0,
            antialiased=True,
            rasterized=True,
            cbar=False, cbar_kws={"orientation": "horizontal", 'pad': 0.05, 'ticks': range(0, 8), 'drawedges': True})
plt.xticks([])
plt.ylabel('Rater\'s Agreement')

def find_pos(a):
    pos = np.unique(a, return_counts=True)[1].cumsum()
    pos = np.insert(pos, 0, 1)
    pos = (pos[:-1] + pos[1:]) / 2
    return pos

l_pos = find_pos(a[:, 0])
r_pos = find_pos(a[:, 1])

_ = plt.yticks(l_pos, range(1, 8), rotation=0)

_, x_limit = plt.xlim()
y_limit, _ = plt.ylim()

def shuf(l):
    l = l[:]
    random.shuffle(l)
    return l

sample_indices = {    
    1: [1565, 2300, 1913, 311, 1799, 763, 72],
    2: [2309, 386, 983, 56, 286, 951, 1774],
    3: [198, 1874, 392, 872, 78, 390, 1103],
    4: [777, 219, 1944, 1066, 217, 1115, 96],
    5: [220, 2174, 945, 389, 385, 1633, 593],
    6: [2218, 2058, 1436, 2212, 2034, 1433, 207],
    7: [453, 7, 6, 4, 12, 644, 20]
}

pct = gt.agreement.value_counts()
pct = pct / pct.sum()
pct = pct.sort_index().values

cell_x = 1.51
nr = 1
for i, (ry, ly) in enumerate(zip(r_pos, l_pos), start=1):
    cell_y = 1 - (i / len(sample_indices))

    # percentage
    ax.annotate(f'{pct[i - 1]:.0%}', (0.5, 1 - (ry+ly) / (2*y_limit)), xycoords='axes fraction',
                color='white' if i < 4 else 'black', ha='center', va='center')
    
     # connector
    ax.annotate('', xy=(1, 1 - ry / y_limit), xycoords='axes fraction',
            xytext=(cell_x, cell_y + 0.06), textcoords='axes fraction',
            arrowprops=dict(arrowstyle='-', color='0.2', connectionstyle='arc,angleA=0,angleB=0,armA=-7,armB=7,rad=0'))
    
    # mean image
    imagebox = OffsetImage(pnn_means[i - 1], zoom=0.5, origin='upper')
    ab = AnnotationBbox(imagebox, (cell_x, cell_y), xycoords='axes fraction', frameon=False, box_alignment=(0,0))
    ax.add_artist(ab)
    
    # samples
    cell_images = [pnn_cells[j][0] for j in sample_indices[i]]
    cell_images = [matplotlib.cm.viridis(c) for c in cell_images]
    cell_images = np.stack(cell_images)[:,:,:,:3]
    image = montage(cell_images, grid_shape=(nr, len(cell_images) / nr), padding_width=5, fill=(1, 1, 1), multichannel=True)
    image = image[5:-5, ...]
    
    imagebox = OffsetImage(image, zoom=0.5, origin='upper')
    ab = AnnotationBbox(imagebox, (cell_x + 0.8, cell_y), xycoords='axes fraction', frameon=False, box_alignment=(0,0))
    ax.add_artist(ab)
    
ax.annotate('mean', (cell_x, 1), xycoords='axes fraction')
ax.annotate('samples', (cell_x + 2, 1), xycoords='axes fraction')
plt.savefig('figures/pnn-mr-breakdown.pdf', bbox_inches='tight')

In [None]:
sample_id = gt.imgName.unique()[2]
img = plt.imread('data/perineuronal-nets/test/fullFrames/' + sample_id)
p2, p98 = np.percentile(img, (0.1, 99.9))
img = exposure.rescale_intensity(img, in_range=(p2, p98))
img = transform.resize(img, (1024, 1024))
img = matplotlib.cm.viridis(img)[:,:,:3]

scale_f = img.shape[0] / 2000

plt.figure(figsize=(8, 8))
plt.imshow(img)
ax = plt.gca()
ax.set_axis_off()

colors = sns.color_palette('rocket', as_cmap=True)(np.linspace(0, 1, 8)[1:])
for agreement, group in gt.set_index('imgName').loc[sample_id].groupby('agreement'):
    color = colors[agreement - 1]
    xs, ys = (group[['X', 'Y']].values * scale_f).astype(int).T
    ax.plot(xs, ys, 'o', ms=25*scale_f, mec=color, mfc='none', mew=0.8)

plt.savefig('figures/pnn-mr-sample.pdf', bbox_inches='tight')

In [None]:
gt_sr = pd.read_csv('data/perineuronal-nets/train/annotations.csv')

sample_id = '034_B4_s06_C1.tif'
img = imread('data/perineuronal-nets/train/fullFrames/' + sample_id)

new_shape = (np.array(img.shape) / 10).astype(int)
scale_f = new_shape[0] / img.shape[0]

img = transform.resize(img, new_shape)

p2, p98 = np.percentile(img, (0.1, 99.9))
img = exposure.rescale_intensity(img, in_range=(p2, p98))

img = matplotlib.cm.viridis(img)[:,:,:3]

plt.figure(figsize=(14, 14))
plt.imshow(img)
ax = plt.gca()
ax.set_axis_off()

xy = gt_sr.set_index('imageName').loc[sample_id, ['X', 'Y']].values * scale_f
xs, ys = xy.astype(int).T
ax.plot(xs, ys, 'o', ms=35 * scale_f, mec='red', mfc='none', mew=0.8)

plt.savefig('figures/pnn-sr-sample.pdf', bbox_inches='tight')

### Agreement between Raters in the Test Set

In [None]:
raters = gt.loc[0, 'AV':'VT'].index.values
raters = np.array(raters).reshape(-1, 1)

def agree(r1, r2):
    a, b = gt[r1], gt[r2]
    return jaccard_score(a, b)

raters_agreement = pdist(raters, agree)
raters_agreement = squareform(raters_agreement)

mask = 1 - np.tri(len(raters), k=-1)

raters_agreement = raters_agreement[1:, :-1]
mask = mask[1:, :-1]

ylabels = [f'R{i+2}' for i in range(len(raters)-1)]
xlabels = [f'R{i+1}' for i in range(len(raters)-1)]
sns.heatmap(raters_agreement, mask=mask, annot=True, square=True,
            xticklabels=xlabels, yticklabels=ylabels, cmap='viridis',
            cbar_kws=dict(location='left', label='Jaccard Index'))
plt.savefig('figures/raters-agreement.pdf', bbox_inches='tight')

### Total Cells counted by each Rater

In [None]:
counts = gt.loc[:, 'AV':'VT'].sum(axis=0)
counts.index = [f'R{i+1}' for i in range(len(counts))]
counts = counts.to_frame(name='count')
counts = counts.reset_index().rename({'index': 'rater'}, axis=1)
ax = sns.barplot(data=counts, x='rater', y='count')

mean = counts.mean().item()
ax.axhline(mean, c='k', ls='--', lw=1.5)
ax.set_ylim(1200, 1650)
sns.despine()

# Stage 1: Localization/Counting Models

## Model Evaluation

In [None]:
metric_order = ('count/mare', 'count/game-3', 'pdet/f1_score')
model_order = ('S-UNet', 'FRCNN', 'D-CSRNet')
scorer_order = ('simple_regression', 'simple_classification', 'ordinal_regression', 'pairwise_balanced')

In [None]:
runs = {
    'S-UNet': Path('runs/experiment=perineuronal-nets/segmentation/').glob('unet_*'),
    'FRCNN' : Path('runs/experiment=perineuronal-nets/detection/')..glob('fasterrcnn_*'),
    'D-CSRNet': Path('runs/experiment=perineuronal-nets/density/').glob('csrnet_*'),
}

In [None]:
def collect(model_name, run, csv_file):
    run = Path(run)
    cfg = OmegaConf.load(run / '.hydra' / 'config.yaml')
    patch_size = cfg['data']['validation']['patch_size']

    csv_path = run / 'test_predictions' / csv_file
    if not csv_path.exists():
        print(f'Skipping not found: {csv_path}')
        return pd.DataFrame()
    
    data = pd.read_csv(csv_path, index_col=0)
    data['model'] = model_name
    data['patch_size'] = patch_size
    
    return data

metrics = pd.concat([collect(k, r, 'all_metrics.csv.gz') for k, v in runs.items() for r in v], ignore_index=True)
predictions = pd.concat([collect(k, r, 'all_gt_preds.csv.gz') for k, v in runs.items() for r in v], ignore_index=True)

predictions['agreement'] = predictions['agreement'].fillna(0)

## What's the best patch size?
Show trade-off between patch size and detection/counting performance.

In [None]:
sns.set_theme(context='poster', style='ticks')#, font_scale=1.5)

def compare_patch_sizes_plot(data, metric, metric_label, mode, fmt='.3f', ylim=(0,1), legend_bbta=(1,1)):
    data = data.rename({'patch_size': 'Patch Size'}, axis=1)
    g = sns.relplot(data=data, kind='line', col='model',
                    x='thr', y=metric, hue='Patch Size', ci=None,
                    facet_kws=dict(margin_titles=True, legend_out=True),
                    aspect=1.2, height=4.5)

    data = data.groupby(['model', 'Patch Size', 'thr']).mean()
    best_points = data.groupby('model')[metric]
    best_points = best_points.idxmin() if mode == 'min' else best_points.idxmax()

    g.set(ylim=ylim, xlim=(0, 1))
    g.set_titles(col_template="{col_name}")
    g.set_axis_labels(x_var='threshold', y_var=metric_label)
    for model, ax in g.axes_dict.items():
        ax.grid(True, which='major')
        ax.grid(True, which='minor', ls='dotted')
        ax.get_xaxis().set_minor_locator(ticker.AutoMinorLocator(2))
        ax.get_yaxis().set_minor_locator(ticker.AutoMinorLocator(2))
        ax.xaxis.set_major_formatter('{x:g}')

        best_point = data.loc[best_points[model]]
        _, patch_size, thr = best_point.name
        value = best_point[metric]
        print(f'[{metric}] {model} ps={patch_size} thr={thr} value={value:.2f}')

        ax.plot([thr], [value], 'X', c='k', ms=9, mec='k', mfc='w')
        xytext = (5, -3) if mode == 'min' else (5, 3)
        va='top' if mode == 'min' else 'bottom'
        ax.annotate(f'{value:{fmt}}', xy=(thr, value), xytext=xytext,
                    textcoords='offset points', fontsize='small',
                    va=va, ha='left')
    
    sns.move_legend(g, "center right", bbox_to_anchor=legend_bbta, frameon=False,
                    labelspacing=0.25, fontsize='x-small', title_fontsize='x-small')
    return g


data = metrics[metrics.thr.between(0, 1)]

compare_patch_sizes_plot(data, 'count/mae', 'MAE', 'min', fmt='.2f', ylim=(0, 200), legend_bbta=(.85, .5)) \
    .savefig('figures/pnn-mae.pdf', bbox_inches='tight')

compare_patch_sizes_plot(data, 'count/mare', 'MARE', 'min', fmt='.1%', legend_bbta=(.85, .50)) \
    .savefig('figures/pnn-mare.pdf', bbox_inches='tight')

compare_patch_sizes_plot(data, 'count/game-3', 'GAME(3)', 'min', fmt='.1f', ylim=(45, 200), legend_bbta=(.85, .5)) \
    .savefig('figures/pnn-game3.pdf', bbox_inches='tight')

compare_patch_sizes_plot(data, 'pdet/f1_score', r'$F_1$-score', 'max', fmt='.1%', legend_bbta=(.85, .63))\
    .savefig('figures/pnn-f1-score.pdf', bbox_inches='tight')

In [None]:
# PR Curves
sns.set_theme(context='poster', style='ticks')

def plot_pr(data, label, color):
    mean_pr = data.groupby('thr').mean().reset_index().sort_values('pdet/recall', ascending=False)
    mean_recalls = mean_pr['pdet/recall'].values
    mean_precisions = mean_pr['pdet/precision'].values
    
    aps = []
    for group_key, img_group in data.groupby('imgName'):
        img_group = img_group.reset_index().sort_values('pdet/recall', ascending=False)
        recalls = img_group['pdet/recall'].values
        precisions = img_group['pdet/precision'].values
        average_precision = - np.sum(np.diff(recalls) * precisions[:-1])  # sklearn's ap
        aps.append(average_precision)
    
    mean_ap = np.mean(aps)
    plt.plot(mean_recalls, mean_precisions, label=f'{label} ({mean_ap:.1%})', color=color)


data = metrics.copy()
data.loc[data['pdet/recall'] == 0, 'pdet/precision'] = 1.0
grid = sns.FacetGrid(data=data, hue='patch_size', col='model', height=4, xlim=(0,1), ylim=(0,1.05), aspect=1.2)
grid.map_dataframe(plot_pr)
grid.set_xlabels('Recall')
grid.set_ylabels('Precision')
grid.set_titles(col_template="{col_name}")

f_scores = np.linspace(0.1, 0.9, num=9)
for ax in grid.axes.flatten():
    ax.legend(title='Patch Size', loc='lower left', ncol=1, fontsize='xx-small', title_fontsize='xx-small')
    
    for i, f_score in enumerate(f_scores):
        label_it = i % 2 != 0
        ls = '-' if label_it else '--'
        lw = 1 if label_it else 0.8
        x = np.linspace(0.01, 1)
        y = f_score * x / (2 * x - f_score)
        l, = ax.plot(x[y >= 0], y[y >= 0], color='gray', alpha=0.2, ls=ls, lw=lw)
        if label_it:
            ax.annotate(r'$F_1$={0:0.1f}'.format(f_score), xy=(0.85, y[45] + 0.02), fontsize='xx-small')

grid.savefig('figures/pnn-pr-curves.pdf', bbox_inches='tight')

### Density-based Metrics

In [None]:
density_metrics = pd.concat([collect(k, r, 'dmap_metrics.csv.gz') for k, v in runs.items() for r in v], ignore_index=True)
density_metrics.groupby(['model', 'patch_size']).mean()

## How's performance on different agreement levels?

Show best counting metrics when practitioners choose different GT based on agreement.

In [None]:
# common funcs
from multiprocessing import Pool, cpu_count

def _temp_func(args):
    func, name, group = args
    return func(group), name

def applyParallel(dfGrouped, func):    
    
    with Pool(cpu_count()) as p:
        gen = [(func, name, group) for name, group in dfGrouped]
        ret_list = p.map(_temp_func, tqdm(gen))
    
    retLst, top_index = zip(*ret_list)
    return pd.concat(retLst, keys=top_index)

def _compute(x):
    x = detection_and_counting(x, image_hw=(2000, 2000))
    return pd.Series(x)


def drop_empty_gp(x):
    empty = x.X.isna() & x.Xp.isna()
    return x[~empty]
    

def compute_metrics_by_agreement(data, grouping, parallel=True):
    is_positive = ~data.Xp.isna() 
    
    filtered = []
    for i in range(1, 8):
        tmp = data.copy()
        tmp.loc[(tmp.agreement < i), 'X'] = None
        tmp = drop_empty_gp(tmp)
        tmp = tmp.assign(min_raters=i)
        filtered.append(tmp)
        
    data = pd.concat(filtered, ignore_index=True)
    data = data.groupby(grouping)
    if parallel:
        data = applyParallel(data, _compute)
        data = data.unstack()
        data.index.names = grouping
        return data
        
    return data.progress_apply(_compute)

def compute_ap(data):
    pr = data.sort_values('pdet/recall', ascending=False)
    recalls = pr['pdet/recall'].values
    precisions = pr['pdet/precision'].values
    ap = - np.sum(np.diff(recalls) * precisions[:-1])
    return ap

def build_map_table(data):
    data = data.copy().reset_index()
    model_grouper = ['model', 'patch_size']
    if 'seed' in data.columns:
        model_grouper.append('seed')
        
    if 'scorer' in data.columns:
        model_grouper.append('scorer')
        data['thr'] = data['re_thr_quantile']
    
    aps = data.groupby(model_grouper + ['min_raters', 'imgName']).apply(compute_ap)
    mean_aps = aps.reset_index().groupby(model_grouper + ['min_raters']).mean()
    mean_aps = mean_aps.rename(columns={0: 'mean_ap'})
    return mean_aps

def build_metrics_table(
    data, 
    metric=['count/mare', 'count/game-3', 'pdet/f1_score'],
    best_metric=None,
    mode='min',
    ci=False,
    return_config=False
):   

    metric = metric if isinstance(metric, list) else [metric]
    best_metric = metric if best_metric is None else best_metric
    best_metric = best_metric if isinstance(best_metric, list) else [best_metric] * len(metric)
    mode = mode if isinstance(mode, list) else [mode] * len(metric)
    
    assert len(metric) == len(best_metric), 'best_metric must be 1 or of the same size of metric'
    assert len(metric) == len(mode), 'mode must be 1 or of the same size of metric'
    
    data = data.copy().reset_index()
    model_grouper = ['model', 'patch_size']
    if 'seed' in data.columns:
        model_grouper.append('seed')
        
    if 'scorer' in data.columns:
        model_grouper.append('scorer')
        data['thr'] = data['re_thr_quantile']
    
    grouped = data.groupby(model_grouper + ['thr', 'min_raters'])
    m, s = grouped.mean(), grouped.std()
    
    tables = []
    configs = []
    for metr, best_metr, mod in zip(metric, best_metric, mode):
        best_points = m # if not ci else (m + s) if mod == 'min' else (m - s)
        best_points = best_points.groupby(model_grouper + ['min_raters'])[best_metr]
        best_points = best_points.idxmin() if mod == 'min' else best_points.idxmax()

        table = m.loc[best_points, [metr]]
        if ci:
            table = table.combine(s.loc[best_points, [metr]], lambda x,y: x.combine(y, lambda w,z: (w,z)))
        
        table = table.reset_index().melt(id_vars=model_grouper + ['thr', 'min_raters'], var_name='metric')
        tables.append(table)
        configs.append(best_points)

    table = pd.concat(tables)
    
    if return_config:
        return table, configs
    
    return table

In [None]:
# let's compute metrics
p1_metrics = compute_metrics_by_agreement(predictions, ['model', 'patch_size', 'thr', 'imgName', 'min_raters'])

In [None]:
def cimax(args):
    imax = args.map(lambda x: x[0] - x[1]*0).idxmax()
    return args.loc[imax]

def cimin(args):
    imin = args.map(lambda x: x[0] + x[1]*0).idxmin()
    return args.loc[imin]

metr = ['count/mae', 'count/mare', 'count/game-3', 'pdet/f1_score']
ci = True
modes = ['min', 'min', 'min', 'max'] 
aggr = [cimin, cimin, cimin, cimax] if ci else modes

p1_table, configs = build_metrics_table(p1_metrics, metric=metr, mode=modes, ci=ci, return_config=True)

prec = 1
xfm = {
    'count/mae': lambda x: f'{x[0]:.1f};{x[1]:.1f}',
    'count/mare': lambda x: f'{100*x[0]:.{prec}f};{100*x[1]:.{prec}f}',
    'count/game-3': lambda x: f'{x[0]:.1f};{x[1]:.1f}',
    'pdet/f1_score': lambda x: f'{100*x[0]:.{prec}f};{100*x[1]:.{prec}f}',
} if ci else {
    'count/mae': '{:.1f}'.format,
    'count/mare': lambda x: f'{100*x:.1f}',
    'count/game-3': '{:.1f}'.format,
    'pdet/f1_score': lambda x: f'{100*x:.1f}', # '{:.0%}'.format,
}

aggr_per_metric = {k: v for k, v in zip(metr, aggr)}

def take_best(a):
    return a['value'].aggregate(aggr_per_metric[a.name[-1]])

p1_table = p1_table.groupby(['model', 'min_raters', 'metric']).apply(take_best).rename('value')
p1_table = p1_table.unstack('metric').transform(xfm).rename_axis('metric', axis=1).stack().rename('value')
p1_table = p1_table.reset_index().pivot(index=['metric', 'model'], columns='min_raters', values='value')
p1_table = p1_table.reindex(metr, level=0).reindex(model_order, level=1)

print(p1_table.to_latex(escape=False, multirow=True))
display(p1_table)

p1_table = p1_table[[1,4,5,7]]

print(p1_table.to_latex(escape=False, multirow=True))
display(p1_table)

In [None]:
def find_best_worst_image(x, metric, mode):
    x = x.reset_index()
    m = x[metric]
    return pd.Series({
        'best' : x.loc[m.idxmin() if mode == 'min' else m.idxmax(), 'imgName'],
        'worst': x.loc[m.idxmax() if mode == 'min' else m.idxmin(), 'imgName']
    })

best_metric = 'pdet/f1_score'
best_metric_configs = configs[-1]
mode = 'max'

best_worst_images = p1_metrics \
    .reset_index().set_index(['model', 'patch_size', 'thr', 'min_raters']) \
    .loc[best_metric_configs] \
    .groupby(['model', 'patch_size', 'thr', 'min_raters']).apply(find_best_worst_image, best_metric, mode) \
    .mode()

tmp = best_metric_configs.reset_index()
selector = (((tmp.model == 'S-UNet')   & (tmp.patch_size == 320)) |
            ((tmp.model == 'FRCNN')    & (tmp.patch_size == 640)) | 
            ((tmp.model == 'D-CSRNet') & (tmp.patch_size == 640)) )
selector = selector & tmp.min_raters.isin([1, 7])
best_configs = tmp[selector][best_metric].values
best_configs

indexed_preds = predictions.set_index(['model', 'patch_size', 'thr', 'imgName'])

for img in ('best', 'worst'):
    imgName = best_worst_images.loc[0, img]
    image = imread('data/perineuronal-nets/test/fullFrames/' + imgName)
    image = matplotlib.cm.viridis(image)[:,:,:3]
    image = resize(image, (500, 500))
    image = (255 * image).astype(np.uint8)
    image = image[:250, 125:375, :]
    
    imsave(f'figures/{img}_clean.png', image)
    
    for model, patch_size, thr, min_raters in best_configs:
        preds = indexed_preds.loc[(model, patch_size, thr, imgName)].reset_index().copy()
        preds.loc[(preds.agreement < min_raters), ['X', 'Y']] = None
        preds = drop_empty_gp(preds)
        preds.loc[:, ['X', 'Y']] /= 4
        preds.loc[:, ['Xp', 'Yp']] /= 4
        
        sel = ( (preds.X.isna() | (preds.X.between(125, 375) & preds.Y.between(0, 250))) | 
                (preds.Xp.isna() | (preds.Xp.between(125, 375) & preds.Yp.between(0, 250))) )
        
        preds = preds[sel]
        preds = drop_empty_gp(preds)
        preds.loc[:, 'X'] -= 125
        preds.loc[:, 'Xp'] -= 125
        
        drawn = draw_groundtruth_and_predictions(image, preds, radius=5)
        fname = f'figures/{img}_img_{model.lower()}_{patch_size}_raters_{min_raters}.png'
        imsave(fname, drawn)
        
        gt_only_img = f'figures/{img}_gt_raters_{min_raters}.png'
        if not Path(gt_only_img).exists():
            gt_sel = ( (predictions.imgName == imgName)
                     & (~predictions.agreement.isna())
                     & (predictions.agreement >= min_raters)
                     )
            
            gt_yx = predictions[gt_sel][['Y', 'X']].drop_duplicates().dropna()
            gt_yx.loc[:, 'X'] = (gt_yx.X / 4) - 125
            gt_yx.loc[:, 'Y'] = (gt_yx.Y / 4)
            gt_yx = gt_yx[gt_yx.X.between(0, 250) & gt_yx.Y.between(0, 250)].values
            
            gt_only = draw_points(image, gt_yx, radius=5, marker='square', color=[255,255,0]) # YELLOW
            imsave(gt_only_img, gt_only)
    

## How much does rescoring (stage 2) increase performance?
Compare counting and detection metrics of stage-1 only models and stage-2 refinement.

In [None]:
runs_score_path = Path('runs_score')
runs_score = {
    'AR': runs_score_path.glob('method=simple_regression,seed=*'),
    'AC': runs_score_path.glob('method=simple_classification,seed=*'),
    'OR': runs_score_path.glob('method=ordinal_regression,seed=*'),
    'RL': runs_score_path.glob('method=pairwise_balanced,seed=*'),
}

def collect_scores(model_name, run):
    run = Path(run)
    csv_path = run / 'test_predictions' / 'all_gt_preds.csv.gz'
    data = pd.read_csv(csv_path, index_col=0)
    data['model'] = model_name
    data['seed'] = int(run.name.split('=')[-1])
    return data

score_data = [collect_scores(k, run) for k, runs in runs_score.items() for run in runs]
score_data = pd.concat(score_data, ignore_index=True)

test_images = score_data.groupby('seed').imgName.unique().to_dict()

In [None]:
# best configs for maximum recall

def max_recall(data):
    data = data.sort_values(['pdet/recall', 'pdet/precision'], ascending=[False, False])
    return data.head(1).index.values

p1_metrics.xs(1, level='min_raters').groupby(['model', 'patch_size', 'thr']).mean().groupby('model').apply(max_recall)

### Samples per Scorer

In [None]:
dataset = PerineuronalNetsRankDataset(mode='patches')

In [None]:
sns.set_theme(context='notebook', style='ticks', font_scale=1)

p2i = dataset.annot.reset_index().set_index(['imgName','X','Y'])

so = ('Pair-wise Regression', 'Ordinal Regression', 'Agreement Classification', 'Agreement Regression')

sample_idx = rank_data.groupby(['model', 'agreement'])\
    .apply(lambda x: x.nlargest(10, 'score')).droplevel(-1)\
    .apply(lambda x: p2i.loc[tuple(x[['imgName', 'X', 'Y']].values), 'index'], axis=1)\

nr=1
fig, axes = plt.subplots(7, len(so), figsize=(17,4))
for i, scorer in enumerate(so):
    axes[0, i].set_title(scorer)
    for j, agreement in enumerate(range(7, 0, -1)):
        samples = sample_idx.loc[(scorer, agreement)]
        cell_images = [dataset[i][0] for i in samples]
        cell_images = [matplotlib.cm.viridis(c) for c in cell_images]
        cell_images = np.stack(cell_images)[:,:,:,:3]
        image = montage(cell_images, grid_shape=(nr, len(cell_images) / nr), padding_width=5, fill=(1, 1, 1), multichannel=True)
        axes[j, i].imshow(image)
        axes[j, i].set_axis_off()

for j, agreement in enumerate(range(7, 0, -1)):
    axes[j, 0].set_ylabel(str(agreement))
    
plt.subplots_adjust(wspace=0.1, hspace=0)

### Stage-1 Only Metrics

In [None]:
# get best patch_size per method, all thresholds
tmp = p1_metrics.reset_index()

selector = (((tmp.model == 'S-UNet')   & (tmp.patch_size == 320)) |
            ((tmp.model == 'FRCNN')    & (tmp.patch_size == 640)) | 
            ((tmp.model == 'D-CSRNet') & (tmp.patch_size == 640)) )

tmp = tmp[selector]

# keep only test set
tmp = pd.concat([tmp[tmp.imgName.isin(images)].assign(seed=seed) for seed, images in test_images.items()], ignore_index=True)
p1_test_metrics = tmp.set_index(p1_metrics.index.names + ['seed'])

In [None]:
p1t_table = build_metrics_table(p1_test_metrics, metric=metr, mode=modes)
p1t_table = p1t_table.groupby(['model', 'patch_size', 'min_raters', 'metric']).value.aggregate(['mean', 'std'])
p1t_table = p1t_table.unstack('metric')

pct_f = lambda x: f'{100*x:.1f}'
flo_f = '{:.1f}'.format
p1t_table = p1t_table.transform({
    ('mean', 'count/mae' ): flo_f,
    ('std' , 'count/mae' ): flo_f,
    ('mean', 'count/mare'   ): pct_f,
    ('std' , 'count/mare'   ): pct_f,
    ('mean', 'count/game-3' ): flo_f,
    ('std' , 'count/game-3' ): flo_f,
    ('mean', 'pdet/f1_score'): pct_f,
    ('std' , 'pdet/f1_score'): pct_f,
})

p1t_table = p1t_table['mean'] + ' $\pm$ ' + p1t_table['std']
p1t_table = p1t_table.rename_axis('metric', axis=1).stack().rename('value').reset_index()
p1t_table = p1t_table.pivot(index=['metric', 'model'], columns='min_raters', values='value')
p1t_table = p1t_table.reindex(metr, level=0).reindex(model_order, level=1)
p1t_table

In [None]:
display(p1_table, p1t_table)

In [None]:
p1_map_table = build_map_table(p1_metrics)\
    .unstack('min_raters')\
    .applymap(lambda x: f'{100*x:.1f}')

p1t_map_table = build_map_table(p1_test_metrics)\
    .reset_index()\
    .groupby(['model', 'patch_size', 'min_raters'])\
    .mean_ap.aggregate(['mean', 'std'])\
    .applymap(lambda x: f'{100*x:.1f}')

p1t_map_table = p1t_map_table['mean'] + ' $\pm$ ' + p1t_map_table['std']
p1t_map_table = p1t_map_table\
    .rename('value')\
    .reset_index()\
    .pivot(index=['model', 'patch_size'], columns='min_raters', values='value')

display(p1_map_table, p1t_map_table)

### Score vs Agreement Correlation

In [None]:
# get best config per model, maximum recall
selector = (((predictions.model == 'S-UNet')   & (predictions.patch_size == 320) & (predictions.thr == 0.1)) |
            ((predictions.model == 'FRCNN')    & (predictions.patch_size == 640) & (predictions.thr == 0.0)) | 
            ((predictions.model == 'D-CSRNet') & (predictions.patch_size == 640) & (predictions.thr == 0.0)))

# keep only test sets
keep = np.unique(np.concatenate(list(test_images.values()))).tolist()
selector = selector & predictions.imgName.isin(keep)

p1_data = predictions[selector].copy()
p1_data['agreement'] = p1_data['agreement'].fillna(0)
p1_data['seed'] = 23

p2_data = score_data.copy()
p2_data['patch_size'] = -1

rdata = pd.concat([p1_data, p2_data], ignore_index=True)

def normalize_scores(data):
    data['score'] = StandardScaler().fit_transform(data['score'].values.reshape(-1, 1)) # * 0.5 + 0.5
    return data

rdata = rdata.groupby(['model', 'patch_size', 'seed']).apply(normalize_scores)
rdata

In [None]:
sns.set_theme(context='talk', style='ticks', font_scale=1.5)

plot_data = rdata[~rdata.score.isna() & (rdata.agreement > 0)].copy()
plot_data['agreement'] = plot_data.agreement.astype(int)
plot_data = plot_data[['score', 'agreement', 'model']]

order = [
    'S-UNet',
    'FRCNN',
    'D-CSRNet',
    'AR', #'Agreement Regression',
    'AC', #'Agreement Classification',
    'OR', #'Ordinal Regression',
    'RL', #'Rank Learning',
]

fig, ax = plt.subplots(figsize=(16, 7))
width = 0.8
sns.boxenplot(data=plot_data, y='score', x='model', hue='agreement', order=order, palette='rocket', ax=ax, width=width, showfliers=False)
ax.set_yticks(range(-3, 4))
ax.set_yticklabels(range(-3, 4))
ax.axhline(xmax=.95, c='k', zorder=-10, lw=1.5)

def corr_coeff(data, **kws):
    sel = (~data.score.isna()) & (~data.agreement.isna())
    x = data.loc[sel, 'score']
    y = data.loc[sel, 'agreement']
    r, p = scipy.stats.pearsonr(x, y)
    return r

def lin_fit(data, **kws):
    sel = (~data.score.isna()) & (~data.agreement.isna())
    
    p = []
    grouped = data[sel].groupby('agreement')
    min_num = grouped.model.count().min()
    for _ in range(50):
        y, x = grouped.sample(min_num)[['score', 'agreement']].values.T
        z = np.polyfit(x, y, 1)   
        p.append(z)
    
    p = np.mean(p, axis=0)
    p = np.poly1d(p)
    return p

grouped = plot_data.groupby('model')
corrs = grouped.apply(corr_coeff)
linfits = grouped.apply(lin_fit)

display(corrs)

labels = [l.get_text() for l in ax.get_xticklabels()]
labels = ['{}\n$r$={:.2f}'.format(l.replace(' ', '\n'), corrs[l]) for l in labels]

ax.set_xticklabels(labels)
ax.set_xlabel(None)

ax.grid(which='major', axis='y', ls='-', lw=.75)
ax.tick_params(axis='x', color='white')

handles, labels = ax.get_legend_handles_labels()

legend_order = (1, 2, 3, 4, 5, 6, 7)
handles = [handles[i-1] for i in legend_order]
labels = [labels[i-1] for i in legend_order]

ax.legend(handles, labels, title='agreement',
          ncol=7, loc='upper center', bbox_to_anchor=(.12,.12,.77,1),
          fontsize='x-small', title_fontsize='x-small',
          labelspacing=0.2, columnspacing=1, framealpha=0)

for i, line in enumerate(linfits[order]):
    x = [i - 4 * width / 7, i + 4 * width / 7]
    y = line([0, 8])
    ax.plot(x, y, c='w', ls='--', path_effects=[pe.Stroke(linewidth=4, foreground='k'), pe.Normal()])

sns.despine(bottom=True)
plt.savefig('figures/score-vs-agreement.pdf', bbox_inches='tight')

In [None]:
sns.set_theme(context='notebook', style='ticks', font_scale=1.5)

plot_data = rdata.fillna({'agreement': 0, 'score': -100000}).groupby(['model', 'X', 'Y', 'agreement']).score.mean().reset_index()
sorted_samples = plot_data.groupby('model')\
    .apply(lambda x: \
           x.sort_values(['score', 'Y', 'X'], ascending=[False, True, True]).agreement)
sorted_samples = sorted_samples.droplevel(-1).reset_index()
sorted_samples['model'] = sorted_samples['model'].str.replace(' ', '\n')

order = [
    'S-UNet',
    'FRCNN',
    'D-CSRNet',
    'Agreement\nRegression',
    'Agreement\nClassification',
    'Ordinal\nRegression',
    'Rank\nLearning',
]

def heatmap_plot(data, color, **kws):
    data = data.agreement.values
    rows = 32
    pad = (-data.size) % rows
    
    mask = np.zeros_like(data)
    mask = np.pad(mask, (0, pad), constant_values=1).reshape(-1, rows)
    data = np.pad(data, (0, pad), constant_values=100).reshape(-1, rows)
    
    sns.heatmap(ax=plt.gca(), data=data, mask=mask, **kws)
    
    
g = sns.FacetGrid(data=sorted_samples, col='model', aspect=.45, height=5, col_order=order)
cbar_ax = g.fig.add_axes([.99, .2, .01, .60])  # create a colorbar axes

g = g.map_dataframe(heatmap_plot, vmin=0, vmax=7, square=True, antialiased=True, rasterized=True,
                    cbar_ax=cbar_ax, cbar_kws=dict(
                        ticks=range(8),
                        ticklocation='right', orientation='vertical',
                        label='agreement',
                    ))
for ax in g.axes.flatten():
    ax.axis('off')
    
g.set_titles(col_template="{col_name}")
g.tight_layout()
g.fig.subplots_adjust(wspace=.05)#, hspace=0.05)
g.savefig('figures/score-gradient.pdf', bbox_inches='tight')

### Stage-2 Metrics

In [None]:
rescored_runs = {
    'S-UNet': [
        ('runs/experiment=perineuronal-nets/segmentation/unet_320', 0.1),
    ],
    'FRCNN' : [
        ('runs/experiment=perineuronal-nets/detection/fasterrcnn_640', 0.00),
    ],
    'D-CSRNet': [
        ('runs/experiment=perineuronal-nets/density/csrnet_640', 0.00),
    ],
}

In [None]:
def collect_rescored(model_name, run, thr):
    run = Path(run)
    cfg = OmegaConf.load(run / '.hydra' / 'config.yaml')
    patch_size = cfg['data']['validation']['patch_size']

    preds = []
    csv_paths = (run / 'test_predictions').glob('all_gt_preds_rescored_*seed*_imgsplit.csv.gz')
    for csv_path in csv_paths:
        method_and_seed = csv_path.name[len('all_gt_preds_rescored_'):-len('_imgsplit.csv.gz')]
        rescore_method, seed = method_and_seed.split('-')
        seed = int(seed[len('seed'):])
        
        data = pd.read_csv(csv_path)
        data = data[(data.thr == thr) & (data.imgName.isin(test_images[seed]))]
        data['model'] = model_name
        data['patch_size'] = patch_size
        data['scorer'] = rescore_method
        data['seed'] = seed
        
        preds.append(data)
    
    # no rescore
    if False:
        data = data.copy()
        data['scorer'] = 'no_rescore'
        data['rescore'] = data['score']
        preds.append(data)
    
    return pd.concat(preds, ignore_index=True)

rescored_predictions = pd.concat([collect_rescored(k, r, t) for k, v in rescored_runs.items() for r, t in v], ignore_index=True)
rescored_predictions['agreement'] = rescored_predictions.agreement.fillna(0)


def apply_percentile_thresholds(gp):
    quantiles = np.linspace(0, 1, 201)
    if gp.scorer.iloc[0] == 'no_rescore':
        re_thrs = quantiles.tolist()
    else:
        re_thrs = gp.rescore.quantile(quantiles).tolist()
    
    quantiles = quantiles.tolist()
    quantiles.append(2.)
    re_thrs.append(re_thrs[-1] + 1)
    
    all_thresholded = []
    for re_thr, q in zip(re_thrs, quantiles):
        thresholded = gp.copy()
        thresholded.loc[(gp.rescore < re_thr) | gp.rescore.isna(), 'Xp'] = None
        thresholded = thresholded[~(thresholded.X.isna() & thresholded.Xp.isna())]
        thresholded['re_thr'] = re_thr
        thresholded['re_thr_quantile'] = q
        all_thresholded.append(thresholded)
    
    return pd.concat(all_thresholded, ignore_index=True)    

rescored_predictions = rescored_predictions.groupby(['patch_size', 'model', 'seed', 'thr', 'scorer'])\
                                           .progress_apply(apply_percentile_thresholds)\
                                           .reset_index(drop=True)

In [None]:
p2_metrics = compute_metrics_by_agreement(
    rescored_predictions,
    ['model', 'patch_size', 'scorer', 're_thr_quantile', 'imgName', 'min_raters', 'seed']
)

In [None]:
p1t_mean_ap = build_map_table(p1_test_metrics)
p2_mean_ap = build_map_table(p2_metrics)

#display(p1t_mean_ap, p2_mean_ap)
tmp_p1 = pd.concat({'-': p1t_mean_ap}, names=['scorer']).reset_index().set_index(p2_mean_ap.index.names)
combined = pd.concat((p2_mean_ap, tmp_p1))\
    .groupby(['model', 'patch_size', 'scorer', 'min_raters']).mean()

diff = combined - combined.xs('-', level=2)

def fmt(absolute, difference):
    return f'{absolute:.2f} ({difference:.2f})'

def styling(x):
    diff = float(x.split(' ')[1].strip('()'))
    color = '#ADFFAD' if diff > 0 else '#ffadad'if diff < 0 else 'none'
    return f'background-color: {color}'

combined.combine(diff, lambda x, y: x.combine(y, fmt)) \
    .reindex(model_order, level=0).reindex(('-',) + scorer_order, level=2) \
    .unstack('min_raters') \
    .style.applymap(styling)

In [None]:
metr = ['count/mae', 'count/mare', 'count/game-3', 'pdet/f1_score']
modes = ['min', 'min', 'min', 'max'] 
# modes = ['min', 'min', 'max'] 

p1t_table = build_metrics_table(p1_test_metrics, metric=metr, mode=modes).assign(scorer='-')
p2_table = build_metrics_table(p2_metrics, metric=metr, mode=modes)

#p1t_table = pd.concat({'-': p1t_table}, names=['scorer']).reset_index().set_index(p2_table.index.names)

combined = pd.concat((p1t_table, p2_table), ignore_index=True)\
    .groupby(['model', 'patch_size', 'scorer', 'min_raters', 'metric']).value.mean().rename('value')\
    .reset_index().set_index(['metric', 'model', 'patch_size', 'scorer', 'min_raters'])

#combined
diff = combined - combined.xs('-', level='scorer')
# display(combined, combined.xs('-', level='scorer'))

def fmt(absolute, difference):
    return f'{absolute:.2f} ({difference:.2f})'

def styling_up(x):
    diff = float(x.split(' ')[1].strip('()'))
    color = '#ADFFAD' if diff > 0 else '#ffadad'if diff < 0 else 'none'
    return f'background-color: {color}'

def styling_down(x):
    diff = float(x.split(' ')[1].strip('()'))
    color = '#ADFFAD' if diff < 0 else '#ffadad'if diff > 0 else 'none'
    return f'background-color: {color}'

styles = {
    'count/mae': styling_down,
    'count/mare': styling_down,
    'count/game-3': styling_down,
    'pdet/f1_score': styling_up
}

def styling(x):
    style_func = styles[x.name[0]]
    return [style_func(i) for i in x.values]

table = combined.combine(diff, lambda x, y: x.combine(y, fmt)) \
    .reindex(model_order, level='model')\
    .reindex(('-',) + scorer_order, level='scorer') \
    .unstack('min_raters')\
    .style.apply(styling, axis=1)

display(table)


def latex_fmt(a, d):
    return f'{a:.2f}\diff{{{d:.2f}}}'

table = combined.combine(diff, lambda x, y: x.combine(y, latex_fmt))\
    .reindex(model_order, level='model')\
    .reindex(('-',) + scorer_order, level='scorer') \
    .unstack('min_raters')\
    .droplevel('patch_size', axis=0)\
    .droplevel(0, axis=1)\
    .loc['count/mae', [1, 4, 5,7]]\
    .rename({
        '-': '',
        'simple_regression': 'AR', #'Agreement Regression',
        'simple_classification': 'AC', #'Agreement Classification',
        'ordinal_regression': 'OR', #'Ordinal Regression',
        'pairwise_balanced': 'PW', #'Pair-wise Regression',
    }, axis=0, level='scorer')
    

table = table.set_index(table.index.map(lambda x: x[0] + (' + ' + x[1] if x[1] else '')))
print(table.to_latex(escape=False))
table