# PL Sampler

See `README.md` for installation and usage instructions.

This notebook allows to run the Plackett-Luce Gibbs sampler described in [2] on the toy dataset of [1] as well as the dermatology dataset.

```
[1] Stutz, D., Roy, A.G., Matejovicova, T., Strachan, P., Cemgil, A.T.,
    & Doucet, A. (2023).
    Conformal prediction under ambiguous ground truth. ArXiv, abs/2307.09302.
[2] Stutz, D., Cemgil, A.T., Roy, A.G., Matejovicova, T., Barsbey, M.,
    Strachan, P., Schaekermann, M., Freyberg, J.V., Rikhye, R.V., Freeman, B.,
    Matos, J.P., Telang, U., Webster, D.R., Liu, Y., Corrado, G.S., Matias, Y.,
    Kohli, P., Liu, Y., Doucet, A., & Karthikesalingam, A. (2023).
    Evaluating AI systems under uncertain ground truth: a case study in
    dermatology. ArXiv, abs/2307.02191.
```

## Imports and setup

In [None]:
import functools
import jax
import jax.numpy as jnp
from matplotlib import pyplot as plt
import numpy as np
import os
import pickle

In [None]:
import selectors_utils
import pl_samplers
import eval_utils
import colab_utils
import gaussian_toy_dataset as gtd

In [None]:
colab_utils.set_style()
plot_hist = colab_utils.plot_hist

## Utilities

In [None]:
def gibbs_sampler(
    selector,
    sampler,
    shape_lam,
    rate_lam,
    num_classes,
    num_iterations,
    warmup_iterations,
    rng,
):
  """Helper to run standard Gibbs sampler on one case."""
  shape_lam = jnp.ones((num_classes)) * shape_lam / num_classes
  rate_lam = jnp.ones((num_classes)) * rate_lam
  result = sampler.sample(
      next(rng),
      selector,
      shape_lam=shape_lam,
      rate_lam=rate_lam,
      num_iterations=num_iterations)
  result = result[warmup_iterations:]
  return jnp.expand_dims(result, axis=0)


def gibbs_sampler_from_ranked_classes(
    selector,
    sampler,
    shape_lam,
    rate_lam,
    num_iterations,
    warmup_iterations,
    represent_unranked_classes,
    normalize_unranked_equally,
    num_classes,
    rng,
):
  """Helper to run Gibbs sampler on one case while ignoring unranked classes."""
  num_observed_classes = len(list(set(jax.tree_util.tree_leaves(selector))))
  shape_lam /= (num_observed_classes + 1)
  if not normalize_unranked_equally:
    shape_lam = 1
  result = sampler.sample_from_ranked_classes(
      next(rng),
      selector,
      shape_lam_i=shape_lam,
      rate_lam_i=rate_lam,
      num_classes=num_classes,
      num_iterations=num_iterations,
      represent_unranked_classes=represent_unranked_classes,
      normalize_unranked_equally=normalize_unranked_equally)
  result = result[warmup_iterations:]
  return jnp.expand_dims(result, axis=0)

In [None]:
def plot_plausabilities(
    plausabilities, irn_reference, label_names,
    limit=10, num_samples=100, **kwargs):
  """Plot PL plausabilities against reference IRN plausabilities."""
  indices = jnp.argsort(- irn_reference)
  indices = indices[:limit]
  plt.bar(
      jnp.arange(indices.shape[0]),
      irn_reference[indices], alpha=0.5,
      label=kwargs.get('reference_label', ''), color=colab_utils.COLORS[0])
  if plausabilities is not None:
    for s in range(num_samples):
      label = f'{num_samples} Samples' if s == 0 else ''
      plt.scatter(
          jnp.arange(indices.shape[0]),
          plausabilities[s, indices],
          c='g', s=10, alpha=0.75, label=label)
    plt.scatter(
        jnp.arange(indices.shape[0]),
        jnp.mean(plausabilities[:num_samples], axis=0)[indices],
        s=25, c='r', label='Mean')

  plt.xticks(np.arange(indices.shape[0]), [label_names[i][:6] for i in indices])
  plt.title(kwargs.get('title', f'Plausabilities'))
  plt.xlabel(kwargs.get('xlabel', 'Classes'))
  plt.ylabel(kwargs.get('ylabel', ''))
  plt.ylim(ymin=kwargs.get('ymin', None), ymax=kwargs.get('ymax', None))
  if kwargs.get('legend', False):
    plt.legend()
  plt.gcf().set_size_inches(kwargs.get('width', limit), kwargs.get('height', 3))
  plt.show()

## Load data

In [None]:
dataset = 'derm'  #@param ['toy', 'derm']

In [None]:
if dataset == 'toy':
  with open('data/toy_data.pkl', 'rb') as f:
    data = pickle.load(f)
  label_names = ['0', '1', '2']
  indices = [0, 1, 2]  # Random examples.
elif dataset == 'derm':
  with open('data/dermatology_data.pkl', 'rb') as f:
    data = pickle.load(f)
  with open('data/dermatology_conditions.txt', 'r') as f:
    label_names = [condition.strip() for condition in f.readlines()]
  indices = [1026, 1057, 357]  # Examples from the paper.

## PL Sampling

For the Plackett-Luce Gibbs sampler, we provide several options. Most notably, `sample_from_ranked` can be set to `True` if there are many classes but individual annotations will never include all classes. This is true on the Dermatology dataset, but not on the toy dataset. So setting this to `True` will speed up sampling on the Dermatology data.

For the 3 dermatology cases, the below should run in a few minutes. Running on the whole dataset might take considerably longer.

In [None]:
shape_lam = 1.0
rate_lam = 1.0
warmup_iterations = 500
total_iterations = 1500
reader_repetitions = 3
# sample_from_ranked can be set to True if there are many classes and
# the partial rankings will never include all classes.
sample_from_ranked = (dataset == 'derm')
represent_unranked_classes = True
normalize_unranked_equally = False
jit_strategy = 'jit_per_reader'

In [None]:
rng = gtd.PRNGSequence(0)
num_classes = data['test_irn'].shape[1]
plausibilities = []
for index in indices:
  selector = data['test_selectors'][index]

  if reader_repetitions > 1:
    selector = selectors_utils.repeat_selectors(
        [selector], reader_repetitions
    )[0]

  if sample_from_ranked:
    worker_fn = functools.partial(
        gibbs_sampler_from_ranked_classes,
        represent_unranked_classes=represent_unranked_classes,
        normalize_unranked_equally=normalize_unranked_equally,
    )
  else:
    worker_fn = gibbs_sampler
  worker = functools.partial(
      worker_fn,
      sampler=pl_samplers.GibbsSamplerPlackettLuce(jit_strategy),
      shape_lam=shape_lam,
      rate_lam=rate_lam,
      num_iterations=total_iterations,
      warmup_iterations=warmup_iterations,
      num_classes=num_classes,
      rng=rng)

  plausibilities.append(worker(selector))
plausibilities = jnp.concatenate(plausibilities, axis=0)
plausibilities = eval_utils.normalize_plausibilities(plausibilities)

In [None]:
for i, index in enumerate(indices):
  plot_plausabilities(plausibilities[i], data['test_irn'][index], label_names)