# Toy dataset and models

See `README.md` for installation and usage instructions.

This notebook creates the reference toy dataset used throughout [1], starting
with Figure 3. However, it also includes pseudo annotations following the format
in our skin condition case study, namely partial rankings, as discussed in [2].

```
[1] Stutz, D., Roy, A.G., Matejovicova, T., Strachan, P., Cemgil, A.T.,
    & Doucet, A. (2023).
    Conformal prediction under ambiguous ground truth. ArXiv, abs/2307.09302.
[2] Stutz, D., Cemgil, A.T., Roy, A.G., Matejovicova, T., Barsbey, M.,
    Strachan, P., Schaekermann, M., Freyberg, J.V., Rikhye, R.V., Freeman, B.,
    Matos, J.P., Telang, U., Webster, D.R., Liu, Y., Corrado, G.S., Matias, Y.,
    Kohli, P., Liu, Y., Doucet, A., & Karthikesalingam, A. (2023).
    Evaluating AI systems under uncertain ground truth: a case study in
    dermatology. ArXiv, abs/2307.02191.
```

## Imports and setup

In [None]:
import jax
import jax.numpy as jnp
import matplotlib
from matplotlib import pyplot as plt
import numpy as np
import os
import pickle
import sklearn.neural_network

In [None]:
import formats
import irn as aggregation
import gaussian_toy_dataset as gtd
import colab_utils

In [None]:
colab_utils.set_style()
plot_hist = colab_utils.plot_hist

## Data

The Gaussian toy dataset samples examples from multiple overlapping Gaussians, see `gaussian_toy_dataset.py` for details.

Here, we create the 2-dimensional reference examples used in [1] for illustrative purposes and an easy way to play around with this repository.

In [None]:
def get_data(config):
  """Generate data using the config."""
  # Defines a dataset of multiple overlapping Gaussians.
  generator = gtd.GaussianToyDataset(
      config['rng'], jnp.array(config['class_weights']),
      config['class_sigmas'], config['dimensionality'], config['sigma'])
  num_examples = config['train_examples'] + config['test_examples']
  # Sample points x from the overlapping Gaussian distributions.
  examples, ground_truths = generator.sample_points(num_examples)
  # Compute the true posterior distributions p(y|x).
  human_ground_truths = generator.evaluate_points(examples)
  # Sample annotator rankings for all points.
  rankings, groups = generator.sample_rankings(
      human_ground_truths,
      config['reader_sharpness'],
      config['expected_length'],
      config['grouping_threshold'])
  # Convert rankings and compute IRN aggregation.
  selectors = formats.convert_rankings_to_selectors(rankings, groups)
  irn = aggregation.aggregate_irn(rankings, groups)
  return {
      'config': config,
      'train_examples': examples[:config['train_examples']],
      'train_labels': ground_truths[:config['train_examples']],
      'train_smooth_labels': human_ground_truths[:config['train_examples']],
      'train_rankings': rankings[:config['train_examples']],
      'train_groups': groups[:config['train_examples']],
      'train_selectors': selectors[:config['train_examples']],
      'train_irn': irn[:config['train_examples']],
      'test_examples': examples[config['train_examples']:],
      'test_labels': ground_truths[config['train_examples']:],
      'test_smooth_labels': human_ground_truths[config['train_examples']:],
      'test_rankings': rankings[config['train_examples']:],
      'test_groups': groups[config['train_examples']:],
      'test_selectors': selectors[config['train_examples']:],
      'test_irn': irn[config['train_examples']:],
  }

In [None]:
config = {}
config['rng'] = gtd.PRNGSequence(5)
config['dimensionality'] = 2
config['sigma'] = 0.3
config['class_weights'] = [1]*3
config['class_sigmas'] = 0.1
config['train_examples'] = 1000
# Note that in the paper we used 20000 test examples.
config['test_examples'] = 1000
config['expected_length'] = 1.5
config['grouping_threshold'] = 0.05
# Number of readers and their sharpness.
config['reader_sharpness'] = jnp.array([500000, 100000, 50000, 1000000, 500000, 150000, 100000, 1000000, 100000, 90000])

In [None]:
data = get_data(config)

In [None]:
colors = np.array([
    [228,26,28],
    [55,126,184],
    [77,175,74],
]) / 255.
colab_utils.plot_data(
    data['train_examples'], data['train_labels'],
    title='Examples with their true labels', name='data', colors=colors)

In [None]:
colab_utils.plot_data(
    data['train_examples'],
    np.argmax(data['train_smooth_labels'], axis=1),
    title='Examples with their voted labels', name='data_top1', colors=colors)

In [None]:
colab_utils.plot_smooth_data(
      data['train_examples'], data['train_smooth_labels'], name='data_smooth', colors=colors)

In [None]:
with open('data/toy_data.pkl', 'wb') as f:
  pickle.dump(data, f)

## Model

We train a small MLP. Note that in the paper, we trained our own 2-layer MLP using Haiku; for simplicty this Colab uses `sklearn` instead.

In [None]:
predictions = []
for seed in range(4):
  classifier = sklearn.neural_network.MLPClassifier(alpha=1, max_iter=(seed + 1) * 25, random_state=seed)
  classifier.fit(
      data['train_examples'],
      jax.nn.one_hot(data['train_labels'], 3))
  predictions_k = classifier.predict_log_proba(data['test_examples'])
  predictions_k = jax.nn.softmax(predictions_k)
  predictions.append(predictions_k)
predictions = jnp.array(predictions)

In [None]:
for seed in range(predictions.shape[0]):
    print(seed, jnp.mean(data['test_labels'] == jnp.argmax(predictions[seed], axis=1)))
    with open(f'data/toy_predictions{seed}.pkl', 'wb') as f:
        pickle.dump(predictions[seed], f)