# RDL Big Paper Plots

*Licensed under the Apache License, Version 2.0.*

To run this in a public Colab, change the GitHub link: replace github.com with [githubtocolab.com](http://githubtocolab.com).

This colab loads raw measurements from disk and analyzes the results.

## Choosing optimal hyperparameters
We automatically detect hyperparameter sweeps by selecting fields that don't correspond to dataset metrics but that have more than one chosen value. We choose the hyperparameters that achieve the best according a given metric (see `dataset_metric`) after averaging over random seeds. For example, if the model is trained on CIFAR-10, we use CIFAR-10's validation loss.

## Plots
All plots report the performance of a given model according to its optimal hyperparameters chosen above. When there are runs with multiple seeds, we show the mean and standard deviation.

In [None]:
from typing import Dict
import itertools
import os
import pickle

import colabtools.fileedit
from importlib import reload
from IPython import display
import matplotlib

matplotlib.rcParams['font.sans-serif'] = "Times New Roman"
matplotlib.rcParams['font.family'] = "sans-serif"
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import seaborn as sns
import numpy as np
import pandas as pd
import seaborn as sns
import tensorflow as tf

colab_utils = None

if colab_utils is None:
  !rm -rf uncertainty-baselines
  !git clone https://github.com/google/uncertainty-baselines.git
  !cp uncertainty-baselines/experimental/plex/colab_utils.py .
  import colab_utils

%matplotlib inline
%config InlineBackend.figure_format = 'retina'
matplotlib.rcParams['figure.dpi'] = 1000
matplotlib.rcParams['lines.linewidth'] = 1.25

## Functions

In [None]:
#@title Choosing optimal hyperparameters

# The finetuning deterministic jobs use a fixed random seed but different
# upstream checkpoints, which themselves correspond to different random seeds.
# In this case, we thus marginalize over upstream checkpoints
# (`config.model_init`) rather than the random seed.

DATASET_METRIC = {
    'cifar10': 'val_loss',
    'cifar100': 'val_loss',
    'imagenet2012': 'val_loss',
    'imagenet21k': 'val_loss',
    'jft/entity:1.0.0': 'val_loss',
    'retina_country': 'in_domain_validation/auroc',
    'retina_severity': 'in_domain_validation/auroc',
    'imagenet_variants': 'imagenet/nll',
}


def get_optimal_results(measurements: Dict[str, pd.DataFrame],
                        dataset_metric: Dict[str, str] = DATASET_METRIC,
                        verbose=True) -> pd.DataFrame:
  """Returns a dataframe, typically with one result per model type.

  A model type may have multiple results that will be averaged over when
  plotting (e.g., random seeds).

  Args:
    measurements: Dictionary of dataframes to obtain best results for.
    dataset_metric: Each dataset's metric to tune for, in the format
      `{dataset: metric}`.
  """
  results = []

  model_to_marginalization_hparams = {
      m: 'config.model_init'
      for m in ('Det', 'Det I21K', 'DE', 'DE S/32', 'DE B/32', 'DE L/32',
                'Det->DE', '[Det]_4', 'Det->[Det]_4', 'Det->BE')
  }
  model_to_marginalization_hparams.update({
      m: 'config.dune_experts.xid_wid'
      for m in ('MoE', 'E^3', '[MoE]_4', 'MoE->[MoE]_4')
  })

  for k, v in measurements.items():
    marginalization_hparams = (colab_utils.random_seed_col(),)
    if k in model_to_marginalization_hparams:
      marginalization_hparams += (model_to_marginalization_hparams[k],)
    for ds in v[colab_utils.dataset_col()].unique():
      df = v[v[colab_utils.dataset_col()] == ds]
      try:
        results.append(
            colab_utils.get_tuned_results(
                df,
                tuning_metric=dataset_metric[ds],
                marginalization_hparams=marginalization_hparams,
                verbose=verbose))
      except KeyError:
        print(f'Could not get optimal results for {k}, {ds}.')
    print()
  return pd.concat(results)


def get_optimal_fewshot_results(measurements: Dict[str, pd.DataFrame],
                                verbose=True) -> pd.DataFrame:
  """Returns a dataframe, typically with one result per model type.

  A model type may have multiple results that will be averaged over when
  plotting (e.g., random seeds).

  Args:
    measurements: Dictionary of dataframes to obtain best results for.
  """
  results = []
  for k, v in measurements.items():
    marginalization_hparams = (colab_utils.random_seed_col(),)
    marginalization_hparams += ('config.model_init',)
    for ds in v[colab_utils.dataset_col()].unique():
      df = v[v[colab_utils.dataset_col()] == ds]
      try:
        # Gets the model and dataset and standard hps.
        dataset = colab_utils.get_unique_value(df, colab_utils._DATASET_COL)
        model = colab_utils.get_unique_value(df, colab_utils._MODEL_COL)
        hps = colab_utils.get_sweeped_hyperparameters(df, marginalization_hparams)

        # Finds the best l2 reg for each shot experiment.
        best_l2 = {}
        non_metric_cols = [c for c in df.columns if '/' not in c]
        dfs_optimal = []
        for shot in [1, 5, 10, 25]:
          tuning_metrics = [c for c in df.columns if c.endswith(f'_{shot}/test_prec@1')]
          marginalized_df = df.groupby(hps)[tuning_metrics].agg('mean').reset_index()
          reg_ranks = []
          for tuning_metric in tuning_metrics:
            reg_accus = marginalized_df[tuning_metric].to_numpy()
            reg_ranks.append(np.argsort(np.argsort(reg_accus)))
          best_l2[shot] = marginalized_df['config.l2_reg'][np.argmax(np.mean(reg_ranks, axis=0))]

          for fewshot_ds in colab_utils.default_fewshot_datasets():
            ds_shot_specific_metric_cols = [c for c in df.columns if f'{fewshot_ds}_{shot}/' in c]
            dfc = df[df['config.l2_reg']==best_l2[shot]][non_metric_cols + ds_shot_specific_metric_cols].copy()
            dfc = dfc.rename(columns={m: str(shot) + 'shot_' + m.split('/')[1] for m in ds_shot_specific_metric_cols})
            dfc['config.dataset'] = f'few-shot {fewshot_ds}'
            dfs_optimal.append(dfc)
        results.append(pd.concat(dfs_optimal))
      except KeyError:
        print(f'Could not get optimal results for {k}, {ds}.')
  return pd.concat(results)

In [None]:
#@title Pretty printing 

def pprint(df, models=None, exclude_models=None):
  """Pretty print dataframe.

  Args:
    df: Dataframe.
    models: Optional list of models to only show. Useful for comparing specific
      models to see which performs better (highlighted cells).
    exclude_models: Optional list of models to exclude.
  """
  def _rename(m):
    m = m.replace('cifar_10h', 'cifar10h')
    m = m.replace('places365_small', 'places365')
    m = m.replace('imagenet_', 'imagenet-')
    m = m.replace('/mean', '')
    m = m.replace('/', ' ')
    m = m.replace('_', ' ')
    m = m.replace('cropped ', '')
    m = m.replace('ood', '')
    m = m.replace('ece', 'ECE')
    m = m.replace('auc', 'AUC')
    m = m.replace('auroc', 'AUROC')
    m = m.replace('loss', 'NLL')
    m = m.replace('negative log likelih', 'NLL')
    m = m.replace('nll', 'NLL')
    m = m.replace('brier', 'Brier')
    m = m.replace('mce', 'mCE')
    m = m.replace('pmk', 'p-mk')
    return m
  def _formatter(metric):
    if any(x in metric for x in ['AUROC', 'AUC']):
      return '{:.2f}'.format
    elif any(x in metric for x in ['prec', 'ECE', 'accuracy']):
      return lambda x: '{:.1f}%'.format(x * 100)
    elif any(x in metric for x in ['score', 'exaflops', 'tpu days', 'gflops', 
                                   'ms step']):
      return lambda x: '{:.1f}'.format(x)
    elif any(x in metric for x in ['NLL', 'Brier']):
      return '{:.3f}'.format
    else:
      return lambda x: x
  def _highlight(data, color='#90EE90'):
    attr = 'background-color: {}'.format(color)
    data = data.replace('%','', regex=True).astype(float)
    if any(x in data.name[1] for x in ['NLL', 'ECE', 'Brier', 'mCE',
                                       'relative mCE', 'accuracy drop',
                                       'accuracy pm-k']):
      is_best = data == data.min()
    elif any(x in data.name[1] for x in ['exaflops', 'tpu days', 'gflops',
                                         'ms step']):
      is_best = data == 'asdf'
    else:
      is_best = data == data.max()
    return [attr if v else '' for v in is_best]

  df = df.copy()
  df = df.rename(columns=_rename)
  for c in df:
    df[c] = df[c].apply(_formatter(c[0]))

  # Swap order of column's multiindex to be dataset first.
  df.columns = df.columns.swaplevel(0, 1)
  df = df.sort_index(axis=1, level=0)

  df = df.T
  if models is not None:
    df = df[[c for c in df.columns if c in models]]
  elif exclude_models is not None:
    df = df[[c for c in df.columns if c not in exclude_models]]

  return display.display(df.style.apply(_highlight, axis=1))

In [None]:
#@title RETINA
REBUILD_RETINA_RESULTS_CACHE = False

if REBUILD_RETINA_RESULTS_CACHE:
  import os
  os.system('pip install wandb')
  import wandb

# TODO(nband): add grid search results (currently random search).
RETINA_SHIFT_AND_UQ_METHOD_TO_WANDB = {
  ('aptos', 'deterministic'): 'vit32-finetune-aptos-deterministic-focused-3',
  ('aptos', 'batchensemble'): 'vit32-finetune-aptos-batchensemble',
  ('severity', 'deterministic'): 'vit32-finetune-severity-deterministic',
  ('severity', 'batchensemble'): 'vit32-finetune-severity-batchensemble-focused-1'
}

RETINA_SHIFTS = ['aptos', 'severity']
RETINA_UQ_METHODS = ['deterministic', 'batchensemble']
RETINA_UQ_METHOD_TO_DF_NAME = {
    'deterministic': 'Det I21K',
    'batchensemble': 'BE L/32 (I21K)'
}

RETINA_SHIFT_TO_METRICS = {
  'aptos': [
    # In-Domain
    'in_domain_test.in_domain_test/accuracy',
    'in_domain_test.in_domain_test/negative_log_likelihood',
    'in_domain_test.in_domain_test/ece',
    'in_domain_test.in_domain_test/retention_auroc_auc',
    # OOD
    'ood_test.ood_test/accuracy',
    'ood_test.ood_test/negative_log_likelihood',
    'ood_test.ood_test/ece',
    'ood_test.ood_test/retention_auroc_auc'
  ],
  'severity': [
    # In-Domain
    'in_domain_test.in_domain_test/accuracy',
    'in_domain_test.in_domain_test/negative_log_likelihood',
    'in_domain_test.in_domain_test/ece',
    'in_domain_test.in_domain_test/retention_auroc_auc',
    # OOD
    'ood_test.ood_test/accuracy',
    'ood_test.ood_test/negative_log_likelihood',
    'ood_test.ood_test/ece',
    'ood_test.ood_test/retention_accuracy_auc'
  ]
}
RETINA_MODEL_SELECTION_METRIC = 'in_domain_validation.in_domain_validation/auroc'

# Split RETINA results into the two distributional shifts: Country Shift and
# Severity Shift.

SHIFT_MAP = {'aptos': 'country', 'severity': 'severity'}


def select_top_model_from_project(project_name):
  api = wandb.Api(timeout=100000000)
  runs = api.runs(project_name)
  print(f'Retrieved run results from Weights & Biases project {project_name}.')
  sweep_history_df = []

  # Get all full histories
  for run in runs:
    run_history_df = pd.DataFrame(run._full_history())

    # Add run name
    run_history_df['run_name'] = run.name
    sweep_history_df.append(run_history_df)

  sweep_history_df = pd.concat(sweep_history_df)
  sweep_history_df.reset_index(inplace=True)

  # Best performing step of the best performing model
  top_idx = sweep_history_df[RETINA_MODEL_SELECTION_METRIC].idxmax()
  return sweep_history_df.iloc[top_idx]


def get_retina_i21k_results_df():
  all_results_df = []
  for shift in RETINA_SHIFTS:
    for uq_method in RETINA_UQ_METHODS:
      print(f'Retrieving results from shift {shift}, '
            f'uncertainty quantification method {uq_method}.')
      wandb_project = RETINA_SHIFT_AND_UQ_METHOD_TO_WANDB[(shift, uq_method)]
      model_results = select_top_model_from_project(wandb_project)
      result_df = model_results.to_frame().T
      result_df['shift'] = shift
      result_df['uq_method'] = uq_method
      all_results_df.append(result_df)

  return pd.concat(all_results_df)


def add_retina_i21k_results(retina_results_df, preprocessed_df, shift_map=SHIFT_MAP):
  for shift in RETINA_SHIFTS:
    for uq_method in RETINA_UQ_METHODS:
      print(f'Adding results from shift {shift}, '
            f'uncertainty quantification method {uq_method}.')
      model_results = retina_results_df[
        (retina_results_df['shift'] == shift) &
        (retina_results_df['uq_method'] == uq_method)]
      n_results = len(model_results)
      assert n_results == 1, f'Found {n_results} model results, expected 1.'
      model_results = model_results.iloc[0]
      metrics = RETINA_SHIFT_TO_METRICS[shift]
      for metric in metrics:
        df_metric_name = metric.split('.')[1]
        per_metric_result = model_results[metric]
        shift_df_name = shift_map[shift]
        metric_shift_series = preprocessed_df[(
          df_metric_name, f'retina_{shift_df_name}')]
        metric_shift_series[
          RETINA_UQ_METHOD_TO_DF_NAME[uq_method]] = per_metric_result
        preprocessed_df[
          (df_metric_name, f'retina_{shift_df_name}')] = metric_shift_series

  return preprocessed_df

if REBUILD_RETINA_RESULTS_CACHE:
  # Retrieve RETINA I21K results from Weights & Biases
  retina_i21k_results_df = get_retina_i21k_results_df()

  # Store RETINA results in gs bucket
  retina_ub_gs_file_path = 'gs://retina-i21k-results-df/retina-i21k-results.tsv'
  with tf.io.gfile.GFile(retina_ub_gs_file_path, 'w') as f:
    retina_i21k_results_df.to_csv(f, sep='\t', index=None)


def add_distribution_shift_to_retina_ds_name(row):
  dataset = str(row['config.dataset'])
  if dataset == 'retina':
    shift = SHIFT_MAP[str(row['config.distribution_shift'])]
    row['config.dataset'] = f'{dataset}_{shift}'

  return row

def split_retina_results_by_shifts(raw_dict):
  for model in raw_dict.keys():
    raw_model_df = raw_dict[model]
    if not len(raw_model_df[raw_model_df['config.dataset'] == 'retina']):
        continue

    print(f'Splitting RETINA results for model {model} by distribution shift.')

    raw_model_df = raw_model_df.apply(
        add_distribution_shift_to_retina_ds_name, axis='columns')
    raw_dict[model] = raw_model_df

  return raw_dict

## Load and preprocess measurements

In [None]:
load_from_cloud = True
if load_from_cloud == True:
  from google.colab import auth
  auth.authenticate_user()

  project_id = 'marginalization-external-xgcp'
  !gcloud config set project {project_id}

  measurements_path = '/tmp/big-paper-raw-measurements.pkl'
  !gsutil cp gs://ub-checkpoints/big-paper-raw-measurements.pkl {measurements_path}

  retina_path = '/tmp/retina-i21k-results.tsv'
  !gsutil cp gs://retina-i21k-results-df/retina-i21k-results.tsv {retina_path}

  fewshot_measurements_path = '/tmp/big-paper-raw-measurements-fewshot.pkl'
  !gsutil cp gs://ub-checkpoints/big-paper-raw-measurements-fewshot.pkl {fewshot_measurements_path}

In [None]:
with tf.io.gfile.GFile(measurements_path, 'rb') as f:
  raw_measurements = pickle.load(f)

with tf.io.gfile.GFile(retina_path, 'r') as f:
  retina_i21k_results_df = pd.read_csv(f, sep='\t')

with tf.io.gfile.GFile(fewshot_measurements_path, 'rb') as f:
  fewshot_raw_measurements = pickle.load(f)

In [None]:
raw_measurements = split_retina_results_by_shifts(raw_measurements)

excluded_keys = [
    'DE', 'Det->DE', 'DE S/32', 'Det->DE S/32', 'DE B/32', 'Det->DE B/32',
    'DE L/32', 'Det->DE L/32', 'Det -> BE L/32 (n=2)', 'Det -> BE L/32 (n=4)',
    'Det -> BE L/32 (n=8)'
]
included_measurements = {
    k: v for k, v in raw_measurements.items() if k not in excluded_keys
}
included_measurements['DE'] = raw_measurements['DE L/32'].query(
    'ensemble_size == 3')
included_measurements['Det->DE'] = raw_measurements['Det->DE L/32'].query(
    'ensemble_size == 3')
# We fetch the deep ensembles of size 4 to compare with MoEs also of size 4.
# In that case, we follow the terminology [MoE]_4 and use [Det]_4. We keep DE to
# refer to the deep ensemble used everywhere else in the paper (size 3).
included_measurements['[Det]_4'] = raw_measurements['DE L/32'].query(
    'ensemble_size == 4')
included_measurements['[Det]_4'].loc[:, 'model'] = '[Det]_4'

included_measurements['Det->[Det]_4'] = raw_measurements['Det->DE L/32'].query(
    'ensemble_size == 4')
included_measurements['Det->[Det]_4'].loc[:, 'model'] = 'Det->[Det]_4'

measurements = get_optimal_results(included_measurements)

df = colab_utils.process_tuned_results(measurements)
df = add_retina_i21k_results(
    retina_results_df=retina_i21k_results_df, preprocessed_df=df)

In [None]:
# Gets tuned fewshot measurements.
fewshot_measurements = get_optimal_fewshot_results(fewshot_raw_measurements)

# Prepares fewshot measurement to inject to df.
relevant_metrics = [c for c in fewshot_measurements.columns if 'shot' in c]
fewshot_df = colab_utils.process_tuned_results(fewshot_measurements, relevant_metrics)

# Add the fewshot results for the comparison with sparse MoE's.
moe_fewshot_df = colab_utils.process_fewshot_for_moe_comparison(included_measurements)

fewshot_df = pd.concat((fewshot_df, moe_fewshot_df))

# Removes upstream fewshot results.
fewshot_metrics_to_del = [m for m in df.columns.levels[0] if 'shot' in m]
df = df.drop(columns=fewshot_metrics_to_del, level=0)
df.columns = df.columns.remove_unused_levels()

# Adds fewshot results from fewshot_df.
df = pd.concat([df, fewshot_df], axis=1)

## Compute reliability score and generate table

In [None]:
datasets = [
    'cifar10',
    'cifar100',
    'imagenet2012',
    # 'imagenet_variants',
    # 'retina_country',
    # 'retina_severity',
]
datasets += [f'few-shot {d}' for d in colab_utils.default_fewshot_datasets()]

scores = colab_utils.compute_score(
    df, datasets=datasets, drop_1shot=True,
    drop_incomplete_measurements=False) * 100

score_cols = [
    'score', 'score_prediction', 'score_uncertainty', 'score_adaptation'
]
display.display(scores[score_cols])

In [None]:
df_with_scores = df.copy()
for column in score_cols:
  df_with_scores[column] = scores[column]

pprint(
    df_with_scores,
    # models=['BE L/32', 'Det'],
    # exclude_models=['DE', 'Det->DE'],
)

In [None]:
# Show a subset of the table's metrics + models
metrics = ['score', 'score_prediction', 'score_uncertainty', 'score_adaptation',
           'exaflops', 'test_loss', 'tpu_days']
models = ['BE L/32', 'Det', 'GP', 'Het', 'BE L/32 (I21K)', 'Det I21K',
          'BE->BE+Het', 'E^3', '[Det]_4', '[MoE]_4']
pprint(df_with_scores.loc[models][metrics].rename(
    columns={'compute': 'z/compute'}))

## Plot reliability score

In [None]:
def pareto_plot(df, x, y, ax, filename=None, **kwargs):
  def is_on_pareto_front(p, points, higher_is_better):
    if higher_is_better:
      return len([
          point for point in points if point[0] <= p[0] and point[1] > p[1]
      ]) == 0
    else:
      return len([
          point for point in points if point[0] <= p[0] and point[1] < p[1]
      ]) == 0
  def get_pareto_points(x, y, higher_is_better=True):
    points = list(zip(x, y))
    frontier = [
        p for p in points if is_on_pareto_front(p, points, higher_is_better)
    ]
    return sorted(frontier, key=lambda x: x[0])
  for model, point in df.iterrows():
    ann = ax.annotate(
        '  ' + model,
        xy=(point[x], point[y]),
        ha='left',
        va='bottom',
  )
  sns.scatterplot(x=df[x], y=df[y], ax=ax)
  pareto_frontier = get_pareto_points(df[x], df[y])
  xx, yy = zip(*pareto_frontier)
  sns.lineplot(x=xx, y=yy, linestyle='--', ax=ax)
  ax.set(xscale='log', **kwargs)
  if filename is not None:
    plt.tight_layout()
    plt.savefig(filename)
    colabtools.fileedit.download_file(filename)

fig, ax = plt.subplots(figsize=(10.0, 5.0))
pareto_plot(
    df_with_scores[[x.startswith('BE') for x in df_with_scores.index.values]],
    ax=ax,
    y='score',
    x=('tpu_days', 'compute'),
    xlabel='Compute (TPUv3 core days)',
    ylabel='Reliability Score',
    filename='reliability.png',
)

fig, axes = plt.subplots(1, 3, figsize=(3.5 * 3, 3.5))
pareto_plot(
    df_with_scores[[x.startswith('BE') for x in df_with_scores.index.values]],
    ax=axes[0],
    y='score_prediction',
    x=('tpu_days', 'compute'),
    xlabel=None,
    ylabel=None,
    title='Reliability Score (Prediction)',
)
pareto_plot(
    df_with_scores[[x.startswith('BE') for x in df_with_scores.index.values]],
    ax=axes[1],
    y='score_uncertainty',
    x=('tpu_days', 'compute'),
    xlabel=None,
    ylabel=None,
    title='Reliability Score (Uncertainty)',
)
pareto_plot(
    df_with_scores[[x.startswith('BE') for x in df_with_scores.index.values]],
    ax=axes[2],
    y='score_adaptation',
    x=('tpu_days', 'compute'),
    xlabel=None,
    ylabel=None,
    title='Reliability Score (Adaptation)',
)
filename = 'reliability_components.png'
plt.tight_layout()
plt.savefig(filename)
colabtools.fileedit.download_file(filename)

## Analyze correlation of metrics

In [None]:
temp_df = colab_utils.process_tuned_results(
    measurements,
    relevant_metrics=colab_utils.default_selected_metrics() +
    ['training_loss', 'training_prec@1'])
datasets = [
    'cifar10',
    'cifar100',
    'imagenet2012',
]
datasets += [f'few-shot {d}' for d in colab_utils.default_fewshot_datasets()]
temp_scores = colab_utils.compute_score(
    temp_df,
    datasets=datasets,
    drop_1shot=True,
    drop_incomplete_measurements=True)
for column in score_cols:
  temp_df[column] = temp_scores[column]

# scores correlation matrix
columns = ['score', 'score_prediction', 'score_uncertainty', 'score_adaptation']
corr_matrix = temp_df[columns]
corr_matrix.columns = [''.join(col) for col in corr_matrix.columns.values]
corr_matrix = corr_matrix.corr()
display.display(corr_matrix)

# upstream test metrics
metrics = ['score', 'score_prediction', 'score_uncertainty', 'score_adaptation']
corr_matrix = temp_df.corr()[['test_loss', 'test_prec@1']].T.xs(
    'jft/entity:1.0.0', level='dataset')
corr_matrix = corr_matrix[metrics]
corr_matrix.columns = [''.join(col) for col in corr_matrix.columns.values]
display.display(corr_matrix)

# imagenet 10-shot. It doesn't correlate well with reliability, mostly due to
# it not correlating well surprisingly on other few-shot tasks.
corr_matrix = temp_df.corr()[['10shot_prec@1']].T.xs(
    'few-shot imagenet', level='dataset')
corr_matrix = corr_matrix[metrics]
corr_matrix.columns = [''.join(col) for col in corr_matrix.columns.values]
display.display(corr_matrix)

# downstream training loss. The correlation is not nearly as tight as on
# upstream.
corr_matrix = temp_df.corr()[['training_loss']].T
corr_matrix = corr_matrix[metrics + ['test_loss']]
corr_matrix = corr_matrix.drop(index=('training_loss', 'retina_country'))
corr_matrix = corr_matrix.drop(index=('training_loss', 'retina_severity'))
corr_matrix = corr_matrix.drop(index=('training_loss', 'imagenet21k'))
corr_matrix = corr_matrix.drop(columns=('test_loss', 'imagenet21k'))
# Display test loss only for training loss' same downstream dataset. Looking at
# cifar10's train loss correlation with I1K's test loss isn't meaningful.
test_loss = pd.Series(
    np.diag(corr_matrix['test_loss']), index=corr_matrix['test_loss'].index)
corr_matrix = corr_matrix.drop(columns='test_loss')
corr_matrix['test_loss'] = test_loss
corr_matrix.columns = [''.join(col) for col in corr_matrix.columns.values]
display.display(corr_matrix)

# Similar to old plot in go/rdl-big-meeting, even generalization gap decreases.
# And downstream is not very indicative, but upstream is.
temp_df2 = temp_df.copy()
for d in temp_df2['test_loss'].columns:
  temp_df2['reg_loss',
           d] = temp_df2['test_loss', d] - temp_df2['training_loss', d]

corr_matrix = temp_df2.corr()[['reg_loss']].T
corr_matrix = corr_matrix[metrics + ['training_loss']]
corr_matrix = corr_matrix.drop(index=('reg_loss', 'imagenet21k'))
display.display(corr_matrix)

In [None]:
corr_matrix = temp_df.corr()[['test_loss', 'test_prec@1', 'training_loss']].T.xs('jft/entity:1.0.0', level='dataset')

# Rename certain task metrics to be under their generic metric name. This way,
# we can average values across that metric.
corr_matrix.columns = corr_matrix.columns.values
corr_matrix.columns = pd.MultiIndex.from_tuples(corr_matrix.rename(columns={
    ('imagenet_real_calib_auc', 'imagenet2012'): ('test_calib_auc', 'imagenet_real'),
    ('imagenet_real_ece', 'imagenet2012'): ('test_ece', 'imagenet_real'),
    ('imagenet_real_loss', 'imagenet2012'): ('test_loss', 'imagenet_real'),
    ('imagenet_real_prec@1', 'imagenet2012'): ('test_prec@1', 'imagenet_real'),
    ('cifar_10h_calib_auc', 'cifar10'): ('test_calib_auc', 'cifar_10h'),
    ('cifar_10h_ece', 'cifar10'): ('test_ece', 'cifar_10h'),
    ('cifar_10h_loss', 'cifar10'): ('test_loss', 'cifar_10h'),
    ('cifar_10h_prec@1', 'cifar10'): ('test_prec@1', 'cifar_10h'),
    ('ood_cifar100_msp_auroc', 'cifar10'): ('msp_auroc', 'cifar10->cifar100'),
    ('ood_cifar10_msp_auroc', 'cifar100'): ('msp_auroc', 'cifar100->cifar10'),
    ('ood_places365_small_msp_auroc', 'imagenet2012'): ('msp_auroc', 'imagenet2012->places365'),
    ('ood_svhn_cropped_msp_auroc', 'cifar10'): ('msp_auroc', 'cifar10->svhn'),
    ('ood_svhn_cropped_msp_auroc', 'cifar100'): ('msp_auroc', 'cifar100->svhn'),
}))

corr_matrix = corr_matrix.sort_index(axis=1)
corr_matrix = corr_matrix.mean(level=0, axis='columns')
corr_matrix = abs(corr_matrix)
corr_matrix = corr_matrix.reindex(
    corr_matrix.mean().sort_values().index, axis=1)
for metric in corr_matrix.columns:
  if metric.startswith('score') or metric in ['exaflops', 'tpu_days', 'gflops', 'ms_step']:
    del corr_matrix[metric]
corr_matrix = corr_matrix.T.reset_index()

fig, ax = plt.subplots(figsize=(20.0, 5.0))
sns.barplot(x='index', y='test_loss', data=corr_matrix)
ax.set(xlabel=None)
ax.set(ylabel=r'$\rho(\cdot,$ test_loss)')

filename = 'correlation.png'
plt.tight_layout()
plt.savefig(filename)
colabtools.fileedit.download_file(filename)

## Plot Relative Score and Rankings

In [None]:
datasets = [
    'cifar10',
    'cifar100',
    'imagenet2012',
    # 'imagenet_variants',
]
datasets += [f'few-shot {d}' for d in colab_utils.default_fewshot_datasets()]
rel_scores = colab_utils.compute_score(
    df,
    drop_1shot=True,
    datasets=datasets,
    baseline_model='Det',
    drop_incomplete_measurements=True)
plt.rc('figure', figsize=(20, 20))

print("Average relative score and ranks across categories")
display.display(rel_scores)

print("==" * 50)
display.display(df_with_scores)

print("Full dataframe")
display.display(df)

# Plot rank distribution
ranks = colab_utils.rank_models(
    df, drop_1shot=True, datasets=datasets, drop_incomplete_measurements=True)
ax = sns.violinplot(data=ranks.T)
ax.set_xticklabels(ax.get_xticklabels(),rotation = 45)
ax.set_ylabel('Ranking')
print("==" * 50)
print("Rankings")
display.display(ranks)

ranks_by_category = colab_utils.rank_models_by_category(
    df, drop_1shot=True, datasets=datasets, drop_incomplete_measurements=False)
for key, rank_df in ranks_by_category.items():
  plt.figure()
  ax = sns.violinplot(data=rank_df.T)
  ax.set_xticklabels(ax.get_xticklabels(),rotation = 45)
  ax.set_ylabel('Ranking - %s' % key)

In [None]:
#@title Radar plot comparing to SOTA
# Note that the plots in the paper are edited
# posthoc in illustrator to improve text placement.
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from matplotlib.ticker import MaxNLocator

# Preliminaries
fontsize = 24
fontfamily = 'sans-serif'
radar_filename = "vision-plex-radar.pdf"

matplotlib.rcParams['figure.dpi'] = 1000
matplotlib.rcParams['lines.linewidth'] = 1.25
matplotlib.rcParams["mathtext.fontset"] = "cm"
matplotlib.rcParams['font.family'] = fontfamily
matplotlib.rcParams['font.sans-serif'] = 'Times New Roman'
matplotlib.rcParams['font.size'] = fontsize
matplotlib.rcParams['ps.fonttype'] = 42
matplotlib.rcParams['pdf.fonttype'] = 42

methods = ['Plex L']
colors = ['royalblue', 'orangered', 'tab:blue', 'cornflowerblue', 'r']

# Add these manually for now
df_with_scores.loc["BE L/32", ("AL Accuracy", "cifar10")] = .9640 # Margin JFT
df_with_scores.loc["BE L/32", ("AL Accuracy", "cifar100")] = .8739 # Margin JFT
df_with_scores.loc["BE L/32", ("AL Accuracy", "places365_small")] = .8739 # Margin JFT
df_with_scores.loc["BE L/32", ("AL Accuracy", "imagenet")] = 0.771687 # Margin JFT

df_with_scores.loc["Det", ("AL Accuracy", "cifar10")] = .95 # Margin JFT RDL AL Meeting Notes - eyeballed
df_with_scores.loc["Det", ("AL Accuracy", "imagenet")] = 0.73 # RDL AL Meeting Notes - eyeballed
df_with_scores.loc["Det", ("AL Accuracy", "cifar100")] = .65 # Margin JFT

df_with_scores.loc["SOTA", ("cifar_10h_loss", "cifar10")] = 0.26
df_with_scores.loc["SOTA", ("ood_cifar10_msp_auroc", "cifar100")] = .9208
df_with_scores.loc["SOTA", ("ood_cifar100_msp_auroc", "cifar10")] = .9775
df_with_scores.loc["SOTA", ("in_domain_test/accuracy", "retina_country")] = .916

# https://arxiv.org/pdf/1911.11132.pdf
df_with_scores.loc["SOTA", ('ood_places365_small_msp_auroc', 'imagenet2012')] = 0.79

# https://arxiv.org/pdf/2201.07459.pdf (Figure 3, 1k examples)
df_with_scores.loc["SOTA", ("AL Accuracy", "cifar10")] = .56
# https://arxiv.org/pdf/2107.14263.pdf (Figure 3, 10k examples)
df_with_scores.loc["SOTA", ("AL Accuracy", "cifar100")] = .40
# # https://arxiv.org/pdf/1911.11132.pdf (Table 5, 10k examples)
# df_with_scores.loc["SOTA", ("AL Accuracy", "places365")] = .40

# https://arxiv.org/pdf/2111.12880.pdf (Figure 3, 30k examples)
df_with_scores.loc["SOTA", ('ood_test/selpred_accuracy_auc', 'retina_country')] = .797
df_with_scores.loc["BE L/32", ('ood_test/selpred_accuracy_auc', 'retina_country')] = .848
df_with_scores.loc["Det", ('ood_test/selpred_accuracy_auc', 'retina_country')] = .795

# Retina Selective Prediction
df_with_scores.loc["SOTA", ("AL Accuracy", "imagenet")] = .54

# Subpopulation shift
# https://arxiv.org/pdf/2110.14216.pdf (Table 2, 25th %'ile)
# CIFAR-10 Plex: .990
# CIFAR-10 SOTA: .815
df_with_scores.loc["BE->BE+Het", ("subpopulation", "cifar10")] = .990
df_with_scores.loc["SOTA", ("subpopulation", "cifar10")] = .815
#CIFAR-100 Plex: .931
#CIFAR-100 SOTA: .528
df_with_scores.loc["BE->BE+Het", ("subpopulation", "cifar100")] = .931
df_with_scores.loc["SOTA", ("subpopulation", "cifar100")] = .528

radar_df = pd.DataFrame(index=df_with_scores.index.copy())
cols = [("cifar_10h_loss", "cifar10"),
        ("ood_cifar10_msp_auroc", "cifar100"),
        ("ood_cifar100_msp_auroc", "cifar10"),
        ('ood_places365_small_msp_auroc', 'imagenet2012'),
        ("in_domain_test/accuracy", "retina_country"),
        ('ood_test/selpred_accuracy_auc', 'retina_country'),
        ("AL Accuracy", "cifar10"),
        ("AL Accuracy", "cifar100"),
        ("AL Accuracy", "imagenet"),
        ("subpopulation", "cifar10"),
        ("subpopulation", "cifar100")]
radar_df = df_with_scores.loc[:, cols].copy()
radar_df.rename(index={'SOTA': 'SOTA (specialized)'}, inplace=True)

def add_default_model_results(df, model, default):
  """Given, say, BE->BE+Het, we'd like to default its adaptation #s to BE."""
  df_copy = df.copy()
  df_copy.loc[[model]] = df_copy.loc[[model]].fillna(df_copy.loc[default], axis=0)
  return df_copy

radar_df = add_default_model_results(radar_df, 'BE->BE+Het', 'BE L/32')
radar_df = radar_df.rename(index={
    'BE->BE+Het': 'Plex L',
    'Det': 'None L',
})

plt.figure(figsize=(20, 20))
plt.tight_layout()
plt.rc('figure', figsize=(20, 20))
ax = plt.subplot(1, 1, 1, polar=True)

max_val = 1.0
methods.append('SOTA (specialized)')
xticklabels = cols.copy()
xticklabels = ["Negative KL\nCIFAR10H",
               "OOD AUROC \nCIFAR100 vs 10",
               "OOD AUROC  \n    CIFAR10 vs 100",
               "OOD AUROC \nImageNet vs Places365",
               "Accuracy\n    RETINA (Country)",
               "  Selective Prediction\nRETINA\n      (OOD Country Shift)",
               "Active Learning Acc.  \nCIFAR10 @1k    ",
               "Active Learning Acc.\n  CIFAR100 @10k",
               "Active Learning Acc.\n  ImageNet @30k",
               "Subpopulation Acc.   \nCIFAR10",
               "Subpopulation Acc.\nCIFAR100 "]
# Ranges for each y-axis corresponding to each ticklabel above.
yranges = [(0.2, 0.55), # CIFAR10h LL
           (0.8, 1.), (0.9, 1.0), (0.73, 0.9), # OOD
           (0.85, 0.92), (0.7, 0.9), # Retina
           (0.4, 1.), (0.3, 1.), (0.4, 0.9), #AL
           (0.7, 1.), (0.5, 0.95)] # Subpopulation

# Replot for each method
for i, m in enumerate(methods):
  colab_utils.make_radar_plot(radar_df, m, colors[i], max_val, ax,
                              xticklabels, yranges, fontfamily=fontfamily)
legend_elements = [Patch(facecolor=colors[i], edgecolor='k',
                         label=m) for i, m in enumerate(methods)]

# Attempt to get labels on top of grid (zorder seems buggy in polar plots)
ax.xaxis.set_zorder(0.1)
ax.yaxis.set_zorder(0.1)
ax.yaxis.grid(True, zorder=1)
ax.xaxis.grid(True, zorder=1)
ax.grid(True, zorder=0.1)

# Create the legend
font = font_manager.FontProperties(family=fontfamily,
                                   weight='normal',
                                   style='normal', size=fontsize)
# Use the legend from the text plot.
# plt.legend(handles=legend_elements, loc='lower right', bbox_to_anchor=(1.26, 0.025), prop=font)
ax.tick_params(axis='x', which='major', pad=120)

if radar_filename is not None:
  plt.savefig(radar_filename, bbox_inches='tight', pad_inches=0)
  colabtools.fileedit.download_file(radar_filename)

# The following comment is potentially a much cleaner way to produce this plot.
#
# radar_df = radar_df.loc[['Plex L', 'SOTA (specialized)']].copy()
# radar_df.columns = ['_'.join(col).strip() for col in radar_df.columns.values]
# cols = ['_'.join(col).strip() for col in cols]
# xtickdict = {col:xticklabels[i] for i, col in enumerate(cols)}

# radar_df['model'] = radar_df.index
# radar_df = radar_df.melt(
#     id_vars=['model'],
#     var_name='metric',
#     value_name='value')
# radar_df = radar_df.rename(columns={('model',''):'model'})
# display.display(radar_df)

# fig = px.line_polar(radar_df, r='value', theta='metric', color='model', line_close=True, labels=xtickdict)
# fig.update_traces(fill='toself')

# fig.show()

# Plotting helpers

In [None]:
#@title Bar plots
def plot_metrics(df, train_dataset, metrics):
  df = df[df['config.dataset'] == train_dataset].copy()
  df = df[['model'] + metrics].melt(
      id_vars='model', var_name='metric', value_name='value')
  grid = sns.catplot(
      col='metric', data=df, y='value', kind='bar', sharey=False,
      x='model')
  for ax in grid.axes.flat:
    ax.set_xticklabels(
        ax.get_xticklabels(), rotation=40, horizontalalignment="right"
    )


def plot_in_distribution(df, train_dataset, split):
  metrics = [f'{split}_{m}' for m in ['loss', 'prec@1', 'ece', 'calib_auc']]
  plot_metrics(df, train_dataset, metrics)

def pareto_plot_in_distribution_subfigs(df, train_dataset, split, axes, xmetric):
  metrics = [f'{split}_{m}' for m in ['prec@1', 'loss']]
  pareto_plot_subfigs(df, metrics, train_dataset, axes=axes, xmetric=xmetric)

def plot_ood(df, train_dataset):
  df = df[df['config.dataset'] == train_dataset].copy()
  if train_dataset == 'imagenet2012':
    datasets = {'places365_small'}
    metrics = ['msp', 'entropy', 'mlogit']
  else:
    datasets = set(['svhn_cropped', 'cifar100', 'cifar10']) - {train_dataset}
    metrics = ['msp', 'entropy', 'mlogit', 'maha', 'rmaha']
  cols = [
      f'ood_{ds}_{m}_auroc' for (ds, m) in itertools.product(datasets, metrics)
  ]
  cols = list(set(cols).intersection(df.columns))
  df = df[['model'] + cols]
  df = df.melt(id_vars='model', var_name='metric', value_name='AUROC')
  df['dataset'] = df['metric'].apply(lambda x: x.split('_')[1])
  df['metric'] = df['metric'].apply(lambda x: x.split('_')[-2])

  sns.catplot(
      data=df, x='metric', y='AUROC', hue='model', kind='bar', col='dataset')
  plt.ylim((0.5, 1))


def plot_reclassified(df, train_dataset):
  ds = 'imagenet_real' if train_dataset == 'imagenet2012' else 'cifar_10h'
  metrics = [f'{ds}_{m}' for m in ['loss', 'prec@1', 'ece', 'calib_auc']]
  plot_metrics(df, train_dataset, metrics)


def _get_imagenet_shifts_metrics(eval_dataset):
  base_metrics = ['accuracy', 'ece', 'nll', 'brier']
  metrics = [f'{eval_dataset}/{m}' for m in base_metrics]
  if eval_dataset == 'imagenet_c':
    metrics = [f'{m}/mean' for m in metrics]
  return metrics


def _get_imagenet_robustness_metrics(eval_dataset):
  base_metrics = ['accuracy_pmk', 'anchor_accuracy', 'accuracy_drop']
  return [f'{eval_dataset}/{m}' for m in base_metrics]


def plot_imagenet_shifts(df, eval_dataset):
  metrics = _get_imagenet_shifts_metrics(eval_dataset)
  plot_metrics(df, 'imagenet_variants', metrics)


def plot_imagenet_robustness(df, eval_dataset):
  metrics = _get_imagenet_robustness_metrics(eval_dataset)
  plot_metrics(df, 'imagenet_variants', metrics)


def pareto_plot_imagenet_shifts(df, eval_dataset):
  metrics = _get_imagenet_shifts_metrics(eval_dataset)
  pareto_plot(df, train_dataset='imagenet_variants', metrics=metrics)

def pareto_plot_imagenet_shift_subplots(df, eval_dataset, axes, xmetric):
  metrics = _get_imagenet_shifts_metrics(eval_dataset)
  metrics = [m for m in metrics if 'ece' not in m]
  pareto_plot_subfigs(df, train_dataset='imagenet_variants', metrics=metrics, axes=axes, xmetric=xmetric)

def pareto_plot_imagenet_robustness(df, eval_dataset):
  metrics = _get_imagenet_robustness_metrics(eval_dataset)
  pareto_plot_subfigs(df, train_dataset='imagenet_variants', metrics=metrics)

def pareto_plot_imagenet_robustness_subplots(df, eval_dataset, axes):
  metrics = _get_imagenet_robustness_metrics(eval_dataset)
  metrics = [m for m in metrics if 'accuracy_drop' not in m]
  pareto_plot_subfigs(df, train_dataset='imagenet_variants', metrics=metrics, axes=axes)

In [None]:
#@title Pareto plots


def is_on_pareto_front(p, points, higher_is_better):
  if higher_is_better:
    return len([
        point for point in points if point[0] <= p[0] and point[1] > p[1]
    ]) == 0
  else:
    return len([
        point for point in points if point[0] <= p[0] and point[1] < p[1]
    ]) == 0


def get_pareto_points(x, y, higher_is_better):
  points = list(zip(x, y))
  frontier = [
      p for p in points if is_on_pareto_front(p, points, higher_is_better)
  ]
  return sorted(frontier, key=lambda x: x[0])


def plot_fn(data, x, y, ax=None, annotate_names=False, **kws):
  if ax is None:
    ax = plt.gca()
  sns.scatterplot(
      data=data,
      x=x,
      y=y,
      hue='model',
      markers=True,
      style='model',
      s=300,
      ax=ax,
      alpha=0.8)
  if annotate_names:
    for _, point in data.iterrows():
      ann = ax.annotate(
          '  ' + point['model'],
          xy=(point[x], point[y]),
          ha='left',
          va='bottom',
          fontsize=16)

  metric = data['metric'].iloc[0]
  higher_is_better = colab_utils.is_higher_better(metric)
  pareto_frontier = get_pareto_points(
      data[x], data[y], higher_is_better=higher_is_better)
  xx, yy = zip(*pareto_frontier)
  sns.lineplot(x=xx, y=yy, linestyle='--', ax=ax)
  ax.set_ylabel(metric)


def pareto_plot(df,
                metrics,
                train_dataset=None,
                xmetric='num_params',
                xlabel='Log # Params'):
  df = df[df['config.dataset'] == train_dataset].copy()
  df = df.groupby(['model', 'config.dataset',
                   xmetric])[metrics].apply(np.mean).reset_index()
  df = df.melt(
      id_vars=['model', 'config.dataset', xmetric],
      var_name='metric',
      value_name='value')

  g = sns.FacetGrid(data=df, col='metric', sharey=False, size=5)
  g.map_dataframe(plot_fn, x=xmetric, y='value')
  g.set_xlabels(xlabel)
  g.set(xscale='log')


def pareto_plot_subfigs(df,
                        metrics,
                        train_dataset=None,
                        xmetric='num_params',
                        xlabel='Log # Params',
                        axes=None):
  """Plot subfigures corresponding to pareto frontier plots for each of metrics

  in `metrics` on the y-axis and `xmetric` on the x-axis.

  Allows for passing in an array of axes handles in `axes` so that the plots
  can fill in subfigures (in which case axes must be the same length as
  metrics).
  """
  df = df.groupby(['model', 'config.dataset',
                   xmetric])[metrics].mean().reset_index()
  df = df.melt(
      id_vars=['model', 'config.dataset', xmetric],
      var_name='metric',
      value_name='value')
  for i in range(len(metrics)):
    if axes is not None:
      ax = axes[i]
    else:
      ax = plt.subplot(len(metrics), 1, i + 1)
    sub_df = df[df['config.dataset'] == train_dataset].copy()
    sub_df = sub_df[sub_df['metric'] == metrics[i]].copy()
    plot_fn(sub_df, x=xmetric, y='value', ax=ax)

# Results

In [None]:
#@title Upstream JFT
plot_metrics(measurements,
             train_dataset='jft/entity:1.0.0',
             metrics=['val_loss', 'val_prec@1', 'a/imagenet_10shot'])
pareto_plot(
    measurements,
    train_dataset='jft/entity:1.0.0',
    metrics=['val_loss', 'val_prec@1', 'a/imagenet_10shot'],
)
pareto_plot(
    measurements,
    train_dataset='jft/entity:1.0.0',
    metrics=['val_loss', 'val_prec@1', 'a/imagenet_10shot'],
    xmetric='tpu_days',
    xlabel='Compute (TPUv3 core days)',
)

## Cifar 10

In [None]:
#@title In-distribution
plot_in_distribution(measurements, train_dataset='cifar10', split='test')
g = pareto_plot(
    measurements,
    train_dataset='cifar10',
    metrics=['test_loss', 'test_prec@1', 'test_ece', 'test_calib_auc'])

In [None]:
#@title Cifar10h
plot_reclassified(measurements, train_dataset='cifar10')
g = pareto_plot(
    measurements,
    train_dataset='cifar10',
    metrics=['cifar_10h_loss', 'cifar_10h_prec@1', 'cifar_10h_ece', 'cifar_10h_calib_auc'])

In [None]:
#@title OOD
plot_ood(measurements, train_dataset='cifar10')

## Cifar100

In [None]:
#@title In-distribution
plot_in_distribution(measurements, train_dataset='cifar100', split='test')
g = pareto_plot(
    measurements,
    train_dataset='cifar100',
    metrics=['test_loss', 'test_prec@1', 'test_ece', 'test_calib_auc'])

In [None]:
#@title OOD
plot_ood(measurements, train_dataset='cifar100')

## Imagenet

In [None]:
#@title In-distribution
plot_in_distribution(measurements, train_dataset='imagenet2012', split='test')
g = pareto_plot(
    measurements,
    train_dataset='imagenet2012',
    metrics=['test_loss', 'test_prec@1', 'test_ece', 'test_calib_auc'])

In [None]:
#@title Imagenet Real
plot_reclassified(measurements, train_dataset='imagenet2012')
g = pareto_plot(
    measurements,
    train_dataset='imagenet2012',
    metrics=[
        'imagenet_real_loss', 'imagenet_real_prec@1', 'imagenet_real_ece',
        'imagenet_real_calib_auc'
    ])

In [None]:
#@title ImageNet Shifts & Robustness (ImageNet-C, etc.)
shifts_filename = 'imagenet_shifts.pdf'
robustness_filename = 'imagenet_yttb_robustness.pdf'
fontsize = 32
fontfamily = 'serif'
xmetric = 'num_params'

sns.reset_orig()
sns.set_theme()
matplotlib.rcParams['figure.dpi'] = 1000
matplotlib.rcParams['lines.linewidth'] = 1.25
sns.set_style('white')
matplotlib.rcParams['font.family'] = fontfamily
matplotlib.rcParams['font.size'] = fontsize
ytickfontparams = {'fontsize': fontsize * .8, 'fontweight': 'normal'}

# Don't keep E^3 or MoEs
sub_df = measurements.copy()
sub_df = sub_df.drop(sub_df[sub_df['model'].str.contains(
    'MoE', case=False)].index)
sub_df = sub_df.drop(sub_df[sub_df['model'].str.contains('E\^3')].index)

sub_df['model'] = sub_df['model'].replace({
    'BE->BE+Het': 'Plex L',
    'Det': 'None L',
})

# We're not currently including TPU days since too many numbers are missing
# in the df.  However, the below fills in some of the values, which might be
# revisited.
# # Fix TPU days to match pretraining time
# sub_df.loc[sub_df['model'].str.contains('Det->', case=False),
#            'tpu_days'] = 107.29
# sub_df.loc[sub_df['model'] == 'Plex L', 'tpu_days'] = 119.12

# # Populate all columns for tpu_days
# cols = sub_df[sub_df['tpu_days'].notna()][['model', 'tpu_days']]
# for ind, c in cols.iterrows():
#   sub_df.loc[sub_df['model'] == c['model'], 'tpu_days'] = c['tpu_days']

sub_df.loc[:, 'model'] = sub_df['model'].str.replace(r'Det', 'None', regex=True)
models_in_imagenet_shifts_fig = sub_df['model']

variants = ['imagenet_c', 'imagenet_a', 'imagenet_r', 'imagenet_v2']
fig, axes = plt.subplots(3, 4, figsize=(20, 15))
axes = np.array(axes).T
for i, ds in enumerate(variants):
  pareto_plot_imagenet_shift_subplots(sub_df, ds, axes[i], xmetric)

# Set titles along columns
axes[0, 0].set_title('ImageNet-C', fontsize=fontsize)
axes[1, 0].set_title('ImageNet-A', fontsize=fontsize)
axes[2, 0].set_title('ImageNet-R', fontsize=fontsize)
axes[3, 0].set_title('ImageNet-V2', fontsize=fontsize)

# Set labels along columns
for ax in axes.flatten():
  ax.set_xlabel('')
for ax in axes[:, 2]:
  ax.set_xlabel('# Params', fontsize=fontsize)

# Set labels along rows
for ax in axes.flatten():
  ax.set_ylabel('')
axes[0, 0].set_ylabel('Accuracy', fontsize=fontsize)
axes[0, 1].set_ylabel('NLL', fontsize=fontsize)
axes[0, 2].set_ylabel('Brier', fontsize=fontsize)

# Remove axes legends and make on for the figure
for ax in axes.flatten():
  ax.get_legend().remove()

handles, labels = ax.get_legend_handles_labels()
legend = fig.legend(
    handles,
    labels,
    loc='lower center',
    ncol=len(labels) // 2 + 1,
    labelspacing=0.3,
    handletextpad=0.1,
    borderpad=0.3,
    fontsize=fontsize * .7,
    markerscale=3)
legend.get_frame().set_linewidth(matplotlib.rcParams['axes.linewidth'])
legend.get_frame().set_edgecolor('lightgray')

if shifts_filename is not None:
  plt.savefig(shifts_filename)
  colabtools.fileedit.download_file(shifts_filename)

fig, axes = plt.subplots(2, 2, figsize=(10, 10))
axes = np.array(axes)
pareto_plot_imagenet_robustness_subplots(sub_df, 'imagenet_vid_robust', axes[0])
pareto_plot_imagenet_robustness_subplots(sub_df, 'ytbb_robust', axes[1])
for ax in axes.flatten():
  ax.set_ylabel('')
  ax.set_xlabel('')
  ax.get_legend().remove()
for ax in axes[1, :]:
  ax.set_xlabel('# Params', fontsize=fontsize)
axes[0, 0].set_title('ImageNet Vid Robust', fontsize=fontsize)
axes[0, 1].set_title('YTTB Robust', fontsize=fontsize)
axes[0, 0].set_ylabel('Accuracy PMK', fontsize=fontsize)
axes[1, 0].set_ylabel('Anchor Accuracy', fontsize=fontsize)
handles, labels = ax.get_legend_handles_labels()
legend = fig.legend(
    handles,
    labels,
    loc='center right',
    ncol=1,
    labelspacing=0.3,
    handletextpad=0.1,
    borderpad=0.3,
    fontsize=fontsize * .7,
    markerscale=3,
    bbox_to_anchor=(1.25, 0.5))
legend.get_frame().set_linewidth(matplotlib.rcParams['axes.linewidth'])
legend.get_frame().set_edgecolor('lightgray')

if robustness_filename is not None:
  plt.savefig(robustness_filename, bbox_inches='tight', pad_inches=0)
  colabtools.fileedit.download_file(robustness_filename)

In [None]:
#@title In distribution figures
in_dist_filename = 'in_distribution_robust_generalization.pdf'
fontsize = 32
fontfamily = 'serif'
xmetric = 'num_params'

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

sns.reset_orig()
sns.set_theme()
matplotlib.rcParams['figure.dpi'] = 1000
matplotlib.rcParams['lines.linewidth'] = 1.25
sns.set_style('white')
matplotlib.rcParams['font.family'] = fontfamily
matplotlib.rcParams['font.size'] = fontsize
ytickfontparams = {'fontsize': fontsize * .8, 'fontweight': 'normal'}

# Don't keep E^3 or MoEs
sub_df = measurements.copy()
sub_df = sub_df.drop(sub_df[sub_df['model'].str.contains(
    'MoE', case=False)].index).copy()
sub_df = sub_df.drop(sub_df[sub_df['model'].str.contains('E\^3')].index).copy()

sub_df['model'] = sub_df['model'].replace({
    'BE->BE+Het': 'Plex L',
    'Det': 'None L',
})

# Make the set of models shown consistent with the imagenet shifts
sub_df = sub_df[sub_df['model'].isin(models_in_imagenet_shifts_fig)]

variants = ['cifar10', 'cifar100', 'imagenet2012']
fig, axes = plt.subplots(2, 3, figsize=(20, 15))
axes = np.array(axes).T
for i, ds in enumerate(variants):
  pareto_plot_in_distribution_subfigs(sub_df, ds, 'test', axes[i], xmetric)

# Set titles along columns
axes[0, 0].set_title('Cifar-10', fontsize=fontsize)
axes[1, 0].set_title('Cifar-100', fontsize=fontsize)
axes[2, 0].set_title('ImageNet', fontsize=fontsize)

# Set labels along columns
for ax in axes.flatten():
  ax.set_xlabel('')
for ax in axes[:, 1]:
  ax.set_xlabel('# Params', fontsize=fontsize)

# Set labels along rows
for ax in axes.flatten():
  ax.set_ylabel('')
axes[0, 0].set_ylabel('Accuracy', fontsize=fontsize)
axes[0, 1].set_ylabel('NLL', fontsize=fontsize)

# Remove axes legends and make on for the figure
for ax in axes.flatten():
  try:
    ax.get_legend().remove()
  except AttributeError:
    pass

handles, labels = ax.get_legend_handles_labels()
legend = fig.legend(
    handles,
    labels,
    loc='center right',
    ncol=1,
    labelspacing=0.3,
    handletextpad=0.1,
    borderpad=0.3,
    fontsize=fontsize * .7,
    markerscale=3,
    bbox_to_anchor=(1.08, 0.5))
legend.get_frame().set_linewidth(matplotlib.rcParams['axes.linewidth'])
legend.get_frame().set_edgecolor('lightgray')

if in_dist_filename is not None:
  plt.savefig(in_dist_filename, bbox_inches='tight', pad_inches=0.25)
  colabtools.fileedit.download_file(in_dist_filename)

## Deep ensemble analysis

In [None]:
matplotlib.rcParams['font.family'] = 'serif'

In [None]:
def get_ensemble_scaling_measurements():
  DE_NAMES = ['DE S/32','DE B/32','DE L/32']
  de_measurements = get_optimal_results({
      k: v for k, v in raw_measurements.items() if k in DE_NAMES
  })

  de_measurements = de_measurements[de_measurements['model'].isin(DE_NAMES)]
  de_measurements['model'] = de_measurements.apply(
      lambda x: f'{x.model}_{int(x.ensemble_size)}', axis=1)
  de_measurements = de_measurements.drop(
      columns=list(colab_utils.compute_metrics()), errors='ignore')

  relevant_metrics = colab_utils.default_selected_metrics() + ['num_params']
  return colab_utils.process_tuned_results(
      de_measurements, relevant_metrics=relevant_metrics)

de_results = get_ensemble_scaling_measurements()

In [None]:
datasets = [
    'cifar10',
    'cifar100',
    'imagenet2012',
    # 'imagenet_variants',
]
datasets += [f'few-shot {d}' for d in colab_utils.default_fewshot_datasets()]

ensemble_meas = {
    'DE': raw_measurements['DE L/32'].query('ensemble_size==3'),
    'Det->DE': raw_measurements['Det->DE L/32'].query('ensemble_size==3'),
    'Det': raw_measurements['Det'],
}

ensemble_meas = get_optimal_results(ensemble_meas, verbose=False).drop(
    columns=list(colab_utils.compute_metrics()), errors='ignore')
df = colab_utils.process_tuned_results(ensemble_meas)

display.display(
    colab_utils.compute_score(
        df,
        datasets=datasets,
        drop_1shot=True,
        drop_incomplete_measurements=False).loc[
            ['Det'],
            ['score_prediction', 'score_uncertainty', 'score_adaptation']])

ensemble_scores = colab_utils.compute_score(
    df,
    datasets=datasets,
    drop_1shot=True,
    drop_incomplete_measurements=False,
    baseline_model='Det')
ensemble_scores = ensemble_scores[[
    'score_prediction', 'score_uncertainty', 'score_adaptation'
]]


def get_improvement(value):
  improvement = (value - 1)
  sign = '+' if improvement >= 0 else '-'
  return f'{sign}{improvement * 100:.2f}%'


for col in ['prediction', 'uncertainty', 'adaptation']:
  ensemble_scores[f'Rel. improvement ({col})'] = ensemble_scores[
      f'score_{col}'].apply(get_improvement)

display.display(ensemble_scores.loc[['Det->DE L/32', 'DE L/32'],
                                    [c for c in ensemble_scores if 'Rel' in c]])

In [None]:
datasets = [
    'cifar10',
    'cifar100',
    'imagenet2012',
    # 'imagenet_variants',
]
datasets += [f'few-shot {d}' for d in colab_utils.default_fewshot_datasets()]
score_cols = [
    'score', 'score_prediction', 'score_uncertainty', 'score_adaptation'
]


def plot_deep_ensemble_heatmap(scores, col_name):
  fontsize = 18
  tick_fontsize = 16
  de_scores = scores[['DE' in x and x != 'DE' for x in scores.index]]
  de_scores.loc[:, 'model_type'] = [
      x[3:-2].replace('/32', '') for x in de_scores.index
  ]
  de_scores.loc[:, 'ensemble_size'] = [int(x[-1:]) for x in de_scores.index]

  de_table = pd.pivot_table(
      de_scores, values='score', index='model_type', columns='ensemble_size')
  de_table = de_table.reindex(['L', 'B', 'S'])
  p = sns.heatmap(
      de_table,
      annot=True,
      fmt='.2f',
      cmap='Blues',
      cbar=False,
      annot_kws={'size': 16})
  p.set_xlabel('Ensemble Size', fontsize=fontsize)
  p.set_ylabel('Model Variant', fontsize=fontsize)
  _ = plt.xticks(fontsize=tick_fontsize)
  _ = plt.yticks(fontsize=tick_fontsize)


de_scores = colab_utils.compute_score(
    de_results.drop(columns=['num_params']),
    datasets=datasets,
    drop_1shot=True,
    drop_incomplete_measurements=True)
de_scores = de_scores[score_cols] * 100

plot_deep_ensemble_heatmap(de_scores, 'score')
plt.tight_layout()
filename = 'ensemble_tradeoff.pdf'
plt.savefig(filename, bbox_inches='tight', pad_inches=0)
colabtools.fileedit.download_file(filename)

In [None]:
fontsize = 18
tick_fontsize = 16
de_results['architecture'] = de_results.index.map(
    lambda x: x.split(' ')[1].split('_')[0].replace('/32', ''))
de_results['ensemble_size'] = de_results.index.map(
    lambda x: int(x.split('_')[-1]))

_ = sns.scatterplot(
    data=de_results,
    x=('num_params', 'imagenet2012'),
    y=('test_prec@1', 'imagenet2012'),
    hue='architecture',
    size='ensemble_size',
    sizes=(40, 200))
_ = plt.ylabel('Accuracy', fontsize=fontsize)
_ = plt.xlabel('# Parameters', fontsize=fontsize)


def clean_legend(ax):
  handles, labels = ax.get_legend_handles_labels()
  legend_texts = [l.get_text() for l in ax.legend().get_texts()]
  ensemble_idx = legend_texts.index('ensemble_size')
  architecture_idx = legend_texts.index('architecture')

  # Remove titles in legend ('Architecture', 'Ensemble size')
  del handles[ensemble_idx]
  del labels[ensemble_idx]
  del handles[architecture_idx]
  del labels[architecture_idx]

  def _annotate_label(label):
    return f'n = {label}' if label.isnumeric() else label

  labels = [_annotate_label(l) for l in labels]

  # Add empty legends so that the two-column environment breaks at the end of
  # the "architecture" legend.
  empty_handle = matplotlib.collections.PathCollection(
      paths=handles[0].get_paths(), sizes=[0.])
  handles = [empty_handle, empty_handle] + handles
  labels = ['', ''] + labels

  legend = plt.legend(
      handles=handles,
      labels=labels,
      fontsize=15,
      ncol=2,
      frameon=True,
      framealpha=.1)
  legend.get_frame().set_edgecolor('k')


clean_legend(plt.gca())

_ = plt.xticks(fontsize=tick_fontsize)
_ = plt.yticks(fontsize=tick_fontsize)
plt.gca().xaxis.get_offset_text().set_size(tick_fontsize)
plt.tight_layout()
filename = 'imagenet_params_vs_prec.pdf'
plt.savefig(filename, bbox_inches='tight', pad_inches=0)
colabtools.fileedit.download_file(filename)

## Comparison with sparse MoEs

In [None]:
score_cols = [
    '\textsc{Score}',
    '\textsc{Score prediction}',
    '\textsc{Score uncertainty}',
    '\textsc{Score adaptation}'
]

moes_df = df_with_scores.reindex([
  'Det', 'MoE', 'E^3', 'Det->[Det]_4', 'MoE->[MoE]_4', '[Det]_4', '[MoE]_4'
])
moes_df = moes_df.rename(index={
    'E^3': '\textsc{E}$^3$',
    'Det->[Det]_4': '$\textsc{Det}\rightarrow[\textsc{Det}]_4$',
    'MoE->[MoE]_4': '$\textsc{MoE}\rightarrow[\textsc{MoE}]_4$',
    '[Det]_4': '$[\textsc{Det}]_4$',
    '[MoE]_4': '$[\textsc{MoE}]_4$'
},
columns={
    'score': '\textsc{Score}',
    'score_prediction': '\textsc{Score prediction}',
    'score_uncertainty': '\textsc{Score uncertainty}',
    'score_adaptation': '\textsc{Score adaptation}'
})

# TODO(rjenatton@): regenerate after the adaption score has been updated.
# In the current state, most of the DE and MoEs do not have adaption scores.
moe_df = moes_df[score_cols].applymap("{0:.2f}".format)
moe_df = moe_df.applymap(lambda s: '$-$' if s == 'nan' else s)
moe_latex_table = moe_df.to_latex(index=True,
                                  index_names=False,
                                  column_format='ccccc',
                                  escape=False)
# Add \midrule at appropriate positions and remove empty line.
moe_latex_table = moe_latex_table.splitlines()
moe_latex_table = moe_latex_table[:8] + ['\\midrule'] + moe_latex_table[8:]
moe_latex_table = moe_latex_table[:11] + ['\\midrule'] + moe_latex_table[11:]
moe_latex_table = moe_latex_table[:3] + moe_latex_table[4:]
# Just need to copy/paste the result of the print statement.
print('\n'.join(moe_latex_table))

# Upstream vs downstream

In [None]:
matplotlib.rcParams['font.family'] = 'serif'

In [None]:
def get_up_vs_down_df():
  RELEVANT_MODELS = [
      'Det', 'BE L/32', 'Det->BE', 'DE', 'Det->DE', 'GP', 'Det->GP', 'Het',
      'Det->Het'
  ]
  relevant_measurements = {
      k: v for k, v in raw_measurements.items() if k in RELEVANT_MODELS
  }
  df = get_optimal_results(
      relevant_measurements, verbose=False).drop(
          columns=list(colab_utils.compute_metrics()), errors='ignore')
  df = colab_utils.process_tuned_results(df)
  return df.rename({'BE L/32': 'BE'})


def _add_up_vs_down_metadata(df):
  df = df.copy()

  def _adaptation_type(model_name):
    if model_name == 'Det':
      return 'baseline'
    if '->' in model_name:
      return 'Downstream only'
    else:
      return 'Upstream & downstream'

  df['base_model'] = df.index.map(lambda x: x.split('->')[-1])
  df['up_vs_down'] = df.index.map(_adaptation_type)

  return df

In [None]:
from matplotlib.lines import Line2D


def up_vs_down_plot(df, dataset, metric):
  fontsize = 18
  tick_fontsize = 16

  col = metric if dataset is None else (metric, dataset)
  ymin = df[col].min()
  ymax = df[col].max()

  det_baseline = df.loc['Det', col]
  cur_df = df[df.index != 'Det']

  base_model_col = 'base_model' if dataset is None else ('base_model', '')
  up_vs_down_col = 'up_vs_down' if dataset is None else ('up_vs_down', '')

  graph = sns.barplot(data=cur_df, x=base_model_col, y=col, hue=up_vs_down_col)
  graph.axhline(det_baseline, c='r', linewidth=2)
  plt.ylim(max(0, ymin - .5 * (ymax - ymin)), ymax + .1 * (ymax - ymin))
  plt.legend(title='location', fontsize=fontsize, title_fontsize=fontsize)
  ylabel = metric.replace('test_', '').replace('loss', 'NLL').replace(
      'ece', 'ECE').replace('prec@1', 'Accuracy')
  plt.gca().legend().set_visible(False)
  plt.ylabel(ylabel, fontsize=fontsize)
  plt.xticks(fontsize=tick_fontsize)
  plt.yticks(fontsize=tick_fontsize)
  plt.xlabel('')


plot_datasets = ['cifar10', 'cifar100', 'imagenet2012']
plot_metrics = ['test_prec@1', 'test_ece', 'test_loss']

df = _add_up_vs_down_metadata(get_up_vs_down_df())

for row, dataset in enumerate(plot_datasets):
  for col, metric in enumerate(plot_metrics):
    up_vs_down_plot(df, dataset, metric)
    filename = f'up_vs_down_{dataset}_{metric}.pdf'
    plt.tight_layout()
    plt.savefig(filename, bbox_inches='tight', pad_inches=0)
    colabtools.fileedit.download_file(filename)
    plt.show()

# Generate arbitrary figure to use its legend data
up_vs_down_plot(df, 'cifar10', 'test_prec@1')

# Save legend in separate figure
ax = plt.gca()
fig_leg = plt.figure(figsize=(15, .5))
ax_leg = fig_leg.add_subplot(111)
# add the legend from the previous axes
handles, labels = ax.get_legend_handles_labels()
handles.extend([Line2D([0], [0], color='r')])
labels.extend(['Deterministic (upstream & downstream)'])
ax_leg.legend(handles, labels, loc='center', ncol=3, fontsize=18, frameon=False)
# hide the axes frame and the x/y labels
ax_leg.axis('off')
fig_leg.savefig('up_vs_down_legend.pdf', bbox_inches='tight', pad_inches=0)
colabtools.fileedit.download_file('up_vs_down_legend.pdf')

# Open-set recognition

In [None]:
ddf_ood = colab_utils.process_tuned_results(measurements, relevant_metrics=colab_utils.ood_related_metrics())
df_ood

In [None]:
df_ood.keys()

In [None]:
#@title Comparing models by fixing on OOD method = MSP
ood_msp_metrics = [x[0] for x in df_ood.keys() if 'msp' in x[0]]
df_ood.loc[:, (ood_msp_metrics, slice(None))]

In [None]:
#@title Comparing OOD methods on hard near-OOD tasks: (1) ImageNet2012
df_ood.loc[:, (slice(None), 'imagenet2012')]

In [None]:
#@title Comparing OOD methods on hard near-OOD tasks (2) CIFAR-100 vs CIFAR-10
near_ood_metrics_cifar = ['ood_cifar10_msp_auroc', 'ood_cifar10_entropy_auroc',
                          'ood_cifar10_mlogit_auroc', 
                          'ood_cifar10_maha_auroc', 'ood_cifar10_rmaha_auroc'] 
df_ood.loc[:, (near_ood_metrics_cifar, 'cifar100')]

## Zero-shot OOD results

In [None]:
#@title load zero-shot data



In [None]:
excluded_keys = [
    'DE', 'Det->DE', 'DE S/32', 'Det->DE S/32', 'DE B/32', 'Det->DE B/32',
    'DE L/32', 'Det->DE L/32', 'Det -> BE L/32 (n=2)', 'Det -> BE L/32 (n=4)',
    'Det -> BE L/32 (n=8)', 
    'E^3', 'BE scaling', 'MoE',
    'Det->BE', 'Det->GP', 'Det->Het', 'BE->BE+Het'
]
included_measurements = {
    k: v for k, v in raw_measurements.items() if k not in excluded_keys
}
measurements = get_optimal_results(included_measurements)

df_ood_zero_shot = colab_utils.process_tuned_results(measurements, relevant_metrics=colab_utils.ood_related_metrics())

In [None]:
ood_maha_metrics = [x[0] for x in df_ood_zero_shot.keys() if 'maha' in x[0]]
df_ood_zero_shot.loc[:, (ood_maha_metrics, slice(None))]

# CIFAR subpopulation shift plot

In [None]:
subpopl_metrics_raw = []
subpopl_metrics_raw.append((""".984 / .989 / .992 / .996 / .998
.973 / .987 / .993 / 1.0 / 1.0
.892 / .909 / .922 / .933 / .944
.878 / .904 / .920 / .940 / .961""", "None"))
subpopl_metrics_raw.append((""".986 / .990 / .995 / .997 / 1.0
.978 / .990 / 1.0  / 1.0 / 1.0
.912 / .931 / .937 / .945 / .960
.900 / .920 / .940 / .950 / .971""", "Plex"))
# subpopl_metrics_raw.append((""".982 / .987 / .990 / .993 / .998
# .971 / .985 / .991 / 1.0 / 1.0
# .901 / .922 / .933 / .943 / .953
# .895 / .920 / .930 / .950 / .970""", "None I21K"))
subpopl_metrics_raw.append((""".986 / .990 / .994 / .997 / 1.0
.977 / .989 / 1.0 / 1.0 / 1.0
.907 / .923 / .933 / .944 / .959
.899 / .919 / .933 / .950 / .971""", "BE→BE"))
subpopl_metrics_raw.append((""".985 / .990 / .994 / .996 / 1.0
.977 / .987 / .999 / 1.0 / 1.0
.905 / .922 / .931 / .940 / .955
.896 / .915 / .930 / .949 / .970""", "None→BE"))


subpopl_metrics_CIFAR10_30 = {}
subpopl_metrics_CIFAR10_100 = {}
subpopl_metrics_CIFAR100_30 = {}
subpopl_metrics_CIFAR100_100 = {}
for raw_metrics, key in subpopl_metrics_raw:
  metrics = [float(raw_metric.strip()) for raw_metric in raw_metrics.replace('\n', '/').split('/')]
  # 4 datasets, 5 metrics for each.
  assert len(metrics) == 20
  subpopl_metrics_CIFAR10_30[key] = metrics[0:5]
  subpopl_metrics_CIFAR10_100[key] = metrics[5:10]
  subpopl_metrics_CIFAR100_30[key] = metrics[10:15]
  subpopl_metrics_CIFAR100_100[key] = metrics[15:20]

subpopl_metrics = {}
subpopl_metrics['CIFAR10_30'] = subpopl_metrics_CIFAR10_30
subpopl_metrics['CIFAR10_100'] = subpopl_metrics_CIFAR10_100
subpopl_metrics['CIFAR100_30'] = subpopl_metrics_CIFAR100_30
subpopl_metrics['CIFAR100_100'] = subpopl_metrics_CIFAR100_100

subpopl_metrics_list = []
for dataset in subpopl_metrics:
  for task in subpopl_metrics[dataset]:
    dataset_base, dataset_tail = dataset.split('_')
    subpopl_metrics_list.append([subpopl_metrics[dataset][task], task, dataset_base, dataset_tail])

df = pd.DataFrame(subpopl_metrics_list, columns=['values', 'task', 'dataset', 'subpopulations'])

def subpopl_plot_fn(data, color):
  x = data['task'].tolist()
  y = data['values'].tolist()
  sns.boxplot(data=pd.DataFrame({key: data for key, data in zip(x, y)}), order=['Plex', 'BE→BE', 'None→BE', 'None'])

  dataset = data['dataset'].tolist()[0]
  if dataset == 'CIFAR10':
    plt.ylim(.97, 1.0)
  elif dataset == 'CIFAR100':
    plt.ylim(.87, .98)
  else:
    raise ValueError()

  plt.xlabel('Method')
  plt.ylabel('Accuracy')

def plot_subpopl_metrics():
  matplotlib.rcParams['font.family'] = 'serif'
  matplotlib.rcParams['axes.titlepad'] = 10
  g = sns.FacetGrid(df, row="dataset", col='subpopulations', aspect=1.8, height=3.2, sharey='row')
  for ax in g.axes.flatten():
    ax.tick_params(labelbottom=True, labelleft=True)
  g.map_dataframe(subpopl_plot_fn)
  plt.subplots_adjust(hspace=0.4, wspace=0.2)
  g.despine(right=False, top=False)
  g.axes[0,0].set_ylabel('Accuracy')
  g.axes[1,0].set_ylabel('Accuracy')

  filename = 'subpopl_ablations.pdf'
  plt.savefig(filename, bbox_inches='tight', pad_inches=0, dpi=1000)
  colabtools.fileedit.download_file(filename)

  return g

plot_subpopl_metrics()