## ViT-Plex Demo

*Licensed under the Apache License, Version 2.0.*

To run this in a public Colab, change the GitHub link: replace github.com with [githubtocolab.com](http://githubtocolab.com)

<a href="https://githubtocolab.com/google/uncertainty-baselines/blob/main/experimental/plex/plex_vit_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook demonstrates how one can utilize the released **ViT-Plex** checkpoints from the *Plex: Towards Reliability using Pretrained Large Model Extensions* paper using [JAX](https://jax.readthedocs.io/). The **General usage** section provides a minimal setup for loading the checkpoints and making predictions, and the ***Uncertainty***, ***Robust Generalization***, and ***Adaptation*** sections delve deeper into the three areas of reliability for which Plex is designed to excel.

For more advanced usage, full training and fine-tuning scripts can be found at https://github.com/google/uncertainty-baselines/tree/main/baselines/jft.

## Imports

In [None]:
# NOTE: Use `tpu-colab` when running on a hosted TPU Colab runtime. Use `tpu`
# when running on a GCP TPU machine.
backend = "cpu"  #@param ["tpu-colab", "tpu", "gpu", "cpu"]

In [None]:
pip_install = True
if pip_install:
  # NOTE: Set the jax version to >=0.3.14 if Python 3.9+ is available.
  if backend == "cpu" or backend == "tpu-colab":
    !python3 -m pip install "jax~=0.2.27"
  elif backend == "tpu":
    !python3 -m pip install "jax[tpu]~=0.2.27" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  elif backend == "gpu":
    !python3 -m pip install "jax[cuda]~=0.2.27" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
  else:
    raise ValueError("Backend must be one of ['cpu', 'tpu', 'gpu']. got "
                     f"backend={backend} instead.")
  !rm -rf uncertainty-baselines
  !git clone https://github.com/google/uncertainty-baselines.git
  !cp -r uncertainty-baselines/baselines/jft/* .
  # NOTE: Remove the explicit tensorflow-federated and tensorflow_probability
  # installs if Python 3.9+ is available.
  !python3 -m pip install "tensorflow-federated==0.20.0" "tensorflow_probability<0.17.0" ./uncertainty-baselines[tensorflow,jax,models,datasets]

In [None]:
if backend == "tpu-colab":
  import jax.tools.colab_tpu
  jax.tools.colab_tpu.setup_tpu()

import functools

from clu import preprocess_spec
import flax
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import ml_collections
import sklearn
import tensorflow as tf
import tensorflow_datasets as tfds
import uncertainty_baselines as ub
import checkpoint_utils  # local file import from baselines.jft
import input_utils  # local file import from baselines.jft
import preprocess_utils  # local file import from baselines.jft

In [None]:
# If running with TPUs, the following should output a list of TPU devices.
print(jax.local_devices())

In [None]:
# Set a base seed to use for the notebook.
rng = jax.random.PRNGKey(42)

## General usage

### Load model

In [None]:
def get_finetuned_config():
  # From `https://github.com/google/uncertainty-baselines/blob/main/baselines/jft/experiments/vit_l32_hetbe_finetune.py`.
  # TODO(dusenberrymw): Clean up this config.
  config = ml_collections.ConfigDict()
  config.model = ml_collections.ConfigDict()
  config.model.patches = ml_collections.ConfigDict()
  config.model.patches.size = [32, 32]
  config.model.hidden_size = 1024
  config.model.transformer = ml_collections.ConfigDict()
  config.model.transformer.mlp_dim = 4096
  config.model.transformer.num_heads = 16
  config.model.transformer.num_layers = 24
  config.model.transformer.attention_dropout_rate = 0.
  config.model.transformer.dropout_rate = 0.
  config.model.classifier = 'token'
  config.model.representation_size = None

  # Heteroscedastic
  config.model.multiclass = True
  config.model.temperature = 1.25
  config.model.mc_samples = 1000
  config.model.num_factors = 15
  config.model.param_efficient = False

  # BatchEnsemble
  config.model.transformer.be_layers = (21, 22, 23)
  config.model.transformer.ens_size = 3
  config.model.transformer.random_sign_init = -0.5
  config.model.transformer.ensemble_attention = False

  # TODO(dusenberrymw): Remove the need to include this GP config.
  # GP
  config.model.use_gp = False
  config.model.covmat_momentum = .999
  config.model.ridge_penalty = 1.
  config.model.mean_field_factor = -1.
  return config

In [None]:
num_classes = 1000
config = get_finetuned_config()
model = ub.models.vision_transformer_het_gp_be(
    num_classes=num_classes, **config.model)

In [None]:
@jax.jit
def predict_fn(params, images, rng):
  rng_dropout, rng_diag_noise, rng_standard_noise = jax.random.split(rng, num=3)
  tiled_logits, _ = model.apply(
      {'params': flax.core.freeze(params)},
      images,
      train=False,
      rngs={
          'dropout': rng_dropout,
          'diag_noise_samples': rng_diag_noise,
          'standard_norm_noise_samples': rng_standard_noise})
  ens_logits = jnp.stack(jnp.split(tiled_logits, model.transformer.ens_size))
  ens_probs = jax.nn.softmax(ens_logits)
  avg_probs = jnp.mean(ens_probs, axis=0)  # Average over ensemble members.
  return avg_probs  # Shape (batch_size, num_classes).

In [None]:
checkpoint_path = "gs://plex-paper/plex_vit_large_imagenet21k_to_imagenet.npz"
read_in_parallel = False
checkpoint = checkpoint_utils.load_checkpoint(None, path=checkpoint_path,
                                              read_in_parallel=read_in_parallel)
params = checkpoint["opt"]["target"]

### Make predictions

#### Single image

In [None]:
# Get a single image from https://www.tensorflow.org/datasets/catalog/imagenet_v2.
# Direct URL: https://knowyourdata-tfds.withgoogle.com/#dataset=imagenet_v2&tab=ITEM&select=kyd%2Fimagenet_v2%2Flabel&item=205%2F194ab2af3f5802ad12e1f4327d598743b01489c0.jpeg
!wget --no-check-certificate "https://knowyourdata-tfds.withgoogle.com/serve_image?&id=205%2F194ab2af3f5802ad12e1f4327d598743b01489c0.jpeg&dataset=imagenet_v2" -O image.jpg
from IPython.display import Image, display
display(Image("image.jpg"))

In [None]:
# Load and preprocess image.
def preprocess_fn(image):
  # Note: The model was trained with this preprocessing.
  x = tf.convert_to_tensor(image)
  x = tf.io.decode_image(x, channels=3, expand_animations=False)
  x = tf.image.resize(x, (384, 384))
  x = tf.cast(x, tf.float32) / 255. * 2 - 1
  return jnp.asarray(x)

with open("image.jpg", mode='rb') as f:
  image_data = f.read()

image = preprocess_fn(image_data)
image.shape

In [None]:
# Make a prediction.
rng_eval = jax.random.fold_in(rng, 0)
images = jnp.array([image])  # Create a batch of 1 image.
probs = predict_fn(params, images, rng_eval)
probs.shape

In [None]:
# Output top 5 predictions.
all_top_preds = tf.keras.applications.imagenet_utils.decode_predictions(
    probs, top=5)

for top_preds in all_top_preds:
  for _, pred_class_name, prob in top_preds:
    print(f"{float(prob):.6f} : {pred_class_name}")

#### Batch of images w/ multiple devices

Here we demonstrate how to make predictions with the model on a batch of images using multiple devices.

In [None]:
def load_val_ds(dataset, split, batch_size, preprocess_eval_fn):
  # NOTE: The data loader yields examples of shape
  # (num_devices, batch_size/num_devices, ...), i.e., it splits the batch_size
  # across the number of local devices, under the assumption that TPUs or
  # multiple GPUs are used.
  val_ds = input_utils.get_data(
      dataset=dataset,
      split=split,
      rng=None,
      process_batch_size=batch_size,
      preprocess_fn=preprocess_eval_fn,
      cache=False,
      num_epochs=1,
      repeat_after_batching=True,
      shuffle=False,
      prefetch_size=0,
      drop_remainder=False,
      data_dir=None)
  return val_ds

In [None]:
pp_eval = "decode|resize(384)|value_range(-1, 1)|onehot(1000, key='label', key_result='labels')|keep(['image', 'labels'])"
preprocess_eval_fn = preprocess_spec.parse(
    spec=pp_eval, available_ops=preprocess_utils.all_ops())

In [None]:
# https://www.tensorflow.org/datasets/catalog/imagenet_v2
dataset = "imagenet_v2"
tfds.builder(dataset).download_and_prepare()
split = "test"
batch_size = 64 * jax.local_device_count()
val_ds = load_val_ds(dataset, split=split, batch_size=batch_size,
                     preprocess_eval_fn=preprocess_eval_fn)
val_ds.element_spec

In [None]:
# Create a model function that works across multiple TPU devices or across
# multiple GPUs for performance. The value for `in_axes` means that the `params`
# argument for `predict_fn` will be copied to each device, the `images` will be
# split ("sharded") across the devices along the first axis, and the `rng` will
# be copied to each device. Note that this means that `images` should have shape
# `(num_devices, batch_size/num_devices, h, w, c)` so that each device processes
# a `(batch_size/num_devices, h, w, c)` chunk of the images. The `params` and
# `rng` will be the same as in the "Singe image" example up above.
pmapped_predict_fn = jax.pmap(predict_fn, in_axes=(None, 0, None))

In [None]:
batch = next(val_ds.as_numpy_iterator())
rng_eval = jax.random.fold_in(rng, 0)
probs = pmapped_predict_fn(params, batch["image"], rng_eval)
# Note that probs is of shape (num_devices, batch_size, num_classes).
probs.shape, probs.device_buffers[0].device()

In [None]:
def get_and_reshape(x):
  # Fetch probs from all devices to CPU and reshape to (batch_size, ...).
  return jnp.reshape(jax.device_get(x), (-1,) + x.shape[2:])

images = get_and_reshape(batch["image"])
all_top_preds = tf.keras.applications.imagenet_utils.decode_predictions(
    get_and_reshape(probs), top=5)
labels = tf.keras.applications.imagenet_utils.decode_predictions(
    get_and_reshape(batch["labels"]), top=1)

# Only show 10 images.
for _, image, top_preds, label in zip(range(10), images, all_top_preds, labels):
  plt.figure(figsize=(4, 4))
  plt.imshow(image * .5 + .5)
  plt.axis('off')
  plt.show()

  correct_class_name = label[0][1]
  for _, pred_class_name, prob in top_preds:
    print(f"{float(prob):.6f} : {pred_class_name}")
  print(f"Correct class: {correct_class_name}\n")

## Reliability

### Uncertainty


To be announced!

### Robust Generalization

Here we demonstrate a *covariate shift* problem by adding ImageNet-C-style Gaussian noise ([Hendrycks & Gimpel, 2019](http://arxiv.org/abs/1903.12261)) to an input image and showing the model's predictions as the noise increases. In this type of problem, we view shifted examples as "noisy", but close enough to the distribution of training examples that we desire our model to be robust to the noise and still make strong predictions. Corruption levels 1-5 correspond to those in ImageNet-C, and we add additional levels above those. We see that Plex models are able to make confident predictions under large amounts of noise. Full evaluation results are in the paper.

In [None]:
# Define a Gaussin noise function to form ImageNet-C-style Gaussian noise
# corruptions.
def gaussian_noise(x, severity, rng):
  severity_scales = [.08, .12, 0.18, 0.26, 0.38, 0.6, 1.]
  assert severity in range(1, len(severity_scales) + 1)
  scale = severity_scales[severity - 1]
  x = x / 255.
  x = jnp.clip(x + scale * jax.random.normal(rng, shape=x.shape), 0, 1) * 255
  return x

In [None]:
# Load and preprocess image.
def preprocess_fn(image, severity=None, rng=None):
  # Note: The model was trained with this preprocessing.
  x = tf.convert_to_tensor(image)
  x = tf.io.decode_image(x, channels=3, expand_animations=False)
  x = tf.cast(tf.image.resize(x, (384, 384)), tf.float32)
  x = jnp.asarray(x)
  if severity is not None:
    x = gaussian_noise(x, severity, rng)
  x = x / 255. * 2 - 1
  return x

with open("image.jpg", mode='rb') as f:
  image_data = f.read()

image = preprocess_fn(image_data)
corrupted_images = [preprocess_fn(image_data, s, jax.random.fold_in(rng, 0))
                    for s in range(1, 8)]
images = jnp.array([image] + corrupted_images)
images.shape

In [None]:
# Make predictions.
rng_eval = jax.random.fold_in(rng, 0)
probs = predict_fn(params, images, rng_eval)
probs.shape

In [None]:
# Output top 5 predictions.
all_top_preds = tf.keras.applications.imagenet_utils.decode_predictions(
    probs, top=5)

# Only show 10 images.
for i, (image, top_preds) in enumerate(zip(images, all_top_preds)):
  plt.figure(figsize=(4, 4))
  plt.imshow(image * .5 + .5)
  plt.axis('off')
  plt.show()

  if i > 0:
    print(f"Corruption level: {i}")
  for _, pred_class_name, prob in top_preds:
    print(f"{float(prob):.6f} : {pred_class_name}")

### Adaptation

Here we demonstrate zero-shot out-of-distribution (OOD) detection using the upstream pretrained model and the relative Mahalanobis distance metric ([Ren et al., 2021](http://arxiv.org/abs/2106.09022)). In zero-shot OOD detection, the goal is to take a fixed model that was pretrained on dataset A and use it to distinguish between in-distributions samples from dataset B and OOD sample from dataset C, all without training the model further on datset B or C. We see that pretrained Plex without any finetuning is able to achieve a strong separation between in and out of distribution.

In [None]:
# Free up RAM.
del probs, batch, params, checkpoint

import gc
gc.collect()

In [None]:
def get_pretrained_config():
  # From `https://github.com/google/uncertainty-baselines/blob/main/baselines/jft/experiments/vit_be/imagenet21k_be_vit_large_32.py`.
  # TODO(dusenberrymw): Clean up this config.
  config = ml_collections.ConfigDict()
  config.model = ml_collections.ConfigDict()
  config.model.patches = ml_collections.ConfigDict()
  config.model.patches.size = [32, 32]
  config.model.hidden_size = 1024
  config.model.transformer = ml_collections.ConfigDict()
  config.model.transformer.mlp_dim = 4096
  config.model.transformer.num_heads = 16
  config.model.transformer.num_layers = 24
  config.model.transformer.attention_dropout_rate = 0.
  config.model.transformer.dropout_rate = 0.1
  config.model.classifier = 'token'
  config.model.representation_size = 1024

  # BatchEnsemble
  config.model.transformer.be_layers = (21, 22, 23)
  config.model.transformer.ens_size = 3
  config.model.transformer.random_sign_init = -0.5
  config.model.transformer.ensemble_attention = False

  return config

In [None]:
num_classes = 21843
config = get_pretrained_config()
pretrained_model = ub.models.vision_transformer_be(
    num_classes=num_classes, **config.model)

In [None]:
@jax.jit
def representation_fn(params, images, rng):
  rng_dropout, rng_diag_noise, rng_standard_noise = jax.random.split(rng, num=3)
  _, out = pretrained_model.apply(
      {'params': flax.core.freeze(params)},
      images,
      train=False,
      rngs={
          'dropout': rng_dropout,
          'diag_noise_samples': rng_diag_noise,
          'standard_norm_noise_samples': rng_standard_noise})
  representations = out["pre_logits"]
  ens_representations = jnp.stack(jnp.split(representations,
                                            model.transformer.ens_size), axis=1)
  return ens_representations  # Shape (batch_size, ens_sizen, um_classes).

# Create a model function that works across multiple TPU devices or across
# multiple GPUs for performance. The value for `in_axes` means that the `params`
# argument for `predict_fn` will be copied to each device, the `images` will be
# split ("sharded") across the devices along the first axis, and the `rng` will
# be copied to each device. Note that this means that `images` should have shape
# `(num_devices, batch_size/num_devices, h, w, c)` so that each device processes
# a `(batch_size/num_devices, h, w, c)` chunk of the images. The `params` and
# `rng` will be the same as in the "Singe image" example up above.
pmapped_representation_fn = jax.pmap(representation_fn, in_axes=(None, 0, None))

In [None]:
checkpoint_path = "gs://plex-paper/plex_vit_large_imagenet21k.npz"
read_in_parallel = False
checkpoint = checkpoint_utils.load_checkpoint(None, path=checkpoint_path,
                                              read_in_parallel=read_in_parallel)
pretrained_params = checkpoint["opt"]["target"]

In [None]:
# TODO(dusenberrymw): Upstream this to the codebase.
@jax.jit
def compute_mean_and_cov(embeds, labels, class_ids):
  """Computes class-specific means and a shared covariance matrix.

  Args:
    embeds: A jnp.array of size [n_train_sample, n_dim], where n_train_sample is
      the sample size of training set, n_dim is the dimension of the embedding.
    labels: A jnp.array of size [n_train_sample, ].
    class_ids:  A jnp.array of the unique class ids in `labels`.

  Returns:
    means: A list of len n_class, and the i-th element is an np.array of size
      [n_dim, ] corresponding to the mean of the fitted Gaussian distribution
      for the i-th class.
    cov: The shared covariance matrix of the size [n_dim, n_dim].
  """
  n_dim = embeds.shape[1]
  cov = jnp.zeros((n_dim, n_dim))

  def f(cov, class_id):
    mask = jnp.expand_dims(labels == class_id, axis=-1)
    data = embeds * mask
    mean = jnp.sum(data, axis=0) / jnp.sum(mask)
    diff = (data - mean) * mask
    cov += jnp.matmul(diff.T, diff)
    return cov, mean

  cov, means = jax.lax.scan(f, cov, class_ids)
  cov = cov / len(labels)
  return means, cov

# TODO(dusenberrymw): Upstream this to the codebase.
@jax.jit
def compute_mahalanobis_distance(embeds, means, cov):
  """Computes Mahalanobis distance between the input and the fitted Guassians.

  The Mahalanobis distance (Mahalanobis, 1936) is defined as

      `distance(x, mu, sigma) = sqrt((x-\mu)^T \sigma^{-1} (x-\mu))`,

  where `x` is a vector, `mu` is the mean vector for a Gaussian, and `sigma` is
  the covariance matrix. We compute the distance for all examples in `embeds`,
  and across all classes in `means`.

  Note that this function technically computes the squared Mahalanobis distance,
  which is consistent with Eq.(2) in <TODO>.

  Args:
    embeds: A matrix size [n_test_sample, n_dim], where n_test_sample is the
      sample size of the test set, and n_dim is the size of the embeddings.
    means: A matrix of size [num_classes, n_dim], where the ith row corresponds
      to the mean of the fitted Gaussian distribution for the i-th class.
    cov: The shared covariance mmatrix of the size [n_dim, n_dim].

  Returns:
    A matrix of size [n_test_sample, n_class] where the [i, j] element
    corresponds to the Mahalanobis distance between i-th sample to the j-th
    class Gaussian.
  """
  # NOTE: It's possible for `cov` to be singular, in part because it is
  # estimated on a sample of data. This can be exacerbated by lower precision,
  # where, for example, the matrix could be non-singular in float64, but
  # singular in float32. For our purposes in computing Mahalanobis distance,
  # using a  pseudoinverse is a reasonable approach that will be equivalent to
  # the inverse if `cov` is non-singular.
  cov_inv = jnp.linalg.pinv(cov)

  def maha_dist(x, mean):
    # NOTE: This computes the squared Mahalanobis distance.
    diff = x - mean
    return jnp.einsum("i,ij,j->", diff, cov_inv, diff)

  # Vectorize over all classes means, and map in a fast loop over examples.
  # Given more memory, one could vectorize over examples as well.
  maha_dist_all_classes_fn = jax.vmap(maha_dist, in_axes=(None, 0))
  out = jax.lax.map(lambda x: maha_dist_all_classes_fn(x, means), embeds)
  return out

In [None]:
def get_and_reshape(x):
  # Fetch probs from all devices to CPU and reshape to (batch_size, ...).
  return jnp.reshape(jax.device_get(x), (-1,) + x.shape[2:])

In [None]:
# https://www.tensorflow.org/datasets/catalog/imagenet_v2
dataset = "imagenet_v2"
tfds.builder(dataset).download_and_prepare()
batch_size = 64 * jax.local_device_count()
split = "test"

pp_eval = f"decode|resize_small(256)|central_crop(224)|value_range(-1, 1)|onehot(1000, key='label', key_result='labels')|keep(['image', 'labels'])"
preprocess_eval_fn = preprocess_spec.parse(
    spec=pp_eval, available_ops=preprocess_utils.all_ops())

val_ds = load_val_ds(dataset, split=split, batch_size=batch_size,
                     preprocess_eval_fn=preprocess_eval_fn)

in_dist_representations = []
in_dist_labels = []
masks = []

# NOTE: given more compute, use the entire dataset instead.
val_ds = val_ds.shuffle(256, seed=42).take(int(1024 / batch_size))
for i, batch in enumerate(val_ds.as_numpy_iterator()):
  rng_eval = jax.random.fold_in(rng, 0)
  representation = pmapped_representation_fn(pretrained_params, batch["image"],
                                             rng_eval)
  in_dist_representations.append(get_and_reshape(representation))
  masks.append(get_and_reshape(batch["mask"]))
  in_dist_labels.append(get_and_reshape(jnp.argmax(batch["labels"], axis=-1)))

mask = jnp.concatenate(jax.device_get(masks))
in_dist_representations = jnp.concatenate(in_dist_representations)[mask == 1]
in_dist_labels = jnp.concatenate(in_dist_labels)[mask == 1]
in_dist_representations.shape, in_dist_labels.shape

In [None]:
ens_means, ens_covs = [], []
ens_means_background, ens_covs_background = [], []
for m in range(in_dist_representations.shape[1]):
  means, cov = compute_mean_and_cov(in_dist_representations[:, m],
                                    in_dist_labels,
                                    class_ids=jnp.unique(in_dist_labels))
  ens_means.append(means)
  ens_covs.append(cov)

  means_bg, cov_bg = compute_mean_and_cov(in_dist_representations[:, m],
                                          jnp.zeros_like(in_dist_labels),
                                          class_ids=jnp.array([0]))
  ens_means_background.append(means_bg)
  ens_covs_background.append(cov_bg)

In [None]:
ens_in_dist_rmaha_distances = []
for m in range(len(ens_means)):
  distances = compute_mahalanobis_distance(in_dist_representations[:, m],
                                           ens_means[m], ens_covs[m])
  distances_bg = compute_mahalanobis_distance(in_dist_representations[:, m],
                                              ens_means_background[m],
                                              ens_covs_background[m])
  rmaha_distances = jnp.min(distances, axis=-1) - distances_bg[:, 0]
  ens_in_dist_rmaha_distances.append(rmaha_distances)

in_dist_rmaha_distances = jnp.mean(jnp.array(ens_in_dist_rmaha_distances),
                                   axis=0)
del ens_in_dist_rmaha_distances
in_dist_rmaha_distances.shape

In [None]:
# https://www.tensorflow.org/datasets/catalog/fashion_mnist
dataset = "fashion_mnist"
tfds.builder(dataset).download_and_prepare()
batch_size = 64 * jax.local_device_count()
split = "test"

pp_eval = "decode|resize_small(256)|central_crop(224)|value_range(-1, 1)|keep(['image'])"
preprocess_eval_fn = preprocess_spec.parse(
    spec=pp_eval, available_ops=preprocess_utils.all_ops())

val_ds = load_val_ds(dataset, split=split, batch_size=batch_size,
                     preprocess_eval_fn=preprocess_eval_fn)

ood_representations = []
masks = []

# NOTE: given more compute, use the entire dataset instead.
val_ds = val_ds.shuffle(256, seed=42).take(int(1024 / batch_size))
for i, batch in enumerate(val_ds.as_numpy_iterator()):
  rng_eval = jax.random.fold_in(rng, 0)
  representation = pmapped_representation_fn(pretrained_params, batch["image"],
                                             rng_eval)
  ood_representations.append(get_and_reshape(representation))
  masks.append(get_and_reshape(batch["mask"]))

mask = jnp.concatenate(masks)
ood_representations = jnp.concatenate(ood_representations)[mask == 1]
ood_representations.shape

In [None]:
ens_ood_rmaha_distances = []
for m in range(len(ens_means)):
  distances = compute_mahalanobis_distance(ood_representations[:, m],
                                           ens_means[m], ens_covs[m])
  distances_bg = compute_mahalanobis_distance(ood_representations[:, m],
                                              ens_means_background[m],
                                              ens_covs_background[m])
  rmaha_distances = jnp.min(distances, axis=-1) - distances_bg[:, 0]
  ens_ood_rmaha_distances.append(rmaha_distances)

ood_rmaha_distances = jnp.mean(jnp.array(ens_ood_rmaha_distances),
                                   axis=0)
del ens_ood_rmaha_distances
ood_rmaha_distances.shape

In [None]:
plt.hist([in_dist_rmaha_distances, ood_rmaha_distances], bins=100, density=True,
         label=["in-dist", "ood"])
plt.legend()
plt.show()

In [None]:
labels = jnp.concatenate((jnp.zeros_like(in_dist_rmaha_distances),
                          jnp.ones_like(ood_rmaha_distances)))
scores = jnp.concatenate((in_dist_rmaha_distances, ood_rmaha_distances))
aucroc = sklearn.metrics.roc_auc_score(labels, scores)
aucroc

## Extras

### Export to TensorFlow for serving, embedded devices, TF.js, etc.

To be announced!