# Uncertainty-adjusted accuracy evaluation

See `README.md` for installation and usage instructions.

This notebook re-creates some figures of [2] on the toy dataset
introduced in [1] as well as the dermatology dataset.

```
[1] Stutz, D., Roy, A.G., Matejovicova, T., Strachan, P., Cemgil, A.T.,
    & Doucet, A. (2023).
    Conformal prediction under ambiguous ground truth. ArXiv, abs/2307.09302.
[2] Stutz, D., Cemgil, A.T., Roy, A.G., Matejovicova, T., Barsbey, M.,
    Strachan, P., Schaekermann, M., Freyberg, J.V., Rikhye, R.V., Freeman, B.,
    Matos, J.P., Telang, U., Webster, D.R., Liu, Y., Corrado, G.S., Matias, Y.,
    Kohli, P., Liu, Y., Doucet, A., & Karthikesalingam, A. (2023).
    Evaluating AI systems under uncertain ground truth: a case study in
    dermatology. ArXiv, abs/2307.02191.
```

# Imports and setup

In [None]:
import functools
import jax
import jax.numpy as jnp
import matplotlib
from matplotlib import pyplot as plt
import numpy as np
import os
import pickle

In [None]:
import agreement
import classification_metrics
import eval_utils
import irn as aggregation
import colab_utils

In [None]:
compute_rank1_certainties = jax.jit(eval_utils.rankk_certainties)

## Load data

In [None]:
dataset = 'derm'  #@param ['toy', 'derm']

In [None]:
model_names = ['A', 'B', 'C', 'D']
if dataset == 'toy':
  with open('data/toy_data.pkl', 'rb') as f:
    data = pickle.load(f)
  model_predictions = []
  for i in range(4):
    with open(f'data/toy_predictions{i}.pkl', 'rb') as f:
      model_predictions.append(pickle.load(f))
  model_predictions = jnp.array(model_predictions)
  num_readers = 3
  irn_plausibilities = aggregation.aggregate_irn(
      data['test_rankings'][:, :num_readers],
      data['test_groups'][:, :num_readers])
  num_classes = 3
elif dataset == 'derm':
  with open('data/dermatology_data.pkl', 'rb') as f:
    data = pickle.load(f)
  model_predictions = []
  for i in range(4):
    with open(f'data/dermatology_predictions{i}.txt', 'r') as f:
      model_predictions.append(np.loadtxt(f))
  model_predictions = jnp.array(model_predictions)
  num_readers = 10
  num_classes = 419
  irn_plausibilities = data['test_irn']

In [None]:
num_samples = 100  # We used 1000 in the paper.
prirn_plausibilities = []
temperatures = [1, 3, 5, 10, 20, 30, 50]
for temperature in temperatures:
  plausibilities_t = aggregation.sample_prirn(
    jax.random.PRNGKey(0), irn_plausibilities, num_samples=num_samples,
    temperature=temperature, alpha=0.01)
  prirn_plausibilities.append(plausibilities_t)
  print(f'Computed PrIRN for temperature {temperature}.')
prirn_plausibilities = jnp.array(prirn_plausibilities)

# Certainty analysis and reader agreement

In [None]:
def compare_rank1_certainties(
    plausibilities, names, **kwargs):
  """Plot rank-1 certainties for plausibilities across reliabilities."""
  num_models, num_examples, _, num_classes = plausibilities.shape
  for m in range(num_models):
    certainties = compute_rank1_certainties(
        plausibilities[m], jnp.arange(num_classes))
    certainties = jnp.max(certainties, axis=-1)
    indices = jnp.argsort(certainties)
    plt.plot(
        jnp.arange(num_examples),
        certainties[indices],
        label=names[m])
  plt.gcf().set_size_inches(
      kwargs.get('width', 12), kwargs.get('height', 2.5))
  plt.title(kwargs.get('title', f'Certainties across trust'))
  plt.ylabel('Certainty')
  plt.xlabel('Sorted examples')
  plt.xlim(0, num_examples)
  plt.ylim(0, 1)
  plt.legend()
  plt.gcf().set_size_inches(kwargs.get('width', 5), kwargs.get('height', 2))
  plt.show()

In [None]:
compare_rank1_certainties(
    prirn_plausibilities, names=[f'temperature {m}' for m in temperatures],
    width=7, title='PrIRN top-1 annotation certainties across temperatures')

In [None]:
def compute_coverage_agreement(rankings, groups):
  """Compute mean agreement using coverage against top-1 conditions."""
  agreements = agreement.leave_one_reader_out_coverage_agreement(
      rankings, groups, jnp.array([10] * rankings.shape[0]))
  return jnp.sum(agreements, axis=1) / 10.

In [None]:
def plot_rank1_certainties_with_agreement(
    agreements, plausibilities, **kwargs):
  """Plot rank-1 certaninties with mean reader agreement."""
  num_examples, _, num_classes = plausibilities.shape
  certainties = compute_rank1_certainties(
      plausibilities, jnp.arange(num_classes))
  certainties = jnp.max(certainties, axis=1)
  indices = np.argsort(certainties)
  correlation = np.corrcoef(certainties, agreements)[0, 1]
  plt.plot(
      np.arange(num_examples),
      certainties[indices],
      label='Top-1 certainties')
  plt.scatter(
      np.arange(num_examples),
      agreements[indices],
      label='Agreements', s=4, c=colab_utils.COLOR_RED)
  m, b = np.polyfit(np.arange(num_examples), agreements[indices], 1)
  plt.plot(np.arange(num_examples), m*np.arange(num_examples)+b, color='gray',
           label='Regression line')
  plt.title(kwargs.get(
      'title',
      'Top-1 certainty and reader agreement '
      f'(corr. {correlation:.2f})'))
  plt.ylabel(kwargs.get('ylabel', 'Certainty / agreement'))
  plt.xlabel('Sorted examples')
  plt.legend()
  plt.xlim(0, num_examples)
  plt.ylim(0, 1)
  plt.gcf().set_size_inches(kwargs.get('width', 5), kwargs.get('height', 2))
  plt.show()

In [None]:
for i in range(len(temperatures)):
  plot_rank1_certainties_with_agreement(
      compute_coverage_agreement(data['test_rankings'], data['test_groups']),
      prirn_plausibilities[i], width=7)

# Model comparison

In [None]:
def compute_ua_topk_accuracies(
    predictions, plausibilities, k, break_ties=False):
  """Compute uncertainty-adjusted accuracies."""
  num_examples, _, num_classes = plausibilities.shape
  if break_ties:
    plausibilities += (jax.random.uniform(
        jax.random.PRNGKey(0), plausibilities.shape) - 0.5) * 1e-4
  labels = classification_metrics.topk_sets(
      plausibilities.reshape(-1, num_classes),
      k=1).reshape(num_examples, -1, num_classes)
  return eval_utils.map_across_plausibilities(
      predictions, labels,
      functools.partial(classification_metrics.aggregated_topk_accuracy, k=k))

In [None]:
def compare_rank1_certainties_with_ua_accuracies(
    predictions, plausibilities, model_names, k=3, **kwargs):
  """Plot rank-1 certainty with uncertainty-adjusted accuracy."""
  num_models, _, _ = predictions.shape
  num_examples, _, num_classes = plausibilities.shape
  certainties = jnp.max(compute_rank1_certainties(
      plausibilities, jnp.arange(num_classes)), axis=1)
  for m in range(num_models):
    accuracies = compute_ua_topk_accuracies(
        predictions[m], plausibilities, k)
    accuracies = jnp.mean(accuracies, axis=1)
    indices = jnp.argsort(accuracies)
    plt.plot(
        jnp.arange(num_examples), accuracies[indices],
        label=model_names[m])
  plt.plot(
      jnp.arange(num_examples), jnp.sort(certainties),
      label='Top-1 certainties',color='gray', linestyle='dashed')
  plt.title(f'UA top-{k} accuracy and certainty')
  plt.ylabel('Certainty / correct')
  plt.xlabel('Sorted examples')
  plt.legend()
  plt.gcf().set_size_inches(kwargs.get('width', 5), kwargs.get('height', 2))
  plt.show()

In [None]:
models_to_compare = jnp.array([0, 2])
compare_rank1_certainties_with_ua_accuracies(
    model_predictions[models_to_compare], prirn_plausibilities[1],
    [model_names[m] for m in models_to_compare], k=1)

In [None]:
def plot_ua_accuracies(
    predictions,
    irn_plausibilities, prirn_plausibilities,
    k=3, **kwargs):
  """Plot uncertainty-adjusted top-k accuracy for different plausibilities."""
  irn_labels = jnp.argmax(irn_plausibilities, 1)
  irn_accuracies = classification_metrics.aggregated_topk_accuracy(
      predictions, jax.nn.one_hot(irn_labels, irn_plausibilities.shape[1]), k)
  prirn_accuracies = compute_ua_topk_accuracies(
      predictions, prirn_plausibilities, k)
  prirn_hist, _ = colab_utils.plot_hist(
      jnp.mean(prirn_accuracies, axis=0),
      alpha=0.5, label='PrIRN accuracies',
      color=colab_utils.COLORS[0])
  hist_max = jnp.max(prirn_hist)
  plt.vlines(
      jnp.mean(prirn_accuracies), 0, hist_max,
      label='PrIRN UA accuracy', color=colab_utils.COLORS[0])
  plt.vlines(
      jnp.mean(irn_accuracies), 0, hist_max,
      label='IRN accuracy', color=colab_utils.COLORS[0], linestyle='dotted')
  plt.legend(loc='upper left')
  plt.ylabel('Counts')
  plt.xlabel('Accuracy' if k == 1 else f'Top-{k} accuracy')
  plt.title(kwargs.get('title', 'UA accuracy and certainty'))
  plt.gcf().set_size_inches(kwargs.get('width', 5), kwargs.get('height', 2))
  plt.show()

In [None]:
models_to_compare = jnp.array([0, 2])
for m in models_to_compare:
  plot_ua_accuracies(
      model_predictions[m], irn_plausibilities,
      prirn_plausibilities[3], k=1, title=model_names[m])

In [None]:
def plot_model_comparison_with_certainty(
    accuracy_fn, predictions, plausibilities,
    temperatures, model_names, k=None,
    num_samples=1000, **kwargs):
  """Compare models across reliabilities."""
  num_temperatures, _, _, _ = plausibilities.shape
  assert len(temperatures) == num_temperatures
  num_models = predictions.shape[0]
  assert len(model_names) == num_models

  vmax = 0
  vmin = 1
  ax = plt.gca()
  for m in range(num_models):
    accuracies_m = []
    for i, _ in enumerate(temperatures):
      accuracies_m_i = accuracy_fn(
          predictions[m], plausibilities[i, :, :num_samples])
      accuracies_m.append(accuracies_m_i)
    # Before: num_temperatures x num_examples x num_samples
    accuracies_m = jnp.array(accuracies_m)
    mean_accuracies_m = jnp.mean(jnp.mean(accuracies_m, axis=1), axis=1)
    std_accuracies_m = jnp.std(jnp.mean(accuracies_m, axis=1), axis=1)
    max_accuracies_m = mean_accuracies_m + std_accuracies_m
    min_accuracies_m = mean_accuracies_m - std_accuracies_m
    ax.plot(
        temperatures[:-1], mean_accuracies_m[:-1],
        label=model_names[m], color=colab_utils.COLORS[m])
    ax.fill_between(
        temperatures[:-1], min_accuracies_m[:-1], max_accuracies_m[:-1],
        alpha=0.1, color=colab_utils.COLORS[m])
    ax.scatter(
        temperatures[-1], mean_accuracies_m[-1],
        s=25, marker='x', color=colab_utils.COLORS[m])
    vmax = max(vmax, jnp.max(max_accuracies_m))
    vmin = min(vmin, jnp.min(min_accuracies_m))

  ax.vlines(
      temperatures[-2], kwargs.get('ymin', vmin), kwargs.get('ymax', vmax + 0.005),
      color='gray', linestyle='dotted')

  ax.legend(loc='lower right', bbox_to_anchor=(0.85, 0.025))
  plt.title(kwargs.get('title', f'Certainty and top-{k} accuracy'))
  ax.set_xlabel(kwargs.get('xlabel', 'Repeated readers'))
  ax.set_ylabel(kwargs.get('ylabel', 'Accuracy'))
  ax.set_ylim(kwargs.get('ymin', vmin), kwargs.get('ymax', vmax + 0.005))
  ax.set_xlim(
      kwargs.get('xmin', min(temperatures)),
      kwargs.get('xmax', max(temperatures)))
  ax.set_xticks(
      kwargs.get('xticks', []), kwargs.get('xticklabels', []))
  ax.set_yticks(
      kwargs.get('yticks', []), kwargs.get('yticklabels', None))
  plt.gcf().set_size_inches(kwargs.get('width', 3.75), kwargs.get('height', 4))
  plt.grid()
  plt.show()

In [None]:
kwargs = dict(
    temperatures=temperatures + [55], num_samples=10,
    xlabel=None, ylabel=None, xticks=temperatures + [55],
    xticklabels=['Low', '', '', 'Med', '', '', 'High', 'ML'],
    model_names=model_names, yticks=[0.5, 0.6, 0.7, 0.8, 0.9, 1],
    ymin=0.5, ymax=1,
)
prirn_plausibilities_with_irn = jnp.concatenate((
    prirn_plausibilities,
    jnp.repeat(irn_plausibilities.reshape(1, -1, 1, num_classes), num_samples, axis=2)
), axis=0)

In [None]:
top_k = 1 if num_classes == 3 else 3
plot_model_comparison_with_certainty(
    functools.partial(compute_ua_topk_accuracies, k=top_k),
    model_predictions, prirn_plausibilities_with_irn,
    title=f'Top-{top_k} UA accuracy', **kwargs)