# Partial average overlap

See `README.md` for installation and usage instructions.

This notebook shows an example of using the partial average overlap
implementation of [2] on the dataset of [1] as well as the dermatology dataset.

```
[1] Stutz, D., Roy, A.G., Matejovicova, T., Strachan, P., Cemgil, A.T.,
    & Doucet, A. (2023).
    Conformal prediction under ambiguous ground truth. ArXiv, abs/2307.09302.
[2] Stutz, D., Cemgil, A.T., Roy, A.G., Matejovicova, T., Barsbey, M.,
    Strachan, P., Schaekermann, M., Freyberg, J.V., Rikhye, R.V., Freeman, B.,
    Matos, J.P., Telang, U., Webster, D.R., Liu, Y., Corrado, G.S., Matias, Y.,
    Kohli, P., Liu, Y., Doucet, A., & Karthikesalingam, A. (2023).
    Evaluating AI systems under uncertain ground truth: a case study in
    dermatology. ArXiv, abs/2307.02191.
```

In [None]:
import jax.numpy as jnp
import matplotlib
import os
import pickle
import numpy as np

In [None]:
import ranking_metrics
import colab_utils
import formats

## Load data

In [None]:
dataset = 'derm'  #@param ['toy', 'derm']

In [None]:
if dataset == 'toy':
  with open('data/toy_data.pkl', 'rb') as f:
    data = pickle.load(f)
  with open(f'data/toy_predictions0.pkl', 'rb') as f:
    model_predictions = pickle.load(f)
  indices = [0, 1, 2]  # Random examples.
elif dataset == 'derm':
  with open('data/dermatology_data.pkl', 'rb') as f:
    data = pickle.load(f)
  with open('data/dermatology_predictions0.txt', 'rb') as f:
    model_predictions = np.loadtxt(f)
  indices = [1026, 1057, 357]  # Examples from the paper.

## Compute (partial) average overlap

In [None]:
indices = np.array(indices, dtype=int)
num_examples = indices.size
model_rankings = jnp.argsort(- model_predictions, axis=1)
model_groups = jnp.array([jnp.arange(model_rankings.shape[1]) for _ in model_rankings])

In [None]:
for i, index in enumerate(indices):
  for j, selector in enumerate(data['test_selectors'][index]):
    print('Example', i, ', annotation', j, ':', data['test_selectors'][index][j])

In [None]:
model_selectors = formats.convert_rankings_to_selectors(
    model_rankings[indices].reshape(num_examples, 1, -1),
    model_groups[indices].reshape(num_examples, 1, -1))
for i, selector in enumerate(model_selectors):
  print('Example', i, ', model prediction: ', selector[0])

In [None]:
ranking_metrics.average_overlap(
    model_rankings[indices],
    data['test_rankings'][indices, 0],
    jnp.sum(data['test_groups'][indices, 0] >= 0, axis=1))

In [None]:
ranking_metrics.partial_average_overlap(
    model_rankings[indices],
    model_groups[indices],
    data['test_rankings'][indices, 0],
    data['test_groups'][indices, 0],
    jnp.sum(data['test_groups'][indices, 0] >= 0, axis=1))