# **Understanding Deep Generative Models with Generalized Empirical Likehoods**


Copyright 2023 DeepMind Technologies Limited

All software is licensed under the Apache License, Version 2.0 (Apache 2.0); you may not use this file except in compliance with the Apache 2.0 license. You may obtain a copy of the Apache 2.0 license at: https://www.apache.org/licenses/LICENSE-2.0

All other materials are licensed under the Creative Commons Attribution 4.0 International License (CC-BY). You may obtain a copy of the CC-BY license at: https://creativecommons.org/licenses/by/4.0/legalcode

Unless required by applicable law or agreed to in writing, all software and materials distributed here under the Apache 2.0 or CC-BY licenses are distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the licenses for the specific language governing permissions and limitations under those licenses.

This is not an official Google product.

# Imports

In [None]:
!pip install ml_collections

import functools
import io
import os

from absl import flags
from absl import logging
from ml_collections import config_dict
import numpy as np
from scipy.optimize import linprog
from scipy.stats import entropy
import timeit

from sklearn.metrics import pairwise_distances
from enum import Enum
from google.colab import auth as google_auth

import matplotlib.pyplot as plt

logging.set_verbosity(logging.INFO)

# Copy Data from Google Cloud Bucket

Authenticate user, and list data in GCP bucket

In [None]:
google_auth.authenticate_user()
!gsutil ls gs://dm_gel_metric/cifar10_mode_drop_data/

Copy to local disk and list what is copied

In [None]:
!mkdir -p cifar10_mode_drop_data
!gsutil cp gs://dm_gel_metric/cifar10_mode_drop_data/*.npz cifar10_mode_drop_data/
!ls cifar10_mode_drop_data

# GEL Code

## Helper Functions and Classes

In [None]:
def approx_in_cvx_hull(
    hull_points: np.array, test_point: np.array, eps: float=0.01):
  """Triangle alg. to see if test_point is in the convex hull of hull_points.

  Implementation of Kalantari et al.,
  "Randomized triangle algorithms for convex hull membership"

  Args:
    hull_points: (n_points, n_dim) matrix of hull points.
    test_point: (n_dim,) vector which is the test point.
    eps: epsilon tolerance
  Returns:
    in_hull: boolean (True) if test point is in the convex hull of hull_points.
    pivot_point: pivot point of the algorithm.
  """
  mean_hull_point = np.mean(hull_points, axis=0)
  aug_hull_points = np.vstack([hull_points, mean_hull_point])

  def calc_dists(hull_points, point):
    return np.sqrt(np.sum((hull_points-point) ** 2, axis=1))

  def nearest_point(test_point, pivot_point, hull_point):
    alpha = (np.dot(test_point-pivot_point, hull_point-pivot_point) /
             np.sum((hull_point-pivot_point) ** 2))
    if alpha >= 0.0 and alpha < 1.0:
      return (1.0 - alpha) * pivot_point + alpha * hull_point
    return hull_point

  def get_new_hull_point(sub_hull_points, pivot_point, test_point):
    norm_vec = ((test_point - pivot_point) /
                np.linalg.norm(test_point - pivot_point))
    a_mat = sub_hull_points - pivot_point
    a_mat /= np.linalg.norm(a_mat, axis=1)[:, np.newaxis]
    return sub_hull_points[np.argmax(np.dot(a_mat, norm_vec)), :]

  dist_vec = calc_dists(aug_hull_points, test_point)
  radius = np.max(dist_vec)
  pivot_point = aug_hull_points[np.argmin(dist_vec), :]
  i = 0
  while np.linalg.norm(pivot_point-test_point) > eps * radius:
    dist_vec_pivot = calc_dists(aug_hull_points, pivot_point)
    if np.all(dist_vec_pivot < dist_vec):
      # convex hull condition failed
      logging.info('Witness found at iter %d', i)
      return False, pivot_point
    sub_hull_points = aug_hull_points[dist_vec_pivot > dist_vec, :]
    hull_point = get_new_hull_point(sub_hull_points, pivot_point, test_point)
    pivot_point = nearest_point(test_point, pivot_point, hull_point)
    i += 1

  logging.info('It is probably in the convex hull. Took %d iterations', i)
  logging.info('Distance for iter %d is %.5f, radius %.5f', i,
               np.linalg.norm(pivot_point-test_point) / radius, radius)
  return True, pivot_point


def is_in_cvx_hull(hull_points: np.array, test_point: np.array):
  """Checks to see if test_point is in the convex hull of hull_points.

  Args:
    hull_points: (n_points, n_dim) matrix of hull points.
    test_point: (n_dim,) vector which is the test point.
  Returns:
    in_hull: boolean (True) if test point is in the convex hull of hull_points.
  """

  n_points, n_dims = hull_points.shape
  c = np.zeros((n_points,))
  matrix_a_ub = -np.eye(n_points)
  b_ub = np.zeros((n_points,))

  matrix_a_eq = np.ones((n_dims+1, n_points))
  matrix_a_eq[:n_dims, :] = hull_points.T

  b_eq = np.ones((n_dims+1))
  b_eq[:n_dims] = test_point

  res = linprog(c, A_ub=matrix_a_ub, b_ub=b_ub, A_eq=matrix_a_eq, b_eq=b_eq)
  logging.info(res.status)
  if res.success:
    log_prob = np.sum(np.log2(res.x))
    logging.info('Initial log prob is %.5f', log_prob)
  in_hull = res.success
  return in_hull


def convert_space_to_dims(features: np.array):
  """Flatten features to dimensions."""
  out = features.reshape([features.shape[0], -1])
  return out


def do_pca(features: np.array, pca_dim: int):
  logging.info('Whitening...')
  cov_mat = np.cov(features, rowvar=False)
  v = np.linalg.eig(cov_mat)[1]
  features = np.real(np.dot(features, v)[:, :pca_dim])
  logging.info('Done')

  return features

def one_sample_emp_lik_iteration(feature_diffs: np.array, params: np.array):
  """Perform newton iteration for empirical likelihood.

  Args:
    feature_diffs: per-sample features for moment conditions.
    params: current parameters for calculating empirical likelihood.
  Returns:
    params: parameters after a Newton step.
    output_stats: dictionary of output statistics.
  """
  num_examples = feature_diffs.shape[0]
  z = 1.0 + np.dot(feature_diffs, params)
  inv_n = 1.0 / num_examples

  # positive part of the modified logarithm
  w_pos = 1.0 / z[z >= inv_n]
  f_diff_pos = feature_diffs[z >= inv_n, :] * w_pos[:, np.newaxis]

  # negative part of the modified logarithm
  w_neg = (2.0 - num_examples * z[z < inv_n]) * num_examples
  f_diff_neg = feature_diffs[z < inv_n, :]
  num_egs2 = num_examples ** 2

  neg_hess = (np.dot(f_diff_pos.T, f_diff_pos)
              + np.dot(f_diff_neg.T, f_diff_neg) * num_egs2)
  sc_f_diff_neg = f_diff_neg * w_neg[:, np.newaxis]
  log_grad = np.sum(f_diff_pos, axis=0) + np.sum(sc_f_diff_neg, axis=0)
  log_grad_norm = np.linalg.norm(log_grad)

  direction = np.linalg.solve(neg_hess, log_grad)
  params += 1.0 * direction
  n_out_of_domain = f_diff_neg.shape[0]

  probs = 1.0 / (num_examples * (1.0 + np.dot(feature_diffs, params)))
  log_lik = np.sum(np.log(probs))
  output_stats = dict(
      probs=probs, obj=log_lik,
      n_out_of_domain=n_out_of_domain, log_grad_norm=log_grad_norm,)
  return params, output_stats


def hellinger_dist(p: np.array, q: np.array):
  assert np.all(p >= 0.0)
  assert np.all(q >= 0.0)
  norm_p = p / np.sum(p)
  norm_q = q / np.sum(q)
  return np.sqrt(1. - np.sum(np.sqrt(norm_p * norm_q)))


class GELStatus(Enum):
  SOLVED = "solved"
  NOT_IN_CONVEX_HULL = "not_in_convex_hull"
  BOUNDARY = "boundary"
  OPTIMIZATION_FAILURE = "optimization_failure"
  RUNNING = "running"


class GELObjective(Enum):
  EMPIRICAL_LIKELIHOOD = 1
  EXPONENTIAL_TILTING = 2
  EUCLIDEAN_LIKELIHOOD = 3

## One-Sample GEL code

In [None]:
class OneSampleGEL(object):
  def __init__(self, config, model_unprocessed_feats: np.array,
               test_unprocessed_feats: np.array):
    self.config = config
    feature_diffs, test_features, model_features = self._preprocess_features(
        model_unprocessed_feats, test_unprocessed_feats)
    self.feature_diffs = feature_diffs
    self.test_features = test_features
    self.model_features = model_features

    self._current_loss = np.inf
    iter_func, output_stats, gel_status, params = self._init_optimizer()
    self._iter_func = iter_func
    self._output_stats = output_stats
    self._status = gel_status  # termination conditions
    self._params = params

  def _init_optimizer(self):
    num_examples, ndims = self.feature_diffs.shape
    if self.config.obj_type == GELObjective.EMPIRICAL_LIKELIHOOD:
      iter_func = self.emp_lik_iteration
    elif self.config.obj_type == GELObjective.EXPONENTIAL_TILTING:
      iter_func = self.exp_tilted_iteration
    elif self.config.obj_type == GELObjective.EUCLIDEAN_LIKELIHOOD:
      iter_func = self.euc_lik_iteration
    else:
      raise ValueError('Objective type %s not valid', self.config.obj_type)
    probs = np.empty((num_examples,))
    probs[:] = np.nan
    output_stats = dict(probs=probs, obj=np.inf,
                        n_out_of_domain=0, log_grad_norm=np.inf)
    params = np.zeros((ndims,))
    gel_status = GELStatus.RUNNING
    # Check if the convex hull condition is satisfied.
    # This condition does not apply to the Euc. Likelihood
    if not self.config.obj_type == GELObjective.EUCLIDEAN_LIKELIHOOD:
      if not approx_in_cvx_hull(self.feature_diffs, np.zeros((ndims,)))[0]:
        gel_status = GELStatus.NOT_IN_CONVEX_HULL
    return iter_func, output_stats, gel_status, params

  def euc_lik_iteration(self):
    """Perform newton step for euclidean likelihood (which solves the problem).
    """
    num_examples = self.feature_diffs.shape[0]
    sample_cov = np.cov(self.test_features, rowvar=False, bias=True)
    self._params = np.linalg.solve(
        sample_cov, np.mean(self.feature_diffs, axis=0))
    demeaned_test_features = self.test_features - np.mean(
        self.test_features, axis=0)
    probs = (1 - np.dot(demeaned_test_features, self._params)) / num_examples

    # square is probably "more" correct but easier for per point
    euc_dist = np.sum((probs - 1.0 / num_examples) ** 2) * num_examples
    log_grad_norm, n_out_of_domain = 0.0, 0
    self._output_stats = dict(
        probs=probs, loss=euc_dist,
        n_out_of_domain=n_out_of_domain, log_grad_norm=log_grad_norm,)

  def exp_tilted_iteration(self):
    """Perform half newton step for exponential tilting objective."""
    num_examples = self.feature_diffs.shape[0]
    w_exp_tilt = np.exp(np.dot(self.feature_diffs, self._params)) / num_examples
    sc_f_diff = self.feature_diffs * w_exp_tilt[:, np.newaxis]

    hess = np.dot(sc_f_diff.T, self.feature_diffs)
    log_grad = np.sum(sc_f_diff, axis=0)
    newton_step = np.linalg.solve(hess, log_grad)
    log_grad_norm = np.linalg.norm(log_grad)

    self._params -= 0.5 * newton_step
    exp_weights = np.exp(np.dot(self.feature_diffs, self._params))
    probs = exp_weights / np.sum(exp_weights)
    n_out_of_domain = 0
    ent = entropy(probs, base=2)
    self._output_stats = dict(
        probs=probs, loss=-ent,
        n_out_of_domain=n_out_of_domain, log_grad_norm=log_grad_norm,)

  def emp_lik_iteration(self):
    """Perform newton iteration for empirical likelihood objective.

    Returns:
      params: parameters after a Newton step.
      output_stats: dictionary of output statistics.
    """
    # The empirical likelihood iteration is in the helper function section
    # since we use it for both the one-sample and two-sample versions
    self._params, output_stats = one_sample_emp_lik_iteration(
        self.feature_diffs, self._params)
    del output_stats['obj']
    output_stats['loss'] = -np.mean(np.log(self._output_stats['probs']))
    self._output_stats = output_stats

  def _print_stats(self, iter_i: int):
    """Prints statistics at the given iteration."""
    logging.info('Loss is at iter %d is %.8f',
                 iter_i, self._output_stats['loss'])
    logging.info('minimum probability is %.8f',
                 np.min(self._output_stats['probs']))
    logging.info('maximum probability is %.8f',
                 np.max(self._output_stats['probs']))
    logging.info('sum of probability is %.8f',
                 np.sum(self._output_stats['probs']))
    logging.info('no. not in domain is %d',
                 self._output_stats['n_out_of_domain'])
    logging.info('log grad norm is %.15f',
                 self._output_stats['log_grad_norm'])

  def _preprocess_features(self, pre_model_features: np.ndarray,
                           pre_test_features: np.ndarray):
    """Loads and converts features for calculation gen. empirical likelihood.

    Args:
      pre_model_features: unpreprocessed model features.
      pre_test_features: unpreprocessed test features.
    Returns:
      feature_diffs: per-sample moment conditions.
      test_features: per-sample test features.
      model_features: per-sample model features.
    """
    if self.config.num_model_examples > 0:
      pre_model_features = pre_model_features[:self.config.num_model_examples]
    assert len(pre_model_features.shape) == len(pre_test_features.shape)

    # convert features of size [bs, h, w, c] -> [bs, dim_new]
    pre_model_features = convert_space_to_dims(pre_model_features)
    pre_test_features = convert_space_to_dims(pre_test_features)

    # Use fewer dimensions if flags ask us to.
    cut_dim = min(self.config.cut_dim, np.prod(pre_test_features.shape[1:]))
    ndims = pre_model_features.shape[1]
    assert cut_dim <= ndims, f'cut_dim {cut_dim} and ndims {ndims}'
    model_features = pre_model_features[:, :cut_dim]
    test_features = pre_test_features[:, :cut_dim]
    assert test_features.shape[1] == model_features.shape[1]

    model_means = np.mean(model_features, axis=0)
    feature_diffs = test_features - model_means[np.newaxis, :]

    if self.config.whiten:
      assert self.config.obj_type != GELObjective.EUCLIDEAN_LIKELIHOOD
      feature_diffs = do_pca(feature_diffs, self.config.pca_dim)

    return feature_diffs, test_features, model_features

  def _num_out_of_domain(self) -> bool:
    if np.any(self._output_stats['probs']) < 0.0:
      logging.info(
          'Encountered negative probability, likely due to optimization. '
          'This will likely be fixed in an iteration or two.')
    if np.any(self._output_stats['probs']) > 1.0:
      logging.info(
          'Encountered probability > 1.0, likely due to optimization. '
          'This will likely be fixed in an iteration or two.')
    return self._output_stats['n_out_of_domain'] > 0

  def _check_norm_param_condition(self):
    norm_params = np.linalg.norm(self._params)
    if (self.config.obj_type == GELObjective.EMPIRICAL_LIKELIHOOD
        and norm_params > self.config.norm_param_tol):
      logging.info('Mean likely near boundary since norm is high '
                   '%.5f, so likelihood is -Inf (<%.5f)',
                   norm_params, np.sum(np.log(self._output_stats['probs'])))
      self._status = GELStatus.BOUNDARY

  def _check_if_solved(self, iteration: int):
    if iteration % 10 == 1: self._print_stats(iteration)
    if not np.isfinite(self._output_stats['loss']):
      return
    elif not np.isfinite(self._current_loss):
      return
    if (np.abs(self._output_stats['loss']-self._current_loss) < self.config.tol
        or self._output_stats['log_grad_norm'] < self.config.grad_norm_tol):
      logging.info('GEL calculation converged at iteration %d. Terminating.',
                   iteration)
      self._print_stats(iteration)
      self._status = GELStatus.SOLVED

  def _calculate_objective(self):
    probs = self._output_stats['probs']
    num_examples = probs.shape[0]
    if self.config.obj_type == GELObjective.EMPIRICAL_LIKELIHOOD:
      objective = np.mean(np.log(probs))
      best_objective = -np.log(num_examples)
      worst_objective = -np.inf
    elif self.config.obj_type == GELObjective.EXPONENTIAL_TILTING:
      objective = entropy(probs, base=2)
      best_objective = np.log2(num_examples)
      worst_objective = 0.0
    elif self.config.obj_type == GELObjective.EUCLIDEAN_LIKELIHOOD:
      objective = np.sum((probs - 1. / num_examples) ** 2) * num_examples
      best_objective = 0.0
      worst_objective = np.inf
    else:
      raise ValueError('Objective type %s not valid', self.config.obj_type)

    return objective, best_objective, worst_objective

  def calculate_divergence(self):
    probs = self._output_stats['probs']
    num_examples = probs.shape[0]
    if self.config.obj_type == GELObjective.EMPIRICAL_LIKELIHOOD:
      divergence = -np.mean(np.log(num_examples * probs))
    elif self.config.obj_type == GELObjective.EXPONENTIAL_TILTING:
      divergence = np.sum(probs * np.log(probs * num_examples))
    elif self.config.obj_type == GELObjective.EUCLIDEAN_LIKELIHOOD:
      divergence = np.sum((probs - 1. / num_examples) ** 2) * num_examples
    else:
      raise ValueError('Objective type %s not valid', self.config.obj_type)

    return divergence

  def _evaluate_gel_solution(self, elapsed_time: float = 0.0):
    objective, best_objective, worst_objective = self._calculate_objective()
    if self._status == GELStatus.SOLVED:
      logging.info('solved... in %f seconds', elapsed_time)
      logging.info('Final objective is %.5f', objective)
      logging.info('Best objective is %.5f', best_objective)
    elif self._status == GELStatus.BOUNDARY:
      assert self.config.obj_type == GELObjective.EMPIRICAL_LIKELIHOOD
      logging.info('final log lik is -Inf_boundary')
    elif self._status == GELStatus.NOT_IN_CONVEX_HULL:
      logging.info('Convex Hull condition not satisfied')
      err = ('model mean not in convex hull of features... distributions not '
             'close enough to GEL')
      logging.info(err)
    else:
      logging.info('failed to solve... checking to see if model mean '
                   'is in convex hull')
      logging.info('calculating convex hull')
      if not is_in_cvx_hull(test_features, model_means):
        err = ('model mean not in convex hull of features... distributions not '
               'close enough to GEL')
        logging.info(err)
        self._status = GELStatus.NOT_IN_CONVEX_HULL
      else:
        logging.info('optimization failed')
        self._status = GELStatus.OPTIMIZATION_FAILED

    if self._status != GELStatus.SOLVED:
      objective = worst_objective
      self._output_stats['probs'][:] = np.nan
    out_dict = dict()
    out_dict.update(self.config.to_dict())
    out_dict['termination_status'] = self._status
    out_dict['probs'] = self._output_stats['probs']
    out_dict['objective'] = objective
    out_dict['elapsed_time'] = elapsed_time
    out_dict['model_feat_dims'] = self.model_features.shape
    out_dict['test_feat_dims'] = self.test_features.shape

    return out_dict

  def calculate_gel(self):
    """Calculates empirical likelihood and output results to a dictionary."""
    if self._status != GELStatus.RUNNING:
      return self._evaluate_gel_solution()

    start_time = timeit.default_timer()
    for i in range(self.config.num_iterations):
      try:
        self._iter_func()
      except np.linalg.LinAlgError:  # This error indicates EL hit the boundary
        logging.info('Encountered LinAlg Error, means we hit boundary cond')
        self._status = GELStatus.BOUNDARY
        break
      self._check_norm_param_condition()
      if self._status == GELStatus.BOUNDARY:
        break
      if self._num_out_of_domain(): continue
      self._check_if_solved(i)
      if self._status == GELStatus.SOLVED:
        break
      self._current_loss = self._output_stats['loss']

    elapsed_time = timeit.default_timer() - start_time
    return self._evaluate_gel_solution(elapsed_time)

## Two-Sample GEL Code

In [None]:
class TwoSampleGEL(object):
  def __init__(self, config, model_unprocessed_feats: np.array,
               test_unprocessed_feats: np.array):
    self.config = config
    res = self._preprocess_features(
        model_unprocessed_feats, test_unprocessed_feats)
    model_feats, test_feats, aug_model_feats, aug_test_feats, aug_feats = res
    self.model_features = model_feats
    self.test_features = test_feats
    self.aug_features = aug_feats
    self.aug_model_features = aug_model_feats
    self.aug_test_features = aug_test_feats

    self._current_loss = np.inf
    iter_func, output_stats, gel_status, params = self._init_optimizer()
    self._iter_func = iter_func
    self._output_stats = output_stats
    self._status = gel_status  # termination conditions
    self._params = params

  def _init_optimizer(self):
    num_model_examples = self.model_features.shape[0]
    num_test_examples = self.test_features.shape[0]
    if self.config.obj_type == GELObjective.EMPIRICAL_LIKELIHOOD:
      iter_func = self.emp_lik_iteration
      ndims = self.aug_features.shape[1]
    elif self.config.obj_type == GELObjective.EXPONENTIAL_TILTING:
      iter_func = self.exp_tilted_iteration
      ndims = self.aug_features.shape[1]
    elif self.config.obj_type == GELObjective.EUCLIDEAN_LIKELIHOOD:
      iter_func = self.euc_likelihood_iteration
      ndims = self.model_features.shape[1]
    else:
      raise ValueError('Objective type %s not valid', self.config.obj_type)
    model_probs = np.empty((num_model_examples,))
    model_probs[:] = np.nan
    test_probs = np.empty((num_test_examples,))
    test_probs[:] = np.nan
    output_stats = dict(
        model_probs=model_probs, test_probs=test_probs,
        model_loss=np.inf, test_loss=np.inf,
        n_out_of_domain=0, log_grad_norm=np.nan)

    params = np.zeros((ndims,))

    gel_status = GELStatus.RUNNING
    # Check if the convex hull condition is satisfied.
    # This condition does not apply to the Euc. Likelihood
    if not self.config.obj_type == GELObjective.EUCLIDEAN_LIKELIHOOD:
      if not approx_in_cvx_hull(
          self.aug_features, np.zeros((self.aug_features.shape[1],)))[0]:
        gel_status = GELStatus.NOT_IN_CONVEX_HULL
    return iter_func, output_stats, gel_status, params

  def euc_likelihood_iteration(self):
    """Perform one (and only) newton iteration for two-sample
       euclidean likelihood.
    """
    model_egs = self.model_features.shape[0]
    test_egs = self.test_features.shape[0]
    if model_egs != test_egs:
      raise ValueError(
          "Different number of model and test examples currently not supported")
    model_mean = np.mean(self.model_features, axis=0)
    test_mean = np.mean(self.test_features, axis=0)
    model_cov = np.cov(self.model_features, rowvar=False, bias=True)
    test_cov = np.cov(self.test_features, rowvar=False, bias=True)
    sample_cov = (model_egs * model_cov + test_egs * test_cov) / (
        model_egs + test_egs)
    self._params = np.linalg.solve(sample_cov, model_mean - test_mean)
    model_probs = 1. / model_egs - np.dot(
        self.model_features - model_mean, self._params)
    test_probs = 1. / test_egs + np.dot(
        self.test_features - test_mean, self._params)

    model_obj = 0.5 * np.mean((model_probs - 1.0 / model_egs) ** 2)
    test_obj = 0.5 * np.mean((test_probs - 1.0 / test_egs) ** 2)

    self._output_stats = dict(
        model_probs=model_probs,
        test_probs=test_probs,
        model_loss=model_obj,
        test_loss=test_obj,
        n_out_of_domain=0,
        log_grad_norm=0.0,  # optimized with one step
        )

  def emp_lik_iteration(self):
    """Perform newton step for two-sample empirical likelihood.

    Standard two-sample GEL methods are in the form:
    sum_i log(p_i) + sum_j log(q_j)
    s.t.
    sum_i p_i = 1
    sum_j q_j = 1
    sum_i p_i X_i = sum_j q_j Y_j

    This is can recast in the form:
    sum_k log(w_k)
    s.t.
    sum_k w_k = 1
    sum_k w_k Z_k = 0

    where w_k = 0.5 * p_k for k=1,...,N
    and w_{k+N} = 0.5 * q_k for k=1,...,M

    Z_k = [X_k]  for k=1,...,N
          [1  ]
    and
    Z_{k+N} = [-Y_k]  for k=1,...,M
              [-1  ]
    """
    self._params, output_stats = one_sample_emp_lik_iteration(
        self.aug_features, self._params)
    model_egs = model_features.shape[0]
    test_egs = test_features.shape[0]
    model_probs = 2.0 * output_stats['probs'][:model_egs]
    test_probs = 2.0 * output_stats['probs'][model_egs:]
    model_log_lik = np.mean(np.log(model_probs))
    test_log_lik = np.mean(np.log(test_probs))

    self._output_stats = dict(
        model_probs=model_probs,
        test_probs=test_probs,
        model_loss=-model_log_lik,
        test_loss=-test_log_lik,
        n_out_of_domain=output_stats['n_out_of_domain'],
        log_grad_norm=output_stats['log_grad_norm'],
        )

  def exp_tilted_iteration(self):
    """Performs half Newton step on two-sample exp. tilting objective.

    See Appendix C.2 for details of the implementation.
    """
    model_egs = self.model_features.shape[0]
    test_egs = self.test_features.shape[0]
    if model_egs != test_egs:
      raise ValueError(
          "Different number of model and test examples currently not supported")
    num_examples = model_egs + test_egs

    w_exp_tilt = np.exp(np.dot(self.aug_features, self._params)) / num_examples
    sc_f_diff = self.aug_features * w_exp_tilt[:, np.newaxis]

    hess = np.dot(sc_f_diff.T, self.aug_features)
    log_grad = np.sum(sc_f_diff, axis=0)
    log_grad_norm = np.linalg.norm(log_grad)
    newton_step = np.linalg.solve(hess, log_grad)
    self._params -= 0.5 * newton_step
    exp_weights = np.concatenate(
        [np.exp(np.dot(self.aug_model_features, self._params)) * test_egs,
        np.exp(np.dot(-self.aug_test_features, self._params)) * model_egs])
    probs = exp_weights / np.sum(exp_weights)
    model_probs = 2 * probs[:model_egs]
    test_probs = 2 * probs[model_egs:]
    self._output_stats = dict(
        model_probs=model_probs,
        test_probs=test_probs,
        model_loss=-entropy(model_probs, base=2),
        test_loss=-entropy(test_probs, base=2),
        n_out_of_domain=0,
        log_grad_norm=log_grad_norm,
        )

  def _preprocess_features(self, model_features, test_features):
    """Loads and converts features for calculate_two_sample_gel."""
    model_egs = model_features.shape[0]
    test_egs = test_features.shape[0]

    two_sample_features = np.concatenate(
        [model_features, -test_features], axis=0)
    if self.config.whiten:
      assert self.config.obj_type != GELObjective.EUCLIDEAN_LIKELIHOOD
      two_sample_features = do_pca(two_sample_features, self.config.pca_dim)

    out_model_feats = two_sample_features[:model_egs]
    out_test_feats = -two_sample_features[-test_egs:]
    aug_model_feats, aug_test_feats, aug_feats = self._make_augmented_features(
        out_model_feats, out_test_feats)

    return (out_model_feats, out_test_feats, aug_model_feats, aug_test_feats,
            aug_feats)

  def _make_augmented_features(self, model_features, test_features):
    """Make two-sample features for use with empirical likelihood iteration."""
    model_egs = model_features.shape[0]
    test_egs = test_features.shape[0]

    aug_model_feats = np.concatenate(
        [model_features, np.ones((model_egs, 1))], axis=1)
    aug_test_feats = np.concatenate(
        [test_features, np.ones((test_egs, 1))], axis=1)

    aug_features = np.concatenate([aug_model_feats, -aug_test_feats], axis=0)

    return aug_model_feats, aug_test_feats, aug_features

  def _print_stats(self, iter_i):
    """Print statistics during optimization."""
    logging.info('model loss is at iter %d is %.8f',
                 iter_i, self._output_stats['model_loss'])
    logging.info('test loss is at iter %d is %.8f',
                 iter_i, self._output_stats['test_loss'])
    logging.info('minimum model probability is %.8f',
                 np.min(self._output_stats['model_probs']))
    logging.info('maximum model probability is %.8f',
                 np.max(self._output_stats['model_probs']))
    logging.info('sum of model probabilities is %.8f',
                 np.sum(self._output_stats['model_probs']))
    logging.info('minimum test probability is %.8f',
                 np.min(self._output_stats['test_probs']))
    logging.info('maximum test probability is %.8f',
                 np.max(self._output_stats['test_probs']))
    logging.info('sum of test probabilities is %.8f',
                 np.sum(self._output_stats['test_probs']))
    logging.info(
        'number not in domain is %d', self._output_stats['n_out_of_domain'])
    logging.info(
        'log grad norm is %.15f', self._output_stats['log_grad_norm'])


  def _num_out_of_domain(self) -> bool:
    if np.any(self._output_stats['model_probs']) < 0.0:
      logging.info(
          'Encountered negative model probability, likely due to '
          'optimization. This will likely be fixed in an iteration or two.')
    if np.any(self._output_stats['model_probs']) > 1.0:
      logging.info(
          'Encountered model probability > 1.0, likely due to optimization.'
          ' This will likely be fixed in an iteration or two.')
    if np.any(self._output_stats['test_probs']) < 0.0:
      logging.info(
          'Encountered negative test probability, likely due to '
          'optimization. This will likely be fixed in an iteration or two.')
    if np.any(self._output_stats['test_probs']) > 1.0:
      logging.info(
          'Encountered test probability > 1.0, likely due to optimization. '
          'This will likely be fixed in an iteration or two.')

    return self._output_stats['n_out_of_domain'] > 0

  def _check_norm_param_condition(self):
    norm_params = np.linalg.norm(self._params)
    if (self.config.obj_type == GELObjective.EMPIRICAL_LIKELIHOOD
        and norm_params > self.config.norm_param_tol):
      model_ll = np.sum(np.log(self._output_stats['model_probs']))
      test_ll = np.sum(np.log(self._output_stats['test_probs']))
      logging.info(
          'Mean likely near boundary since norm is high %.5f, '
          'so model likelihood is -Inf, <%.5f, as is test likelihood, <%.5f.',
          norm_params, model_ll, test_ll)
      self._status = GELStatus.BOUNDARY

  def _check_if_solved(self, iteration: int):
    if iteration % 10 == 1: self._print_stats(iteration)
    obj = self._output_stats['model_loss'] + self._output_stats['test_loss']
    if not(np.isfinite(obj) and np.isfinite(self._current_loss)): return
    if (np.abs(obj-self._current_loss) < self.config.tol
        or self._output_stats['log_grad_norm'] < self.config.grad_norm_tol):
      logging.info('GEL calculation converged '
                   'at iteration %d... terminating', iteration)
      self._print_stats(iteration)
      self._status = GELStatus.SOLVED

  def _calculate_objective(self):
    model_probs = self._output_stats['model_probs']
    test_probs = self._output_stats['test_probs']
    num_model_examples = model_probs.shape[0]
    num_test_examples = test_probs.shape[0]
    if self.config.obj_type == GELObjective.EMPIRICAL_LIKELIHOOD:
      model_objective = np.mean(np.log(model_probs))
      test_objective = np.mean(np.log(test_probs))
      best_objective = -np.log(num_model_examples) - np.log(num_test_examples)
      worst_objective = -np.inf
    elif self.config.obj_type == GELObjective.EXPONENTIAL_TILTING:
      model_objective = entropy(model_probs, base=2)
      test_objective = entropy(test_probs, base=2)
      best_objective = np.log2(num_test_examples) + np.log2(num_model_examples)
      worst_objective = 0.0
    elif self.config.obj_type == GELObjective.EUCLIDEAN_LIKELIHOOD:
      model_objective = self._output_stats['model_loss']
      test_objective = self._output_stats['test_loss']
      best_objective = 0.0
      worst_objective = np.inf
    else:
      raise ValueError('Objective type %s not valid', self.config.obj_type)

    objective = model_objective + test_objective
    return objective, best_objective, worst_objective

  def calculate_divergence(self):
    model_probs = self._output_stats['model_probs']
    test_probs = self._output_stats['test_probs']
    num_model_examples = model_probs.shape[0]
    num_test_examples = test_probs.shape[0]
    if self.config.obj_type == GELObjective.EMPIRICAL_LIKELIHOOD:
      model_divergence = -np.mean(np.log(num_model_examples * model_probs))
      test_divergence = -np.mean(np.log(num_test_examples * test_probs))
      divergence = model_divergence + test_divergence
    elif self.config.obj_type == GELObjective.EXPONENTIAL_TILTING:
      model_divergence = np.sum(
          model_probs * np.log(model_probs * num_model_examples))
      test_divergence = np.sum(
          test_probs * np.log(test_probs * num_test_examples))
    elif self.config.obj_type == GELObjective.EUCLIDEAN_LIKELIHOOD:
      model_divergence = 0.5 * np.mean(
          (1. / num_model_examples - model_probs) ** 2)
      test_divergence = 0.5 * np.mean(
          (1. / num_test_examples - test_probs) ** 2)
    else:
      raise ValueError('Objective type %s not valid', self.config.obj_type)

    divergence = model_divergence + test_divergence
    return divergence, model_divergence, test_divergence

  def _evaluate_gel_solution(self, elapsed_time: float = 0.0):
    if self._status == GELStatus.SOLVED:
      logging.info('solved... in %f seconds', elapsed_time)
    elif self._status == GELStatus.BOUNDARY:
      assert self.config.obj_type == GELObjective.EMPIRICAL_LIKELIHOOD
      logging.info('final log lik is -Inf_boundary')
    elif self._status == GELStatus.NOT_IN_CONVEX_HULL:
      logging.info('Convex Hull condition broken')
      err = ('model mean not in convex hull of features... distributions not '
             'close enough to GEL')
      logging.info(err)
    else:
      logging.info('failed to solve... checking to see if model mean '
                   'is in convex hull')
      logging.info('calculating convex hull')
      if not is_in_cvx_hull(self.aug_features,
                            np.zeros((self.aug_features.shape[1],))):
        err = ('model mean not in convex hull of features... distributions not '
               'close enough to GEL')
        logging.info(err)
        self._status = GELStatus.NOT_IN_CONVEX_HULL
      else:
        logging.info('optimization failed')
        self._status = GELStatus.OPTIMIZATION_FAILED

    if self._status != GELStatus.SOLVED:
      _, _, worst_objective = self._calculate_objective()
      objective = worst_objective
      self._output_stats['model_probs'][:] = np.nan
      self._output_stats['test_probs'][:] = np.nan

    out_dict = dict()
    out_dict.update(self.config.to_dict())
    out_dict['termination_status'] = self._status
    out_dict['model_probs'] = self._output_stats['model_probs']
    out_dict['test_probs'] = self._output_stats['test_probs']
    objective, _, _ = self._calculate_objective()
    out_dict['objective'] = objective
    out_dict['elapsed_time'] = elapsed_time
    out_dict['model_feat_dims'] = self.model_features.shape
    out_dict['test_feat_dims'] = self.test_features.shape

    return out_dict

  def calculate_gel(self):
    """Calculates empirical likelihood and output results to a dictionary."""
    if self._status != GELStatus.RUNNING:
      return self._evaluate_gel_solution()

    start_time = timeit.default_timer()
    for i in range(self.config.num_iterations):
      try:
        self._iter_func()
      except np.linalg.LinAlgError:  # This error indicates EL hit the boundary
        logging.info('Encountered LinAlg Error, means we hit boundary cond')
        self._status = GELStatus.BOUNDARY
        break
      self._check_norm_param_condition()
      if self._status == GELStatus.BOUNDARY:
        break
      if self._num_out_of_domain(): continue
      self._check_if_solved(i)
      if self._status == GELStatus.SOLVED:
        break
      self._current_loss = (self._output_stats['model_loss'] +
                            self._output_stats['test_loss'])

    elapsed_time = timeit.default_timer() - start_time
    return self._evaluate_gel_solution(elapsed_time)

## Kernel Code

In [None]:
def get_kernel_config():
  """Flags for the kernel config."""
  config = config_dict.ConfigDict()
  config.kernel_type = 'exponential'
  config.polynomial_params = config_dict.create(order=3, const=1.0)
  config.exponential_params = config_dict.create(sigma=1.0)
  config.laplacian_params = config_dict.create(sigma=1.0)
  config.rbf_params = config_dict.create(sigma=1.0)
  config.rational_quadratic_params = config_dict.create(order=2.0, const=1.0)

  return config


def kernel_matrix(feat1: np.ndarray, feat2: np.ndarray,
                  config: config_dict.ConfigDict):
  """Generate kernel matrix from two sets of features."""
  kernel_type = config.kernel_type.lower()
  ndim = feat1.shape[1]

  if kernel_type == 'linear':
    kernel_mat = np.dot(feat1, feat2.T) / ndim
  elif kernel_type == 'exponential':
    sigma = config.exponential_params.sigma
    assert sigma > 0.0
    kernel_mat = np.exp(sigma * np.dot(feat1, feat2.T) / ndim)
  elif kernel_type == 'polynomial':
    order = config.polynomial_params.order
    const = config.polynomial_params.const
    assert order > 0.0 and const > 0.0
    ndim = feat1.shape[1]
    kernel_mat = (np.dot(feat1, feat2.T) / ndim + const) ** order
  elif kernel_type == 'laplacian' or 'laplace':
    sigma = config.laplacian_params.sigma
    assert sigma > 0.0
    dist_mat = pairwise_distances(feat1, feat2, metric='l1')
    kernel_mat = np.exp(-sigma * dist_mat)
  elif kernel_type == 'rbf' or 'gaussian':
    sigma = config.rbf_params.sigma
    assert sigma > 0.0
    dist_mat = pairwise_distances(feat1, feat2, metric='l2')
    kernel_mat = np.exp(-sigma * (dist_mat ** 2))
  elif kernel_type == 'rational_quadratic':
    const = config.rational_quadratic_params.const
    order = config.rational_quadratic_params.order
    assert order > 0.0 and const > 0.0
    squared_dist = pairwise_distances(feat1, feat2, metric='l2') ** 2
    kernel_mat = (squared_dist * (const ** 2)) ** -order
  else:
    raise ValueError(f'kernel_type {kernel_type} not supported')

  return kernel_mat

## Config Flags

In [None]:
def get_gel_config():
  """Flags for GEL calculation.

  Returns:
    config: ConfigDict for Flags
  """
  config = config_dict.ConfigDict()
  config.cut_dim = 1024
  config.pca_dim = 1024
  config.whiten = True
  config.num_model_examples = 0
  config.obj_type = GELObjective.EXPONENTIAL_TILTING

  config.num_iterations = 10000
  config.norm_param_tol = 1E8
  config.tol = 1E-8
  config.grad_norm_tol = 1E-8

  return config

# Unit Tests

## Helper Functions

In [None]:
def calculate_gel_unit_tests(
    shifts, obj_type, is_one_sample_test: bool = True):
  config_flags = get_gel_config()
  config_flags.obj_type = obj_type
  if obj_type == GELObjective.EUCLIDEAN_LIKELIHOOD:
    config_flags.whiten = False
  else:
    config_flags.whiten = True
  if is_one_sample_test:
    test_class = OneSampleGEL
  else:
    test_class = TwoSampleGEL
  gel_one_sample_tests = list()
  for shift in shifts:
    gel_one_sample_test = test_class(
        config_flags, model_features + shift, test_features)
    gel_one_sample_test.calculate_gel()
    gel_one_sample_tests.append(gel_one_sample_test)

  return gel_one_sample_tests, config_flags


def print_one_sample_unit_test_stats(
    shifts, gel_one_sample_tests, config: config_dict.ConfigDict):
  logging.info('Divergences for %s objective:', config.obj_type)
  for shift, gel_one_sample_test in zip(shifts, gel_one_sample_tests):
    logging.info('Shift is %f', shift)
    solution_status = gel_one_sample_test._status
    logging.info('Solution Status: %s', solution_status)
    if solution_status == GELStatus.SOLVED:
      logging.info(
          'Divergence is %f', gel_one_sample_test.calculate_divergence())
    logging.info('---------------------------------------------')


def print_two_sample_unit_test_stats(
    shifts, gel_two_sample_tests, config: config_dict.ConfigDict):
  logging.info('Divergences for %s objective:', config.obj_type)
  for shift, gel_two_sample_test in zip(shifts, gel_two_sample_tests):
    logging.info('Shift is %f', shift)
    solution_status = gel_two_sample_test._status
    logging.info('Solution Status: %s', solution_status)
    res = gel_two_sample_test.calculate_divergence()
    divergence, model_divergence, test_divergence = res
    if solution_status == GELStatus.SOLVED:
      logging.info('Divergence is %f', divergence)
      logging.info('Model divergence is %f', model_divergence)
      logging.info('Test divergence is %f', test_divergence)
    logging.info('---------------------------------------------')

## Load features and mean shift hyperparameters

In [None]:
model_features = np.random.randn(*(50000, 128))
test_features = np.random.randn(*(50000, 128))

shifts = [0.0, 0.1, 0.3]

## One-Sample GEL

Empirical Likelihood

In [None]:
gel_one_sample_tests, config_flags = calculate_gel_unit_tests(
    shifts, GELObjective.EMPIRICAL_LIKELIHOOD)

In [None]:
print_one_sample_unit_test_stats(shifts, gel_one_sample_tests, config_flags)

Exponential Tilting

In [None]:
gel_one_sample_tests, config_flags = calculate_gel_unit_tests(
    shifts, GELObjective.EXPONENTIAL_TILTING)

In [None]:
print_one_sample_unit_test_stats(shifts, gel_one_sample_tests, config_flags)

Euclidean Likelihood

In [None]:
gel_one_sample_tests, config_flags = calculate_gel_unit_tests(
    shifts, GELObjective.EUCLIDEAN_LIKELIHOOD)

In [None]:
print_one_sample_unit_test_stats(shifts, gel_one_sample_tests, config_flags)

## Two-Sample GEL

Empirical Likelihood

In [None]:
gel_two_sample_tests, config_flags = calculate_gel_unit_tests(
    shifts, GELObjective.EMPIRICAL_LIKELIHOOD, False)

In [None]:
print_two_sample_unit_test_stats(shifts, gel_two_sample_tests, config_flags)

Exponential Tilting

In [None]:
gel_two_sample_tests, config_flags = calculate_gel_unit_tests(
    shifts, GELObjective.EXPONENTIAL_TILTING, False)

In [None]:
print_two_sample_unit_test_stats(shifts, gel_two_sample_tests, config_flags)

Euclidean Likelihood

In [None]:
gel_two_sample_tests, config_flags = calculate_gel_unit_tests(
    shifts, GELObjective.EUCLIDEAN_LIKELIHOOD, False)

In [None]:
print_two_sample_unit_test_stats(shifts, gel_two_sample_tests, config_flags)

# A couple motivating examples

## Helper Functions

In [None]:
def make_mode_probs(probs, test_labels, num_classes=10):
  mode_probs = np.empty((num_classes,))
  for i in range(num_classes):
    mode_probs[i] = np.sum(probs[test_labels == i])

  return mode_probs

## Evaluating Mode Dropping with one-sample exponential tilting

Here, we recreate Figure 3(a) of the paper.

The "generative model" here is 40k examples from the CIFAR10 training set, with up to 8 classes missing.

For the mode dropping experiments, we remove examples from the last n classes. For example, if two modes are dropped, we remove labels 9 and 10. We use pool3 features.

### Load features

In [None]:
data_dirn = 'cifar10_mode_drop_data/'
cifar10_dropped_mode_data = dict()
cifar10_mode_drop_gold_probs = dict()
gel_one_sample_tests = dict()

test_data = np.load(os.path.join(data_dirn, 'cifar10_test_pool3.npz'))
test_feats = test_data['features']
test_labels = test_data['labels']

witness_data = np.load(os.path.join(
    data_dirn, 'cifar10_train_valid_10k_pool3.npz'))
witness_feats = witness_data['features'][:1024]

all_train_data = np.load(
    os.path.join(data_dirn, 'cifar10_train_valid_40k_pool3.npz'))

for num_present_modes in range(2, 11, 2):
  cifar10_dropped_mode_data[num_present_modes] = (
      all_train_data['features'][all_train_data['labels'] < num_present_modes])
  cifar10_mode_drop_gold_probs[num_present_modes] = np.zeros((10,))
  cifar10_mode_drop_gold_probs[num_present_modes][:num_present_modes] = (
      1. / num_present_modes)

### Calculate kernel features

In [None]:
kernel_config_flags = get_kernel_config()
kernel_features = dict()
test_kernel_feats = kernel_matrix(
    test_feats, witness_feats, kernel_config_flags)
for num_present_modes in range(2, 11, 2):
  kernel_features[num_present_modes] = kernel_matrix(
      cifar10_dropped_mode_data[num_present_modes], witness_feats,
      kernel_config_flags)

### Calculate KGEL

This code block recreates the results of Figure 3(a).

In [None]:
config_flags = get_gel_config()
hellinger_distances = dict()
for num_present_modes in range(2, 11, 2):
  gel_one_sample_tests[num_present_modes] = OneSampleGEL(
      config_flags, kernel_features[num_present_modes], test_kernel_feats)
  outputs = gel_one_sample_tests[num_present_modes].calculate_gel()
  per_sample_probs = outputs['probs']
  mode_probs = make_mode_probs(per_sample_probs, test_labels)
  num_missing_modes = 10 - num_present_modes
  hellinger_distances[num_missing_modes] = hellinger_dist(
      mode_probs, cifar10_mode_drop_gold_probs[num_present_modes])

### Calculate Hellinger Distances

Hellinger distances calculated here are *slightly* different than what is reported in the paper, likely due to numerical precision

In [None]:
for num_present_modes in range(2, 11, 2)[::-1]:
  num_missing_modes = 10 - num_present_modes
  print("%d missing modes, Hellinger distance: %.4f"
        % (num_missing_modes, hellinger_distances[num_missing_modes]))

In this example, we perform an experiment where the model distribution (the first half of the CIFAR10 training+validation sets) only has samples from class labels 0 and 1, while the test distribution (the second half of the CIFAR10 training+validation sets) only has samples from class labels 1 and 2. In the ideal scenario, the model and test probabilities for class 1 should sum to 1.0, while the other probabilities should sum to 0.0.

The witness features are from the CIFAR10 test set.

### Helper Functions

In [None]:
def make_two_modes(
    model_data: dict, test_data: dict,
    common_mode: int = 0, disjoint_mode1: int = 1, disjoint_mode2: int = 2,
    num_feats_per_class = None):
  """Given two sets of features, construct two for use with two-sample tests.

  The structure of the outputs have two sets of features. One has examples from
  common_mode (class A) and disjoint_mode1 (class B). The second set has
  examples from common_mode (class A) and disjoint_mode2 (class C)."""
  assert disjoint_mode1 != disjoint_mode2
  out1_data = dict()
  model_feats = model_data['features']
  model_labels = model_data['labels']
  feats_common = model_feats[model_labels == common_mode][:num_feats_per_class]
  assert feats_common.shape[0] == num_feats_per_class
  feats_disjoint1 = model_feats[model_labels == disjoint_mode1][
      :num_feats_per_class]
  assert feats_disjoint1.shape[0] == num_feats_per_class
  out1_data['features'] = np.concatenate(
      [feats_common, feats_disjoint1], axis=0)
  labels_list = [common_mode] * num_feats_per_class
  labels_list += [disjoint_mode1] * num_feats_per_class
  out1_data['labels'] = np.array(labels_list, dtype=np.int32)

  out2_data = dict()
  test_feats = test_data['features']
  test_labels = test_data['labels']
  feats_common2 = test_feats[test_labels == common_mode][:num_feats_per_class]
  assert feats_common2.shape[0] == num_feats_per_class, feats_common2.shape[0]
  feats_disjoint2 = test_feats[test_labels == disjoint_mode2][
      :num_feats_per_class]
  assert feats_disjoint2.shape[0] == num_feats_per_class

  out2_data['features'] = np.concatenate(
      [feats_common2, feats_disjoint2], axis=0)
  labels_list = [common_mode] * num_feats_per_class
  labels_list += [disjoint_mode2] * num_feats_per_class
  out2_data['labels'] = np.array(labels_list, dtype=np.int32)

  return out1_data, out2_data

def make_label_balanced_model_and_test_data():
  """Make features with 2500 egs per class for data with 5000 egs per class."""
  data_dirn = 'cifar10_mode_drop_data/'
  all_data = np.load(
      os.path.join(data_dirn, 'cifar10_train_valid_50k_pool3.npz'))
  all_data_feats = all_data['features']
  all_data_labels = all_data['labels']
  out_feats_model = list()
  out_labels_model = list()
  out_feats_test = list()
  out_labels_test = list()
  num_examples_per_sample_set = 2500
  for label in range(10):
    per_class_feats = all_data_feats[all_data_labels == label]
    per_class_labels = all_data_labels[all_data_labels == label]
    out_feats_model.append(per_class_feats[:num_examples_per_sample_set])
    out_labels_model.append(per_class_labels[:num_examples_per_sample_set])
    out_feats_test.append(per_class_feats[num_examples_per_sample_set:])
    out_labels_test.append(per_class_labels[num_examples_per_sample_set:])

  out_model_data = dict(
      features=np.concatenate(out_feats_model, axis=0),
      labels=np.concatenate(out_labels_model))
  out_test_data = dict(
      features=np.concatenate(out_feats_test, axis=0),
      labels=np.concatenate(out_labels_test))

  return out_model_data, out_test_data

### Load features

Make two sets of features

1. For the "model", keep features from classes 0 and 1
2. For the "test", keep features from classes 1 and 2

In [None]:
model_data, test_data = make_label_balanced_model_and_test_data()
num_feats_per_class = 2500  # using 50k set, 2500 egs/class
model_two_classes, test_two_classes = make_two_modes(
    model_data, test_data, num_feats_per_class=num_feats_per_class)

### Calculate kernel features

In [None]:
kernel_config_flags = get_kernel_config()
witness_data = np.load(os.path.join(
    data_dirn, 'cifar10_test_pool3.npz'))
witness_feats = witness_data['features'][:1024]
model_kgel_data = dict(labels=model_two_classes['labels'])
test_kgel_data = dict(labels=test_two_classes['labels'])
model_kgel_data['features'] = kernel_matrix(
    model_two_classes['features'], witness_feats, kernel_config_flags)
test_kgel_data['features'] = kernel_matrix(
    test_two_classes['features'], witness_feats, kernel_config_flags)

### Calculate KGEL

In [None]:
config_flags = get_gel_config()
gel_two_sample_test = TwoSampleGEL(
    config_flags, model_kgel_data['features'], test_kgel_data['features'])
out_dict = gel_two_sample_test.calculate_gel()

### Extract Results

In [None]:
model_probs = np.array(
    [out_dict['model_probs'][:num_feats_per_class].sum(),
     out_dict['model_probs'][num_feats_per_class:].sum()])

test_probs = np.array(
    [out_dict['test_probs'][:num_feats_per_class].sum(),
     out_dict['test_probs'][num_feats_per_class:].sum()])

print('Ideal probability of the common mode is 1.0')
print('Ideal probability of the disjoint mode is 0.0')

print('Model probability of the common mode is', model_probs[0])
print('Model probability of the disjoint mode is', model_probs[1])

print('Test probability of the common mode is', test_probs[0])
print('Test probability of the disjoint mode is', test_probs[1])