# MNIST Multi-label MCCP experiments

See `README.md` for installation and usage instructions.

This notebook shows an example of using Monte Carlo conformal prediction [1] on
a synthetic dataset derived from MNIST.

```
[1] Stutz, D., Roy, A.G., Matejovicova, T., Strachan, P., Cemgil, A.T.,
    & Doucet, A. (2023).
    Conformal prediction under ambiguous ground truth. ArXiv, abs/2307.09302.
```

## Imports and setup

In [None]:
import jax
import jax.numpy as jnp
from matplotlib import pyplot as plt
import sklearn.neural_network
import tensorflow as tf
tf.compat.v1.enable_eager_execution()
import tensorflow_datasets as tfds

In [None]:
import conformal_prediction
import monte_carlo
import colab_utils
import gaussian_toy_dataset as gtd

In [None]:
colab_utils.set_style()
plot_hist = colab_utils.plot_hist

## Data

We create a synthetic multi-label dataset by overlaying multiple digits in a
single image.

In [None]:
num_examples = 10000
ds = tfds.load('mnist', split=f'train[:{num_examples}]').shuffle(
    num_examples).batch(1000)

In [None]:
images = []
labels = []
for b, batch in enumerate(ds):
  images.append(jnp.array(batch['image'].numpy()))
  labels.append(jnp.array(batch['label'].numpy()))
images = jnp.concatenate(images)
labels = jnp.concatenate(labels)

In [None]:
split = num_examples//2
combined_images = []
combined_labels = []
rng = gtd.PRNGSequence(0)
r = jax.random.uniform(jax.random.PRNGKey(0), (split,))
for n in range(split):
  r1 = int(r[n] * 3)
  combined_image = jnp.repeat(images[n], 3, axis=2)
  combined_image = combined_image.at[:, :, jnp.arange(3) != r1].set(0)
  combined_label = jax.nn.one_hot(labels[n], 10)
  if labels[n] != labels[split + n]:
    r2 = (r1 + 1) % 3
    other_image = jnp.repeat(images[split + n], 3, axis=2)
    other_image = other_image.at[:, :, jnp.arange(3) != r2].set(0)
    combined_image += images[split + n]
    combined_label += jax.nn.one_hot(labels[split + n], 10)
  combined_images.append(combined_image)
  combined_labels.append(combined_label)
combined_images = jnp.array(combined_images)
combined_labels = jnp.array(combined_labels)

In [None]:
split = int((num_examples // 2) * 3/5.)
train_images = combined_images[:split]
train_labels = combined_labels[:split]
held_out_images = combined_images[split:]
held_out_labels = combined_labels[split:]

## Models

We train 10 binary models to recognize each digit individually. These will
be used to obtain the per-digit conformity scores.

In [None]:
predictions = []
for k in range(10):
  classifier = sklearn.neural_network.MLPClassifier(alpha=1, max_iter=100)
  classifier.fit(
      train_images.reshape(train_images.shape[0], -1),
      train_labels[:, k])
  predictions_k = classifier.predict_log_proba(
      held_out_images.reshape(held_out_images.shape[0], -1))
  predictions_k = jax.nn.softmax(predictions_k)[:, 1]
  predictions.append(predictions_k)
predictions = jnp.array(predictions).T

## Experiments

In [None]:
def plot_trials(alpha=0.1, num_trials=10, **kwargs):
  """Run conformal prediction trials."""
  rng = jax.random.PRNGKey(0)
  coverages = []
  sizes = []
  for t in range(num_trials):
    permutation_rng, mc_rng, rng = jax.random.split(rng, 3)
    split = int((num_examples // 2) * 2/5.) // 2
    permutation = jax.random.permutation(permutation_rng, 2 * split)
    val_labels = held_out_labels[permutation[:split]]
    val_predictions = predictions[permutation[:split]]
    test_images = held_out_images[permutation[split:]]
    test_labels = held_out_labels[permutation[split:]]
    test_predictions = predictions[permutation[split:]]

    num_classes = val_predictions.shape[1]
    mc_val_predictions, mc_val_labels = monte_carlo.sample_mc_labels(
        mc_rng, val_predictions, val_labels, 10)
    mc_val_predictions = mc_val_predictions.reshape(-1, num_classes)
    mc_val_labels = mc_val_labels.reshape(-1)

    p_values = conformal_prediction.compute_p_values(
        mc_val_predictions, mc_val_labels, test_predictions)
    confidence_sets = (p_values >= alpha).astype(int)
    coverages.append(jnp.sum(
        test_labels * confidence_sets, axis=1) / jnp.sum(test_labels, axis=1))
    sizes.append(jnp.sum(confidence_sets, axis=1))
  coverages = jnp.array(coverages)
  sizes = jnp.array(sizes)

  hist, _ = colab_utils.plot_hist(jnp.mean(coverages, axis=1), normalize=True)
  plt.vlines(1 - alpha, 0, jnp.max(hist), color='black', label='Target')
  plt.title('Aggregated coverage for multi-label classification')
  plt.xlabel('Empirical coverage')
  plt.ylabel('Frequency')
  plt.legend()
  plt.gcf().set_size_inches(kwargs.get('width', 5), kwargs.get('height', 3))
  plt.savefig('mnist_mccp_coverage.pdf', bbox_inches="tight")
  plt.show()
  %download_file mnist_mccp_coverage.pdf

  hist, _ = colab_utils.plot_hist(
      jnp.mean(sizes, axis=1), normalize=True, label='Inefficiency')
  plt.vlines(
      jnp.mean(sizes), 0, jnp.max(hist),
      label=f'Average: {jnp.mean(sizes):.2f}', color='black')
  plt.title('Inefficiency histogram')
  plt.xlabel('Inefficiency')
  plt.ylabel('Frequency')
  plt.legend()
  plt.gcf().set_size_inches(kwargs.get('width', 5), kwargs.get('height', 3))
  plt.savefig('mnist_mccp_ineff.pdf', bbox_inches="tight")
  plt.show()
  %download_file mnist_mccp_ineff.pdf

  alpha = 0.1
  for n in range(3):
    plt.bar(jnp.arange(10), test_labels[n], alpha=0.5, label='Labels')
    plt.bar(jnp.arange(10), p_values[n], alpha=0.5, label='p-values')
    plt.hlines(alpha, -0.5, 9.5, label='Confidence level', color='red')
    plt.legend(loc='upper left', bbox_to_anchor=(1.02, 1))
    plt.xlabel('Class')
    plt.xticks([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
    plt.gcf().set_size_inches(
        kwargs.get('width', 3), kwargs.get('height', 1.25))
    plt.savefig(f'mnist_example{n}.pdf', bbox_inches="tight")
    plt.show()
    %download_file mnist_example{n}.pdf

    plt.imshow(test_images[n] / 255.)
    plt.axis('off')
    plt.gcf().set_size_inches(kwargs.get('width', 2), kwargs.get('height', 2))
    plt.savefig(f'mnist_sets{n}.pdf', bbox_inches="tight")
    plt.show()
    %download_file mnist_sets{n}.pdf

In [None]:
plot_trials(alpha=0.1, num_trials=500)