# Monte Carlo conformal prediction

See `README.md` for installation and usage instructions.

This notebook re-creates some of the examples and figures from [1] on the toy dataset and results of [1] on the dermatology data.

```
[1] Stutz, D., Roy, A.G., Matejovicova, T., Strachan, P., Cemgil, A.T.,
    & Doucet, A. (2023).
    Conformal prediction under ambiguous ground truth. ArXiv, abs/2307.09302.
```

## Imports and setup

In [None]:
import jax
import jax.numpy as jnp
import matplotlib
from matplotlib import pyplot as plt
import numpy as np
import os
import pickle

In [None]:
import conformal_prediction
import monte_carlo
import p_value_combination
import plausibility_regions
import classification_metrics
import colab_utils

In [None]:
colab_utils.set_style()
plot_hist = colab_utils.plot_hist

## Toy dataset

In [None]:
with open('data/toy_data.pkl', 'rb') as f:
  data = pickle.load(f)

In [None]:
with open('data/toy_predictions0.pkl', 'rb') as f:
  predictions = pickle.load(f)

### CP with true and top-1 labels

Compare standard conformal prediction calibrated against majority voted (top-1) labels and conformal prediction calibrated against the true labels -- which we have access to on the toy dataset but not in practice.

In [None]:
def run_trial(
    predictions, data, method='',
    alpha=0.1, trials=100, split=0.5):
  """Run a conformal prediction experiment."""
  results = {}
  tags = ['inefficiencies', 'true_coverages', 'aggregated_coverages', 'top1_coverages']
  for tag in tags:
    results[tag] = []
  keys = jax.random.split(jax.random.PRNGKey(0), 2 * trials)
  for t in range(trials):
    permutation = jax.random.permutation(keys[2 * t], predictions.shape[0])
    val_examples = int(predictions.shape[0]*split)
    val_predictions = predictions[permutation[:val_examples]]
    test_predictions = predictions[permutation[val_examples:]]
    val_human_ground_truth = data['test_smooth_labels'][permutation[:val_examples]]
    test_human_ground_truth = data['test_smooth_labels'][permutation[val_examples:]]
    val_ground_truth = data['test_labels'][permutation[:val_examples]]
    test_ground_truth = data['test_labels'][permutation[val_examples:]]

    if method == 'top1':
      val_labels = jnp.argmax(val_human_ground_truth, axis=1)
      threshold = conformal_prediction.calibrate_threshold(
          val_predictions, val_labels, alpha)
    elif method == 'mccp':
      threshold = monte_carlo.calibrate_mc_threshold(
          keys[2 * t + 1], val_predictions, val_human_ground_truth, num_samples=10, alpha=alpha)
    else:
      val_labels = val_ground_truth
      threshold = conformal_prediction.calibrate_threshold(
          val_predictions, val_labels, alpha)

    confidence_sets = conformal_prediction.predict_threshold(
        test_predictions, threshold)

    results['inefficiencies'].append(classification_metrics.size(
        confidence_sets))
    test_one_hot_ground_truth = jax.nn.one_hot(
        test_ground_truth, confidence_sets.shape[1])
    results['true_coverages'].append(classification_metrics.aggregated_coverage(
        confidence_sets, test_one_hot_ground_truth))
    test_top1_ground_truth = jax.nn.one_hot(
        jnp.argmax(test_human_ground_truth, axis=1),
        test_human_ground_truth.shape[1])
    results['top1_coverages'].append(
        classification_metrics.aggregated_coverage(
            confidence_sets, test_top1_ground_truth))
    results['aggregated_coverages'].append(
        classification_metrics.aggregated_coverage(
            confidence_sets, test_human_ground_truth))
  for tag in tags:
    results[tag] = jnp.array(results[tag])
  return results

In [None]:
def plot_coverage_top1_calibration(predictions, data, **kwargs):
  """Plot coverage when calibrating against true or top-1 labels."""
  top1_results = run_trial(
      predictions, data, method='top1',
      alpha=0.05, trials=100)
  plot_hist(jnp.mean(top1_results['true_coverages'], axis=-1), normalize=True, label='Coverage of true labels')
  plot_hist(jnp.mean(top1_results['top1_coverages'], axis=-1), normalize=True, label='Coverage of voted labels')
  plt.vlines(0.95, 0, 0.15, color='black', label='Target')
  plt.title('Calibration with voted labels')
  plt.xlabel('Empirical coverage')
  plt.ylabel('Frequency')
  plt.legend()
  plt.gcf().set_size_inches(kwargs.get('width', 5), kwargs.get('height', 3))
  if kwargs.get('name', False):
      plt.savefig(kwargs.get('name') + '.pdf', bbox_inches="tight")
  plt.show()

In [None]:
plot_coverage_top1_calibration(
    predictions, data, name='coverage_top1_calibration')

In [None]:
def plot_coverage_mccp_calibration(predictions, data, **kwargs):
  """Plot coverage when calibrating against true or top-1 labels."""
  top1_results = run_trial(
      predictions, data, method='mccp',
      alpha=0.05, trials=100)
  plot_hist(jnp.mean(top1_results['true_coverages'], axis=-1), normalize=True, label='Coverage of true labels')
  plot_hist(jnp.mean(top1_results['top1_coverages'], axis=-1), normalize=True, label='Coverage of voted labels')
  plt.vlines(0.95, 0, 0.15, color='black', label='Target')
  plt.title('Monte Carlo conformal calibration')
  plt.xlabel('Empirical coverage')
  plt.ylabel('Frequency')
  plt.legend()
  plt.gcf().set_size_inches(kwargs.get('width', 5), kwargs.get('height', 3))
  if kwargs.get('name', False):
      plt.savefig(kwargs.get('name') + '.pdf', bbox_inches="tight")
  plt.show()

In [None]:
plot_coverage_mccp_calibration(
    predictions, data, name='coverage_mccp_calibration')

### Aggregated coverage

This experiment illustrates aggregated coverage.

In [None]:
def plot_aggregated_coverage(predictions, data, **kwargs):
  true_results = run_trial(
      predictions, data, method='',
      alpha=0.05, trials=500)

  plot_hist(
      jnp.mean(true_results['true_coverages'], axis=1), bins=40,
      normalize=True, label='True coverage', alpha=0.8)
  plot_hist(
      jnp.mean(true_results['aggregated_coverages'], axis=1), bins=40,
      normalize=True, label='Aggregated coverage', alpha=0.8)
  plt.title('Average coverage across trials')
  plt.xlabel('Empirical coverage')
  plt.ylabel('Frequency')
  plt.xticks([0.94, 0.945, 0.95, 0.955, 0.96])
  plt.xlim(0.94, 0.96)
  plt.legend()
  plt.gcf().set_size_inches(kwargs.get('width', 5), kwargs.get('height', 3))
  plt.savefig('aggregated_coverage_histogram1.pdf', bbox_inches="tight")
  plt.show()

  plt.plot(
      jnp.arange(true_results['true_coverages'][0].shape[0]),
      jnp.sort(true_results['true_coverages'][0]),
      label='True label covered')
  plt.fill_between(
      jnp.arange(true_results['true_coverages'][0].shape[0]),
      jnp.zeros(true_results['true_coverages'][0].shape[0]),
      jnp.sort(true_results['true_coverages'][0]),
      alpha=0.2, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][0],
      label=(
          'Realized coverage'
          f'({jnp.mean(true_results["true_coverages"][0]):.2f})'))
  plt.plot(
      jnp.arange(true_results['aggregated_coverages'][0].shape[0]),
      jnp.sort(true_results['aggregated_coverages'][0]),
      label='Plausibility mass covered')
  plt.fill_between(
      jnp.arange(true_results['aggregated_coverages'][0].shape[0]),
      jnp.zeros(true_results['aggregated_coverages'][0].shape[0]),
      jnp.sort(true_results['aggregated_coverages'][0]),
      alpha=0.2, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][1],
      label=(
          'Realized aggregated coverage'
          f'({jnp.mean(true_results["aggregated_coverages"][0]):.2f})'))
  plt.title('Correctness across examples for single trial')
  plt.xlabel('Sorted examples (separate sorting)')
  plt.ylabel('Realized coverage')
  plt.xlim(0, true_results['aggregated_coverages'][0].shape[0])
  plt.legend()
  plt.gcf().set_size_inches(kwargs.get('width', 5), kwargs.get('height', 3))
  plt.savefig('aggregated_coverage_sorted.pdf', bbox_inches="tight")
  plt.show()

In [None]:
plot_aggregated_coverage(predictions, data)

### Monte Carlo conformal prediction

Running and evaluation Monte Carlo conformal prediction in comparison to calibrating against the majority-voted (top-1) labels.

In [None]:
def run_trial(
    predictions, data,
    alpha=0.05, num_trials=1000, split=0.5, seed=0):
  """Run a conformal prediction experiment."""
  results = {}
  metrics = ['true_coverages', 'aggregated_coverages']
  methods = ['top1_', 'mc_']
  for method in methods:
    for metric in metrics:
      results[method + metric] = []

  permutation = jax.random.permutation(
      jax.random.PRNGKey(seed), predictions.shape[0])
  val_examples = int(predictions.shape[0]*split)
  val_predictions = predictions[permutation[:val_examples]]
  test_predictions = predictions[permutation[val_examples:]]
  val_human_ground_truth = data['test_smooth_labels'][permutation[:val_examples]]
  test_human_ground_truth = data['test_smooth_labels'][permutation[val_examples:]]
  test_ground_truth = data['test_labels'][permutation[val_examples:]]

  keys = jax.random.split(jax.random.PRNGKey(seed + 1), num_trials)
  for t in range(num_trials):

    def evaluate_method(confidence_sets, key):
      test_one_hot_ground_truth = jax.nn.one_hot(
          test_ground_truth, confidence_sets.shape[1])
      results[f'{key}true_coverages'].append(classification_metrics.aggregated_coverage(
          confidence_sets, test_one_hot_ground_truth))
      results[f'{key}aggregated_coverages'].append(
          classification_metrics.aggregated_coverage(
              confidence_sets, test_human_ground_truth))

    val_top1_labels = jnp.argmax(val_human_ground_truth, axis=1)
    top1_threshold = conformal_prediction.calibrate_threshold(
        val_predictions, val_top1_labels, alpha)
    top1_confidence_sets = conformal_prediction.predict_threshold(
          test_predictions, top1_threshold)
    evaluate_method(top1_confidence_sets, 'top1_')

    val_mc_predictions, mc_labels = monte_carlo.sample_mc_labels(
      keys[t], val_predictions, val_human_ground_truth, num_samples=1)
    val_mc_predictions = val_mc_predictions.reshape(
        -1, val_mc_predictions.shape[-1])
    val_mc_labels = mc_labels.reshape(-1)
    mc_threshold = conformal_prediction.calibrate_threshold(
        val_mc_predictions, val_mc_labels, alpha)
    mc_confidence_sets = conformal_prediction.predict_threshold(
        test_predictions, mc_threshold)
    evaluate_method(mc_confidence_sets, f'mc_')

  for key in results.keys():
    results[key] = jnp.array(results[key])
  return results

In [None]:
def plot_label_randomness(predictions, data, **kwargs):
  alpha = 0.05
  results = run_trial(predictions, data, alpha, seed=kwargs.get('seed', 0))
  vmax = 0
  hist, _ = plot_hist(
      jnp.mean(results['mc_aggregated_coverages'], axis=-1), normalize=True,
      alpha=0.65, label='Aggregated coverage')
  vmax = max(vmax, jnp.max(hist))
  plt.vlines(1 - alpha, 0, vmax, color='black', label='Target')
  plt.title('Variation in aggregated coverage from sampled labels')
  plt.xlabel('Empirical coverage')
  plt.ylabel('Frequency')
  plt.legend()
  plt.gcf().set_size_inches(kwargs.get('width', 5), kwargs.get('height', 3))
  if kwargs.get('name', False):
      plt.savefig(kwargs.get('name') + '.pdf', bbox_inches="tight")
  plt.show()

In [None]:
plot_label_randomness(predictions, data, seed=3, name=f'label_randomness3')

In [None]:
def run_trial(
    predictions, data, alpha=0.05, num_trials=100, split=0.5):
  """Run a conformal prediction experiment."""
  results = {}
  ms = [1, 5, 10]
  metrics = [
      'inefficiencies',
      'true_coverages',
      'aggregated_coverages',
      'top1_coverages',
  ]
  methods = ['true_', 'top1_'] + [f'mc{m}_' for m in ms]
  for method in methods:
    for metric in metrics:
      results[method + metric] = []
  keys = jax.random.split(jax.random.PRNGKey(0), 2 * num_trials)
  for t in range(num_trials):
    permutation = jax.random.permutation(keys[2 * t], predictions.shape[0])
    val_examples = int(predictions.shape[0]*split)
    val_predictions = predictions[permutation[:val_examples]]
    test_predictions = predictions[permutation[val_examples:]]
    val_human_ground_truth = data['test_smooth_labels'][permutation[:val_examples]]
    test_human_ground_truth = data['test_smooth_labels'][permutation[val_examples:]]
    val_ground_truth = data['test_labels'][permutation[:val_examples]]
    test_ground_truth = data['test_labels'][permutation[val_examples:]]

    def evaluate_method(confidence_sets, key):
      results[f'{key}inefficiencies'].append(classification_metrics.size(
          confidence_sets))
      test_one_hot_ground_truth = jax.nn.one_hot(
          test_ground_truth, confidence_sets.shape[1])
      results[f'{key}true_coverages'].append(classification_metrics.aggregated_coverage(
          confidence_sets, test_one_hot_ground_truth))
      test_top1_ground_truth = jax.nn.one_hot(
          jnp.argmax(test_human_ground_truth, axis=1),
          test_human_ground_truth.shape[1])
      results[f'{key}top1_coverages'].append(
          classification_metrics.aggregated_coverage(
              confidence_sets, test_top1_ground_truth))
      results[f'{key}aggregated_coverages'].append(
          classification_metrics.aggregated_coverage(
              confidence_sets, test_human_ground_truth))

    val_true_labels = val_ground_truth
    val_top1_labels = jnp.argmax(val_human_ground_truth, axis=1)
    true_threshold = conformal_prediction.calibrate_threshold(
        val_predictions, val_true_labels, alpha)
    top1_threshold = conformal_prediction.calibrate_threshold(
        val_predictions, val_top1_labels, alpha)
    true_confidence_sets = conformal_prediction.predict_threshold(
          test_predictions, true_threshold)
    top1_confidence_sets = conformal_prediction.predict_threshold(
          test_predictions, top1_threshold)
    evaluate_method(true_confidence_sets, 'true_')
    evaluate_method(top1_confidence_sets, 'top1_')

    for m in ms:
      val_mc_predictions, mc_labels = monte_carlo.sample_mc_labels(
        keys[2 * t + 1], val_predictions, val_human_ground_truth, num_samples=m)
      val_mc_predictions = val_mc_predictions.reshape(-1, val_mc_predictions.shape[-1])
      val_mc_labels = mc_labels.reshape(-1)
      mc_threshold = conformal_prediction.calibrate_threshold(
          val_mc_predictions, val_mc_labels, alpha)
      mc_confidence_sets = conformal_prediction.predict_threshold(
          test_predictions, mc_threshold)
      evaluate_method(mc_confidence_sets, f'mc{m}_')

  for key in results.keys():
    results[key] = jnp.array(results[key])
  return results

In [None]:
def plot_std(alpha=0.05, num_trials=100, ms=[1, 5, 10], **kwargs):
  """Plot standard deviation for MC conformal prediction."""
  results = []
  splits = [0.05, 0.1, 0.2, 0.3, 0.4, 0.5]
  for split in splits:
    results.append(run_trial(
        predictions, data, alpha=0.05, num_trials=num_trials, split=split))

  for m in ms:
    values = [result[f'mc{m}_aggregated_coverages'] for result in results]
    plt.plot(
        splits,
         [jnp.std(jnp.mean(value, axis=-1), axis=-1) for value in values],
        label=r"$m =$" + f'{m} sampled labels')
  plt.title('Standard deviation in aggregated coverage')
  plt.xlabel('Fraction of calibration data')
  plt.ylabel('Std in empirical coverage')
  plt.legend(bbox_to_anchor=[1, 1], loc='upper right')
  plt.gcf().set_size_inches(kwargs.get('width', 5), kwargs.get('height', 3))
  plt.savefig('mc_std_aggregated_lines_wo.pdf', bbox_inches="tight")
  plt.show()

  vmax = 0
  for m in ms:
    hist, _ = plot_hist(
        jnp.mean(results[0][f'mc{m}_aggregated_coverages'], axis=-1), normalize=True,
        alpha=0.65, label=r"$m =$" + f'{m} sampled labels', bins=40, range=(0.92, 0.98))
    vmax = max(vmax, jnp.max(hist))
  plt.vlines(1 - alpha, 0, vmax, color='black', label='Target')
  plt.title('Aggregated coverage for 10% calibration data')
  plt.xlabel('Empirical coverage')
  plt.ylabel('Frequency')
  plt.xlim(0.9, 0.98)
  plt.legend(bbox_to_anchor=[0, 1], loc='upper left')
  plt.gcf().set_size_inches(kwargs.get('width', 5), kwargs.get('height', 3))
  plt.savefig('mc_std_aggregated_histogram_wo.pdf', bbox_inches="tight")
  plt.show()

In [None]:
plot_std(alpha=0.05, num_trials=500)

### ECDF-corrected Monte Carlo conformal prediction

In [None]:
def plot_ecdf(**kwargs):
  num_tests = 10
  num_examples = 10000
  val_examples = 10000//2
  all_p_values = jax.random.uniform(jax.random.PRNGKey(0), (num_tests, num_examples))
  dependent_p_values = all_p_values
  all_p_values = jnp.concatenate((all_p_values, dependent_p_values), axis=0)
  combined_p_values = jnp.mean(all_p_values, axis=0)
  val_combined_p_values = combined_p_values[:val_examples]
  test_combined_p_values = combined_p_values[val_examples:]
  test_corrected_p_values = p_value_combination.combine_ecdf_p_values(
      val_combined_p_values, test_combined_p_values)

  plot_hist(
      test_combined_p_values, normalize=True,
       alpha=0.65, label='Averaged p-values')
  plot_hist(
      test_corrected_p_values, normalize=True,
      alpha=0.65, label=f'ECDF corrected p-values')
  plt.legend()
  plt.title(f'Distribution of combined p-values')
  plt.xlabel('Value')
  plt.ylabel('Frequency')
  plt.gcf().set_size_inches(kwargs.get('width', 5), kwargs.get('height', 3))
  plt.savefig('ecdf_p_values.pdf', bbox_inches="tight")
  plt.show()

  baseline_coverages = []
  method_coverages = []
  max_method_coverages = []
  min_method_coverages = []
  delta = 0.0001
  alphas = jnp.linspace(0, 1, 26)
  epsilon = np.sqrt(np.log(2. / delta) / (2 * val_examples))
  for alpha in alphas:
    baseline_coverages.append(jnp.mean(test_combined_p_values >= alpha))
    method_coverage = jnp.mean(test_corrected_p_values >= alpha)
    method_coverages.append(method_coverage)
    max_method_coverages.append(min(1, method_coverage + epsilon))
    min_method_coverages.append(max(0, method_coverage - epsilon))
  max_method_coverages = np.array(max_method_coverages)
  min_method_coverages = np.array(min_method_coverages)

  plt.plot(1 - alphas, baseline_coverages, label='Baseline')
  plt.plot(1 - alphas, method_coverages, label=f'ECDF', color='green')
  plt.plot(1 - alphas, max_method_coverages, color='green', alpha=0.2)
  plt.plot(1 - alphas, min_method_coverages, color='green', alpha=0.2)
  plt.fill_between(
      alphas, 1 - max_method_coverages, 1 - min_method_coverages, alpha=0.1,
      color='green', label='ECDF $1 - \delta$ band')
  plt.title(f'ECDF correction of p-values')
  plt.xlabel(r'Target coverage $1 - \alpha$')
  plt.ylabel('Empirical coverage')
  plt.legend()
  plt.gcf().set_size_inches(kwargs.get('width', 5), kwargs.get('height', 3))
  plt.savefig('ecdf_lines.pdf', bbox_inches="tight")
  plt.show()

In [None]:
plot_ecdf()

In [None]:
def plot_coverage_ecdf_mccp(
    predictions, data,
    num_samples=10, split=0.5, alpha=0.05, num_trials=10, **kwargs):
  """Plot coverage for ECDF based MC conformal prediction."""
  coverages = []
  mc_coverages = []
  corrected_mc_coverages = []
  rng = jax.random.PRNGKey(0)
  for _ in range(num_trials):
    permutation_rng, mc_rng, rng = jax.random.split(rng, 3)
    permutation = jax.random.permutation(permutation_rng, predictions.shape[0])
    _, num_classes = predictions.shape
    val_examples = int(predictions.shape[0]*split)
    val_predictions = predictions[permutation[:val_examples]]
    test_predictions = predictions[permutation[val_examples:]]
    val_human_ground_truth = data['test_smooth_labels'][permutation[:val_examples]]
    val_labels = data['test_labels'][permutation[:val_examples]]
    test_labels = data['test_labels'][permutation[val_examples:]]

    p_values = conformal_prediction.compute_p_values(
        val_predictions, val_labels, test_predictions)

    mc_p_values = monte_carlo.compute_mc_p_values(
        mc_rng, val_predictions,
        val_human_ground_truth, test_predictions, num_samples)
    mc_p_values = jnp.mean(mc_p_values, axis=0)

    corrected_mc_p_values = monte_carlo.compute_mc_ecdf_p_values(
        mc_rng, val_predictions,
        val_human_ground_truth, test_predictions, num_samples)

    confidence_sets = conformal_prediction.predict_p_values(p_values, alpha)
    mc_confidence_sets = conformal_prediction.predict_p_values(mc_p_values, alpha)
    corrected_mc_confidence_sets = conformal_prediction.predict_p_values(
        corrected_mc_p_values, alpha)

    coverages.append(classification_metrics.aggregated_coverage(
        confidence_sets, jax.nn.one_hot(test_labels, num_classes)))
    mc_coverages.append(classification_metrics.aggregated_coverage(
        mc_confidence_sets, jax.nn.one_hot(test_labels, num_classes)))
    corrected_mc_coverages.append(classification_metrics.aggregated_coverage(
        corrected_mc_confidence_sets,
        jax.nn.one_hot(test_labels, num_classes)))

  mc_coverages = jnp.array(mc_coverages)
  corrected_mc_coverages = jnp.array(corrected_mc_coverages)

  vmax = 0
  hist, _ = colab_utils.plot_hist(
      jnp.mean(mc_coverages, axis=1),
      normalize=True, alpha=0.65, label='Monte Carlo CP')
  vmax = max(np.max(hist), vmax)
  hist, _ = colab_utils.plot_hist(
      jnp.mean(corrected_mc_coverages, axis=1),
      normalize=True, alpha=0.65, label='ECDF Monte Carlo CP')
  vmax = max(np.max(hist), vmax)
  plt.vlines(1 - alpha, 0, vmax, label='Target', color='black')
  plt.title('Aggregated coverage of ECDF-based approach')
  plt.xlabel('Empirical coverage')
  plt.ylabel('Frequency')
  plt.xticks([0.94, 0.95, 0.96])
  plt.legend()
  plt.gcf().set_size_inches(kwargs.get('width', 5), kwargs.get('height', 3))
  if kwargs.get('name', False):
      plt.savefig(kwargs.get('name') + '.pdf', bbox_inches="tight")
  plt.show()

In [None]:
plot_coverage_ecdf_mccp(
    predictions, data, alpha=0.05, num_samples=10, num_trials=100,
    name='ecdf_coverage')

### Aggregated conformity scores

This corresponds to an experiment calibrating with so-called aggregated conformity scores (in the first version our paper called *expected* conformity scores) which leads to plausibility regions.

Please see the first version of our paper on ArXiv: [arxiv.org/abs/2307.09302v1](https://arxiv.org/abs/2307.09302v1)

In [None]:
def plot_conformity_scores(predictions, data, **kwargs):
  num_samples = 10
  labels = data['test_labels']
  smooth_labels = data['test_smooth_labels']
  num_examples = predictions.shape[0]
  true_scores = predictions[jnp.arange(num_examples), labels]
  top1_scores = predictions[
      jnp.arange(num_examples), jnp.argmax(smooth_labels, axis=1)]
  aggregated_scores = jnp.sum(predictions * smooth_labels, axis=1)

  plot_hist(
      true_scores, label='True scores $E(x,y)$', range=(0, 1), bins=50,
      normalize=True, alpha=0.6)
  plot_hist(
      top1_scores, label='Voted scores $E(x,argmax_k\lambda_k)$', range=(0, 1), bins=50,
      normalize=True, alpha=0.6)
  plot_hist(
      aggregated_scores, label=r"Aggregated scores $e(x,\lambda)$",
      range=(0, 1), bins=50, normalize=True, alpha=0.6)
  plt.legend()
  plt.ylabel('Frequency')
  plt.xlabel('Conformity score')
  plt.title('Conformity score histograms')
  plt.gcf().set_size_inches(kwargs.get('width', 5), kwargs.get('height', 3))
  plt.savefig('conformity_scores.pdf', bbox_inches="tight")
  plt.show()

  plt.plot(
      jnp.arange(num_examples) / num_examples,
      jnp.sort(true_scores),
      label='True scores $E(x,y)$')
  plt.plot(
      jnp.arange(num_examples) / num_examples, jnp.sort(top1_scores),
      label='Voted scores $E(x,argmax_k\lambda_k)$')
  plt.plot(
      jnp.arange(num_examples) / num_examples, jnp.sort(aggregated_scores),
      label='Aggregated scores $e(x,\lambda)$')
  plt.legend()
  plt.xlabel('Frequency')
  plt.ylabel('Conformity score')
  plt.title('Conformity score CDFs')
  plt.gcf().set_size_inches(kwargs.get('width', 5), kwargs.get('height', 3))
  plt.savefig('conformity_scores_cdf.pdf', bbox_inches="tight")
  plt.show()

In [None]:
plot_conformity_scores(predictions, data)

#### Reduced plausibility regions

In [None]:
def visualize_confidence_regions(
    rng, predictions, data, indices, alpha=0.05, split=0.5, **kwargs):
  """Visualize confidence regions."""
  permutation = jax.random.permutation(rng, predictions.shape[0])
  val_examples = int(predictions.shape[0]*split)
  num_classes = predictions.shape[1]
  val_predictions = predictions[permutation[:val_examples]]
  test_predictions = predictions[permutation[val_examples:]]
  val_human_ground_truth = data['test_smooth_labels'][permutation[:val_examples]]
  test_human_ground_truth = data['test_smooth_labels'][permutation[val_examples:]]
  val_ground_truth = data['test_labels'][permutation[:val_examples]]
  test_ground_truth = data['test_labels'][permutation[val_examples:]]
  test_examples = data['test_examples'][permutation[val_examples:]]

  baseline_threshold = conformal_prediction.calibrate_threshold(
      val_predictions, jnp.argmax(val_human_ground_truth, axis=1), alpha)
  baseline_confidence_sets = conformal_prediction.predict_threshold(
      test_predictions, baseline_threshold)

  threshold = plausibility_regions.calibrate_plausibility_regions(
      val_predictions, val_human_ground_truth, alpha)
  distributions, coverages = plausibility_regions.predict_plausibility_regions(
      test_predictions, threshold, num_grid_points=50)
  k = 1
  confidence_sets = plausibility_regions.reduce_plausibilities_to_topk(
      distributions, coverages, k=k)

  colors = np.array([
      [228,26,28],
      [55,126,184],
      [77,175,74],
  ]) / 255.
  colab_utils.plot_smooth_data(
      data['train_examples'], data['train_smooth_labels'],
      highlight_points=test_examples[indices], boundary=True,
      name='data_smooth_marked', colors=colors)
  cmap = matplotlib.cm.get_cmap('viridis')
  plt.text(
      0.475, 0.8, 'MLP class. boundary', color=cmap(0), fontdict={'fontsize': 14})
  plt.show()

  projected_distributions = colab_utils.project_simplex(distributions)
  for i, n in enumerate(indices):
    print('Case:', n)
    print('Human ground truth:', test_human_ground_truth[n])
    print('Conformity scores:', test_predictions[n])
    print('Baseline confidence sets:', baseline_confidence_sets[n])
    print(f'Top-{k} confidence set:', confidence_sets[n])

    plt.bar(
        jnp.arange(num_classes),
        test_human_ground_truth[n],
        label=r"Plausibility $\lambda_k = p(y=k|x)$", alpha=0.65)
    plt.bar(
        jnp.arange(num_classes),
        test_predictions[n],
        label=r"Conf. scores $E(x, k) = \pi_k(x)$", alpha=0.65)
    plt.xlabel('Class')
    plt.ylabel('Probability')
    plt.legend(loc='lower left', bbox_to_anchor=(-0.5, 1))
    plt.gcf().set_size_inches(2, 1)
    plt.xticks([0, 1, 2])
    plt.grid(True)
    plt.savefig(f'data_smooth_{i + 1}.pdf', bbox_inches="tight")
    plt.show()

    colab_utils.plot_simplex(projected_distributions, coverages[n])
    plt.savefig(f'data_smooth_{i + 1}1.png', bbox_inches="tight")
    plt.show()

In [None]:
visualize_confidence_regions(
    jax.random.PRNGKey(0), predictions, data,
    indices=np.array([10, 6, 12, 20]), alpha=0.05)

In [None]:
def run_reduction_trials(
    rng, predictions, data, alpha=0.05, trials=100, split=0.5):
  """Visualize confidence regions."""
  results = {}
  tags = ['true_coverages', 'plausibility_coverages']
  for tag in tags:
    results[tag] = []
  keys = jax.random.split(rng, trials)
  for key in keys:
    permutation = jax.random.permutation(key, predictions.shape[0])
    val_examples = int(predictions.shape[0]*split)
    val_predictions = predictions[permutation[:val_examples]]
    test_predictions = predictions[permutation[val_examples:]]
    val_human_ground_truth = data['test_smooth_labels'][
        permutation[:val_examples]]
    test_human_ground_truth = data['test_smooth_labels'][
        permutation[val_examples:]]
    test_ground_truth = data['test_labels'][permutation[val_examples:]]

    threshold = plausibility_regions.calibrate_plausibility_regions(
        val_predictions, val_human_ground_truth, alpha)
    distributions, coverages = plausibility_regions.predict_plausibility_regions(
        test_predictions, threshold)
    confidence_sets = plausibility_regions.reduce_plausibilities_to_topk(
        distributions, coverages, k=1)
    num_classes = test_predictions.shape[1]
    true_coverages = classification_metrics.aggregated_coverage(
        confidence_sets, jax.nn.one_hot(test_ground_truth, num_classes))
    plausibility_coverages = plausibility_regions.check_plausibility_regions(
        test_predictions, test_human_ground_truth, threshold)
    results['true_coverages'].append(true_coverages)
    results['plausibility_coverages'].append(plausibility_coverages)
  for tag in tags:
    results[tag] = jnp.array(results[tag])
  return results

In [None]:
def plot_reduced_plausibility_regions(predictions, data, alpha=0.05, **kwargs):
  plausibility_results = run_reduction_trials(
      jax.random.PRNGKey(0), predictions, data, alpha=alpha)

  hist, _ = plot_hist(
      jnp.mean(plausibility_results['plausibility_coverages'], axis=-1),
      normalize=True, label='Plausibility coverage')
  vmax = np.max(hist)
  hist, _ = plot_hist(
      jnp.mean(plausibility_results['true_coverages'], axis=-1),
      normalize=True, label='True label coverage')
  vmax = max(vmax, np.max(hist))
  plt.vlines(0.95, 0, vmax, color='red', label='Target plausibility coverage')
  plt.xlabel('Empirical coverage')
  plt.ylabel('Frequency')
  plt.legend()
  plt.title('Coverage of reduced plausibility regions')
  plt.gcf().set_size_inches(kwargs.get('width', 5), kwargs.get('height', 3))
  plt.savefig('reduced_plausibility_regions.pdf', bbox_inches="tight")
  plt.show()


In [None]:
plot_reduced_plausibility_regions(predictions, data)

## Dermatology dataset

### Load data

In [None]:
with open('data/dermatology_data.pkl', 'rb') as f:
  data = pickle.load(f)
with open('data/dermatology_predictions0.txt', 'r') as f:
  predictions =  np.loadtxt(f)

### Main results

In [None]:
def plot_conformal_prediction(
    predictions, plausibilities, num_trials=100,
    num_samples=10, alpha=0.27, method=False,**kwargs):
  """Plot expected/standard accuracy for CP against IRN top-1 labels."""
  _, num_classes = plausibilities.shape
  top1_coverages = []
  aggregated_coverages = []
  sizes = []
  keys = jax.random.split(jax.random.PRNGKey(0), 2 * num_trials)
  for t in range(num_trials):
    permutation = jax.random.permutation(keys[2 * t], predictions.shape[0])
    val_examples = int(predictions.shape[0] * 0.5)
    val_predicitions = predictions[permutation[:val_examples]]
    test_predicitions = predictions[permutation[val_examples:]]
    val_plausibilities = plausibilities[permutation[:val_examples]]
    test_plausibilities = plausibilities[permutation[val_examples:]]

    if method == 'mccp':
      threshold = monte_carlo.calibrate_mc_threshold(
          keys[2 * t + 1], val_predicitions, val_plausibilities,
          alpha, num_samples=num_samples)
      confidence_sets = conformal_prediction.predict_threshold(
          test_predicitions, threshold)
    elif method == 'ecdf':
      p_values = monte_carlo.compute_mc_ecdf_p_values(
          keys[2 * t + 1], val_predicitions, val_plausibilities,
          test_predicitions, num_samples=num_samples, split=0.5)
      confidence_sets = conformal_prediction.predict_p_values(
          p_values, alpha)
    else:
      val_labels = jnp.argmax(val_plausibilities, axis=1)
      threshold = conformal_prediction.calibrate_threshold(
          val_predicitions, val_labels, alpha)
      confidence_sets = conformal_prediction.predict_threshold(
          test_predicitions, threshold)

    test_labels = jnp.argmax(test_plausibilities, axis=1)
    top1_coverages.append(classification_metrics.aggregated_coverage(
        confidence_sets, jax.nn.one_hot(test_labels, num_classes)))
    aggregated_coverages.append(classification_metrics.aggregated_coverage(
        confidence_sets, test_plausibilities))
    sizes.append(classification_metrics.size(confidence_sets))
  top1_coverages = jnp.array(top1_coverages)
  aggregated_coverages = jnp.array(aggregated_coverages)
  sizes = jnp.array(sizes)

  hist1, _ = colab_utils.plot_hist(
      jnp.mean(top1_coverages, axis=1),
      label='Voted coverage', normalize=True)
  hist2, _ = colab_utils.plot_hist(
      jnp.mean(aggregated_coverages, axis=1),
      label='Aggregated coverage', normalize=True)
  vmax = max(jnp.max(hist1), jnp.max(hist2))
  plt.vlines(
      1 - alpha, 0, vmax,
      color='black', label=r'Target')
  plt.title(kwargs.get(
      'title',
      'Empirical coverage for standard CP with voted IRN labels'
      f'\n({num_trials} random calibration/test splits)'))
  plt.xlabel('Empirical coverage')
  plt.ylabel('Frequency')
  plt.legend()
  plt.gcf().set_size_inches(kwargs.get('width', 4.5), kwargs.get('height', 2.5))
  plt.show()

  _, _ = colab_utils.plot_hist(
      jnp.mean(sizes, axis=1), normalize=True)
  vmax = max(jnp.max(hist1), jnp.max(hist2))
  average = jnp.mean(sizes)
  plt.vlines(
      average, 0, vmax, label=f'Average: {average:.2f}',
      color='black', linewidth=2)
  plt.xlabel('Inefficiency')
  plt.ylabel('Frequency')
  plt.legend()
  plt.gcf().set_size_inches(kwargs.get('width', 4.5), kwargs.get('height', 1.5))
  plt.show()

In [None]:
for alpha in [0.27, 0.1]:
  plot_conformal_prediction(
      predictions, data['test_irn'], method=False,
      title='CP with voted labels', alpha=alpha)

In [None]:
for alpha in [0.27, 0.1]:
  plot_conformal_prediction(
      predictions, data['test_irn'], method='mccp',
      title='Monte Carlo CP', alpha=alpha)

In [None]:
# This runs much longer than the above.
plot_conformal_prediction(
    predictions, data['test_irn'], method='ecdf',
    title='ECDF Monte Carlo CP', name='ecdf', alpha=0.27)

In [None]:
plot_conformal_prediction(
    predictions, data['test_irn'], method='ecdf',
    title='ECDF Monte Carlo CP', name='ecdf', alpha=0.1)