# Hyperparameter Ensembles for Robustness and Uncertainty Quantification

*Florian Wenzel, April 8th 2021. Licensed under the Apache License, Version 2.0.*

Recently, we proposed **Hyper-deep Ensembles** ([Wenzel et al., NeurIPS 2020](https://arxiv.org/abs/2006.13570)) a simple, yet powerful, extension of [deep ensembles](https://arxiv.org/abs/1612.01474). The approach works with any given deep network architecture and, therefore, can be easily integrated (and improve) a machine learning system that is already used in production.

Hyper-deep ensembles improve the performance of a given deep network by forming an ensemble over multiple variants of that architecture where each member uses different hyperparameters. In this notebook we consider a ResNet-20 architecture with block-wise $\ell_2$-regularization parameters and a label smoothing parameter. We construct an ensemble of 4 members where each member uses a  different set of hyperparameters. This leads to an ensemble of **diverse members**, i.e., members that are complementary in their predictions. The final ensemble greatly improves the prediction performance and the robustness of the model, e.g., in out-of-distribution settings.

Let's start with some boilerplate code for data loading and the model definition.




Requirements:
```bash
!pip install "git+https://github.com/google/uncertainty-baselines.git#egg=uncertainty_baselines"
```

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np

import uncertainty_baselines as ub

In [None]:
def _ensemble_accuracy(labels, logits_list):
  """Compute the accuracy resulting from the ensemble prediction."""
  per_probs = tf.nn.softmax(logits_list)
  probs = tf.reduce_mean(per_probs, axis=0)
  acc = tf.keras.metrics.SparseCategoricalAccuracy()
  acc.update_state(labels, probs)
  return acc.result()

def _ensemble_cross_entropy(labels, logits):
  logits = tf.convert_to_tensor(logits)
  ensemble_size = float(logits.shape[0])
  labels = tf.cast(labels, tf.int32)
  ce = tf.nn.sparse_softmax_cross_entropy_with_logits(
      labels=tf.broadcast_to(labels[tf.newaxis, ...], tf.shape(logits)[:-1]),
      logits=logits)
  nll = -tf.reduce_logsumexp(-ce, axis=0) + tf.math.log(ensemble_size)
  return tf.reduce_mean(nll)


def greedy_selection(val_logits, val_labels, max_ens_size, objective='nll'):
  """Greedy procedure from Caruana et al. 2004, with replacement."""

  assert_msg = 'Unknown objective type (received {}).'.format(objective)
  assert objective in ('nll', 'acc', 'nll-acc'), assert_msg

  # Objective that should be optimized by the ensemble. Arbitrary objectives,
  # e.g., based on nll, acc or calibration error (or combinations of those) can
  # be used.
  if objective == 'nll':
    get_objective = lambda acc, nll: nll
  elif objective == 'acc':
    get_objective = lambda acc, nll: acc
  else:
    get_objective = lambda acc, nll: nll-acc

  best_acc = 0.
  best_nll = np.inf
  best_objective = np.inf
  ens = []

  def get_ens_size():
    return len(set(ens))

  while get_ens_size() < max_ens_size:
    current_val_logits = [val_logits[model_id] for model_id in ens]
    best_model_id = None
    for model_id, logits in enumerate(val_logits):
      acc = _ensemble_accuracy(val_labels, current_val_logits + [logits])
      nll = _ensemble_cross_entropy(val_labels, current_val_logits + [logits])
      obj = get_objective(acc, nll)
      if obj < best_objective:
        best_acc = acc
        best_nll = nll
        best_objective = obj
        best_model_id = model_id
    if best_model_id is None:
      print('Ensemble could not be improved: Greedy selection stops.')
      break
    ens.append(best_model_id)
  return ens, best_acc, best_nll


def parse_checkpoint_dir(checkpoint_dir):
  """Parse directory of checkpoints."""
  paths = []
  subdirectories = tf.io.gfile.glob(os.path.join(checkpoint_dir, '*'))
  is_checkpoint = lambda f: ('checkpoint' in f and '.index' in f)
  print('Load checkpoints')
  for subdir in subdirectories:
    for path, _, files in tf.io.gfile.walk(subdir):
      if any(f for f in files if is_checkpoint(f)):
        latest_checkpoint = tf.train.latest_checkpoint(path)
        paths.append(latest_checkpoint)
        print('.', end='')
        break
  print('')
  return paths

In [None]:
DATASET = 'cifar10'
TRAIN_PROPORTION = 0.95
BATCH_SIZE = 64
ENSEMBLE_SIZE = 4
CHECKPOINT_DIR = 'gs://gresearch/reliable-deep-learning/checkpoints/baselines/cifar/hyper_ensemble/'

In [None]:
# Load data.
ds_info = tfds.builder(DATASET).info
num_classes = ds_info.features['label'].num_classes
# Test set.
steps_per_eval = ds_info.splits['test'].num_examples // BATCH_SIZE 
test_dataset = ub.datasets.get(
      DATASET,
      split=tfds.Split.TEST).load(batch_size=BATCH_SIZE)
# Validation set.
validation_percent = 1 - TRAIN_PROPORTION
val_dataset = ub.datasets.get(
    dataset_name=DATASET,
    split=tfds.Split.VALIDATION,
    validation_percent=validation_percent,
    drop_remainder=False).load(batch_size=BATCH_SIZE)
steps_per_val_eval = int(ds_info.splits['train'].num_examples *
                          validation_percent) // BATCH_SIZE

# Let's construct the hyper-deep ensemble over a ResNet-20 architecture


**This is the (simplified) hyper-deep ensembles construction pipeline**
> **1. Random search:** train several models on the train set using different (random) hyperparameters.
> 
> **2. Ensemble construction:** on a validation set using a greedy selection method.

Remark:
*In this notebook we use a slightly simplified version of the pipeline compared to the approach of the original paper (where an additional stratification step is used). Additionally, after selecting the optimal hyperparameters the ensemble performance can be improved even more by retraining the selected models on the full train set (i.e., this time not reserving a portion for validation). The simplified pipeline in this notebook is slightly less performant but easier to implement. The simplified pipeline is similar to the ones used by Caranua et al., 2004 and Zaidi et al., 2020 in the context of neural architecture search.*


## Step 1: Random Hyperparameter Search

We start by training 100 different versions of the ResNet-20 using different $\ell_2$-regularization parameters and label smoothing parameters. Since this would take some time we have already trained the models using a standard training script (which can be found [here](https://github.com/google/uncertainty-baselines/blob/main/baselines/cifar/deterministic.py)) and directly load the checkpoints (which can be browsed [here](https://console.cloud.google.com/storage/browser/gresearch/reliable-deep-learning/checkpoints/baselines/cifar/hyper_ensemble/)).



In [None]:
# The model architecture we want to form the ensemble over
# here, we use the original ResNet-20 model by He et al. 2015.
model = ub.models.wide_resnet(
    input_shape=ds_info.features['image'].shape,
    depth=22,
    width_multiplier=1,
    num_classes=num_classes,
    l2=0.,
    version=1)

In [None]:
# Load checkpoints:
# These are 100 checkpoints and loading will take a few minutes.
ensemble_filenames = parse_checkpoint_dir(CHECKPOINT_DIR)
model_pool_size = len(ensemble_filenames)
checkpoint = tf.train.Checkpoint(model=model)
print('Model pool size: {}'.format(model_pool_size))

Model pool size: 100


## Step 2: Construction of the hyperparameter ensemble on the validation set

First we compute the logits of all models in our model pool on the validation set.

In [None]:
# Compute the logits on the validation set.
val_logits, val_labels = [], []
for m, ensemble_filename in enumerate(ensemble_filenames):
  # Enforce memory clean-up.
  tf.keras.backend.clear_session()
  checkpoint.restore(ensemble_filename)
  val_iterator = iter(val_dataset)
  val_logits_m = []
  for _ in range(steps_per_val_eval):
    inputs = next(val_iterator)
    features = inputs['features']
    labels = inputs['labels']
    val_logits_m.append(model(features, training=False))
    if m == 0:
      val_labels.append(labels)

  val_logits.append(tf.concat(val_logits_m, axis=0))
  if m == 0:
    val_labels = tf.concat(val_labels, axis=0)

  if m % 10 == 0 or m == model_pool_size - 1:
    percent = (m + 1.) / model_pool_size
    message = ('{:.1%} completion for prediction on validation set: '
                'model {:d}/{:d}.'.format(percent, m + 1, model_pool_size))
    print(message)

1.0% completion for prediction on validation set: model 1/100.
11.0% completion for prediction on validation set: model 11/100.
21.0% completion for prediction on validation set: model 21/100.
31.0% completion for prediction on validation set: model 31/100.
41.0% completion for prediction on validation set: model 41/100.
51.0% completion for prediction on validation set: model 51/100.
61.0% completion for prediction on validation set: model 61/100.
71.0% completion for prediction on validation set: model 71/100.
81.0% completion for prediction on validation set: model 81/100.
91.0% completion for prediction on validation set: model 91/100.
100.0% completion for prediction on validation set: model 100/100.


Now we are ready to construct the ensemble.
* In the first step, we take the best model (on the validation set) -> `model_1`.
* In the second step, we fix `model_1` and try all models in our model pool and construct the ensemble `[model_1, model_2]`. We select the model `model_2` that leads to the highest performance gain.
* In the third step, we fix `model_1`, `model_2` and choose `model_3` to construct an ensemble `[model_1, model_2, model_3]` that leads to the highest performance gain over step 2.
* ... and so on, until the desired ensemble size is reached or no performance gain could be achieved anymore.

In [None]:
# Ensemble construction by greedy member selection on the validation set.
selected_members, val_acc, val_nll = greedy_selection(val_logits, val_labels,
                                                        ENSEMBLE_SIZE,
                                                        objective='nll')
unique_selected_members = list(set(selected_members))
message = ('Members selected by greedy procedure: model ids = {} (with {} '
            'unique member(s)).').format(
                selected_members, len(unique_selected_members))
print(message)

Members selected by greedy procedure: model ids = [38, 0, 81, 27] (with 4 unique member(s)).


# Evaluation on the test set

Let's see how the **hyper-deep ensemble** performs on the test set.

In [None]:
# Evaluate the following metrics on the test set.
metrics = {
    'ensemble/negative_log_likelihood': tf.keras.metrics.Mean(),
    'ensemble/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
}
metrics_single = {
    'single/negative_log_likelihood': tf.keras.metrics.SparseCategoricalCrossentropy(),
    'single/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
}


# Compute logits for each ensemble member on the test set.
logits_test = []
for m, member_id in enumerate(unique_selected_members):
  ensemble_filename = ensemble_filenames[member_id]
  checkpoint.restore(ensemble_filename)
  logits = []
  test_iterator = iter(test_dataset)
  for _ in range(steps_per_eval):
    features = next(test_iterator)['features']
    logits.append(model(features, training=False))
  logits_test.append(tf.concat(logits, axis=0))
logits_test = tf.convert_to_tensor(logits_test)
print('Completed computation of member logits on the test set.')

Completed computation of member logits on the test set.


In [None]:
# Compute test metrics.
test_iterator = iter(test_dataset)
for step in range(steps_per_eval):
  labels = next(test_iterator)['labels']
  logits = logits_test[:, (step*BATCH_SIZE):((step+1)*BATCH_SIZE)]
  labels = tf.cast(labels, tf.int32)
  negative_log_likelihood = _ensemble_cross_entropy(labels, logits)
  # Per member output probabilities.
  per_probs = tf.nn.softmax(logits)
  # Ensemble output probabilites.
  probs = tf.reduce_mean(per_probs, axis=0)
  metrics['ensemble/negative_log_likelihood'].update_state(
      negative_log_likelihood)
  metrics['ensemble/accuracy'].update_state(labels, probs)

  # For comparison compute performance of the best single model,
  # this is by definition the first model that was selected by the greedy 
  # selection method.
  logits_single = logits_test[0, (step*BATCH_SIZE):((step+1)*BATCH_SIZE)]
  probs_single = tf.nn.softmax(logits_single)
  metrics_single['single/negative_log_likelihood'].update_state(labels, logits_single)
  metrics_single['single/accuracy'].update_state(labels, probs_single)

  percent = (step + 1) / steps_per_eval
  if step % 25 == 0 or step == steps_per_eval - 1:
    message = ('{:.1%} completion final test prediction'.format(percent))
    print(message)

ensemble_results = {name: metric.result() for name, metric in metrics.items()}
single_results = {name: metric.result() for name, metric in metrics_single.items()}

0.6% completion final test prediction
16.7% completion final test prediction
32.7% completion final test prediction
48.7% completion final test prediction
64.7% completion final test prediction
80.8% completion final test prediction
96.8% completion final test prediction
100.0% completion final test prediction


## Here is the final ensemble performance

We gained almost 2 percentage points in terms of accuracy over the best single model!

In [None]:
print('Ensemble performance:')
for m, val in ensemble_results.items():
  print('   {}: {}'.format(m, val))

print('\nFor comparison:')
for m, val in single_results.items():
  print('   {}: {}'.format(m, val))

Ensemble performance:
   ensemble/negative_log_likelihood: 0.19807305932044983
   ensemble/accuracy: 0.9358974099159241

For comparison:
   single/negative_log_likelihood: 1.0325815677642822
   single/accuracy: 0.9189703464508057


## Hyper-deep ensembles as a strong baseline

We have seen that **hyper-deep ensembles** can lead to significant performance gains and can be easily implemented in your existing machine learning pipeline. Moreover, we hope that other researchers can benefit from this by using **hyper-deep ensembles** as a competitive, yet simple-to-implement, baseline. Even though **hyper-deep ensembles** might be more expensive than single model methods, it can show how much can be gained by introducing more diversity in the predictions.


## Hyper-deep ensembles can make your ML pipeline more robust

**Don't throw away your precious models!**

In many settings where we use a standard (single model) deep neural network, we usually start with a hyperparameter search. Typically, we select the model with the best hyperparameters and throw away all the others. Here, we show that you can get a much more performant system by combining multiple models from the hyperparameter search. 

**What's the additional cost?**

In most cases you already get a significant performance boost if you combine 4 models. The main additional cost (provided you have already done the hyperparameter search) is that your model is now 4x larger (more memory) and 4x times slower to perform the predictions (if not parallelized). Often the performance boost justifies this increased cost. If you can't afford the additional cost, check out **hyper-batch ensembles**. This is an efficient version that amortizes hyper-deep ensembles **within a single model** (see our [paper](https://arxiv.org/abs/2006.13570)).

## Pointers to additional resources

* The full code for the extended **hyper-deep ensembles** pipeline and the code for the experiments in our paper can be found in the [Uncertainty Baselines](https://github.com/google/uncertainty-baselines/blob/main/baselines/cifar/hyperdeepensemble.py) repository.
* Our efficient version **hyper-batch ensembles** that amortize hyper-deep ensembles within a single model is implemented as a  keras layer and can be found in [Edward2](https://github.com/google/edward2 ).

## For questions reach out to
Florian Wenzel ([florianwenzel@google.com](mailto:florianwenzel@google.com)) \
Rodolphe Jenatton ([rjenatton@google.com](mailto:rjenatton@google.com))


### Reference

If you use parts of this pipeline we would be happy if you would cite our paper.

> Florian Wenzel, Jasper Snoek, Dustin Tran and Rodolphe Jenatton (2020).
> [Hyperparameter Ensembles for Robustness and Uncertainty Quantification](https://arxiv.org/abs/2006.13570).
> In _Neural Information Processing Systems_.

```none
@inproceedings{wenzel2020good,
  author = {Florian Wenzel and Jasper Snoek and Dustin Tran and Rodolphe Jenatton},
  title = {Hyperparameter Ensembles for Robustness and Uncertainty Quantification},
  booktitle = {Neural Information Processing Systems)},
  year = {2020},
}
```


