# Uncertainty-adjusted accuracy evaluation

See `README.md` for installation and usage instructions.

This notebook shows an example of using the partial average overlap
implementation of [2] on the dataset of [1].

```
[1] Stutz, D., Roy, A.G., Matejovicova, T., Strachan, P., Cemgil, A.T.,
    & Doucet, A. (2023).
    Conformal prediction under ambiguous ground truth. ArXiv, abs/2307.09302.
[2] Stutz, D., Cemgil, A.T., Roy, A.G., Matejovicova, T., Barsbey, M.,
    Strachan, P., Schaekermann, M., Freyberg, J.V., Rikhye, R.V., Freeman, B.,
    Matos, J.P., Telang, U., Webster, D.R., Liu, Y., Corrado, G.S., Matias, Y.,
    Kohli, P., Liu, Y., Doucet, A., & Karthikesalingam, A. (2023).
    Evaluating AI systems under uncertain ground truth: a case study in
    dermatology. ArXiv, abs/2307.02191.
```

In [None]:
import jax.numpy as jnp
import matplotlib
import os
import pickle

In [None]:
import ranking_metrics
import colab_utils
import formats

In [None]:
with open('toy_data.pkl', 'rb') as f:
  data = pickle.load(f)

In [None]:
with open(f'toy_predictions0.pkl', 'rb') as f:
  model_predictions = pickle.load(f)

In [None]:
model_rankings = jnp.argsort(model_predictions, axis=1)
model_groups = jnp.array([jnp.arange(model_rankings.shape[1]) for _ in model_rankings])

In [None]:
num_examples = 10
assert num_examples > 1

In [None]:
for i, selector in enumerate(data['test_selectors'][:num_examples]):
  print('Example ', i, ', annotation 0: ', selector[0])

In [None]:
model_selectors = formats.convert_rankings_to_selectors(
    model_rankings[:num_examples].reshape(num_examples, 1, -1),
    model_groups[:num_examples].reshape(num_examples, 1, -1))
for i, selector in enumerate(model_selectors):
  print('Example ', i, ', model prediction: ', selector[0])

In [None]:
ranking_metrics.average_overlap(
    model_rankings[:num_examples],
    data['test_rankings'][:num_examples, 0],
    jnp.sum(data['test_groups'][:num_examples, 0] >= 0, axis=1))

In [None]:
ranking_metrics.partial_average_overlap(
    model_rankings[:num_examples],
    model_groups[:num_examples],
    data['test_rankings'][:num_examples, 0],
    data['test_groups'][:num_examples, 0],
    jnp.sum(data['test_groups'][:num_examples, 0] >= 0, axis=1))