In [None]:
#@title Copyright 2022 Google LLC, licensed under the Apache License, Version 2.0 (the "License")
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [None]:
import pickle
from collections import defaultdict
import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve
from scipy.interpolate import interp1d

In [None]:
def get_log_likelihoods(vis_dist, mode):
  """Loads log likelihoods from probs.pkl files.
  
  Args:
    vis_dist: Visible dist of the model
    mode: "grayscale" or "color"
  Returns:
    A nested dictionary containing the log likelihoods
  """
  
  if mode == 'grayscale':
    datasets = [
      'mnist',
      'fashion_mnist',
      'emnist/letters',
      'sign_lang',
    ]
    nf = 32
    cs_hist = 'adhisteq'
  else:
    datasets = [
      'svhn_cropped',
      'cifar10',
      'celeb_a',
      'gtsrb',
      'compcars',
    ]
    nf = 64
    cs_hist = 'histeq'

  log_probs = defaultdict(lambda: defaultdict(dict))
  for id_data in datasets:
    for norm in [None, 'pctile-5', cs_hist]:
      with open(
          (f'vae_ood/models/{vis_dist}/'
           f'{id_data.replace("/", "_")}-{norm}-zdim_20-lr_0.0005-bs_64-nf_{nf}/'
           'probs.pkl'),
          'rb') as f:
        d = pickle.load(f)
      for i, ood_data in enumerate(datasets + ['noise']):
        log_probs[f'{id_data}-{norm}'][f'{ood_data}-{norm}']['orig_probs'] = d['orig_probs'][ood_data]
        log_probs[f'{id_data}-{norm}'][f'{ood_data}-{norm}']['corr_probs'] = d['corr_probs'][ood_data]
  return log_probs`

In [None]:
def get_metrics(log_probs):
  """Computes AUROC, AUPRC and FPR@80 metrics using probs.pkl files.
  
  Args:
    log_probs: original and corrected log likelihoods for all ID-OOD
               pairs as returned by get_log_likelihoods()
  Returns:
    A nested dictionary containing the metrics
  """

  metrics = defaultdict(lambda: defaultdict(dict))
  for id_data in log_probs:
    for ood_data in log_probs[id_data]:
      labels_concat = np.concatenate(
          [np.zeros_like(log_probs[id_data][ood_data]['orig_probs'][:10000]),
           np.ones_like(log_probs[id_data][id_data]['orig_probs'][:10000])]) 
      lls_concat = np.concatenate(
          [log_probs[id_data][ood_data]['orig_probs'][:10000],
           log_probs[id_data][id_data]['orig_probs'][:10000]])
      orig_roc = roc_auc_score(labels_concat, lls_concat)
      orig_prc = average_precision_score(labels_concat, lls_concat)
      fpr, tpr, thresholds = roc_curve(labels_concat, lls_concat, pos_label=1, drop_intermediate=False)
      ind = np.argmax(tpr>0.8)  
      x = np.array((tpr[ind-1], tpr[ind]))
      y = np.array((fpr[ind-1], fpr[ind]))    
      f = interp1d(x,y)
      orig_fpr = f(0.8)
      metrics[id_data][ood_data]['orig_roc'] = orig_roc*100
      metrics[id_data][ood_data]['orig_prc'] = orig_prc*100
      metrics[id_data][ood_data]['orig_fpr'] = orig_fpr*100

      lls_concat = np.concatenate(
          [log_probs[id_data][ood_data]['corr_probs'][:10000],
           log_probs[id_data][id_data]['corr_probs'][:10000]])
      corr_roc = roc_auc_score(labels_concat, lls_concat)
      corr_prc = average_precision_score(labels_concat, lls_concat)
      fpr, tpr, thresholds = roc_curve(labels_concat, lls_concat, pos_label=1, drop_intermediate=False)
      ind = np.argmax(tpr>0.8)  
      x = np.array((tpr[ind-1], tpr[ind]))
      y = np.array((fpr[ind-1], fpr[ind]))    
      f = interp1d(x,y)
      corr_fpr = f(0.8)
      metrics[id_data][ood_data]['corr_roc'] = corr_roc*100
      metrics[id_data][ood_data]['corr_prc'] = corr_prc*100
      metrics[id_data][ood_data]['corr_fpr'] = corr_fpr*100
  return metrics

In [None]:
def print_metrics(metrics):
  """Returns key metrics in a dataframe.
  
  Args:
    metrics: metrics dict returned by get_metrics()
  Returns:
    A dataframe containing key metrics
  """
  df = pd.DataFrame(
    columns = ['ID Data ↓ OOD Data →'] + 
    list(set(dname.split('-')[0] for dname in metrics.keys())) + ['noise'])
  for id_data in df.columns[1:-1]:
    df_row = {'ID Data ↓ OOD Data →': id_data}
    for ood_data in df.columns[1:]:
      df_row[ood_data] = [
          int(round(metrics[f'{id_data}-None'][f'{ood_data}-None']['orig_roc'],
                    0)),
          int(round(metrics[f'{id_data}-pctile-5'][f'{ood_data}-pctile-5']['corr_roc'],
                    0))
          ]
    df = df.append(df_row, ignore_index=True)
  return df.set_index('ID Data ↓ OOD Data →', drop=True)

In [None]:
cb_grayscale_lls = get_log_likelihoods('cont_bernoulli', 'grayscale')
cb_grayscale_metrics = get_metrics(cb_grayscale_lls)
print_metrics(cb_grayscale_metrics)

In [None]:
cat_grayscale_lls = get_log_likelihoods('cont_bernoulli', 'grayscale')
cat_grayscale_metrics = get_metrics(cat_grayscale_lls)
print_metrics(cat_grayscale_metrics)

In [None]:
cb_color_lls = get_log_likelihoods('cont_bernoulli', 'color')
cb_color_metrics = get_metrics(cb_color_lls)
print_metrics(cb_color_metrics)

In [None]:
cat_color_lls = get_log_likelihoods('categorical', 'color')
cat_color_metrics = get_metrics(cat_color_lls)
print_metrics(cat_color_metrics)