##### Copyright 2019 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License")

In [0]:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Basic setup

**About this Colab**

This Colab accompanies the NeurIPS 2019 paper:
<br/>

*Practical and Consistent Estimation of f-Divergences* \
*Paul K. Rubenstein, Olivier Bousquet, Josip Djolonga, Carlos Riquelme, Ilya Tolstikhin*

The paper can be found at https://arxiv.org/abs/1905.11112


This Colab reproduces Figures 1, 2 and 3 from the paper. By default, the Colab just loads precomputed data and makes the plots from these. Recomputing everything from scratch takes ~5 hours.

In [0]:
#@title Imports { display-mode: "form" }
import tensorflow as tf
tf.enable_eager_execution()
import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd
from matplotlib import pyplot as plt
import numpy as np
import cvxpy as cp
import time
import os

import matplotlib.cm as cm
from scipy import stats
from scipy.special import logsumexp

import h5py

import seaborn as sns

from matplotlib import rc
rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
rc('text', usetex=False)

ROOT_PATH = "gs://rammc_data/"

# Figure 1

## Experimental design:

*   We estimate $D_f(Q_z, P_z)$ for different $f$, $Q_z$ and $P_z$;
*   We consider $d$-variate normal $P_z$ with mean $b_0\in R^d$ and identity covariance;
*   Base distribution $P_x$ is standard normal in $R^k$;
*   We always take $Q(Z|X)$ to be $d$-variate normal with mean value $AX + b\in R^d$ and covariance $\epsilon^2 I_d$;
*   In this case $Q_z := \int Q(Z|X) dP_x(X)$ is a $d$-variate Gaussian with mean value $b$ and covariance $A A' + \epsilon^2 I_d$;
*  The distribution $Q_z$ is parametrized with one scalar $\lambda\in [-\Lambda, \Lambda]$ as follows:
*   > $A = \lambda  A_0$,
*   > $b = \lambda  b_0$.
*   We consider dimensions: $d=32,64,128$
*  We consider the following divergences: KL, $\chi^2$ and squared Hellinger divergences.
*  We run the propsed RAM-MC estimator with ${n=1}$ and ${n=500}$ points sampled from $P_x$ and $128$ Monte-Carlo samples.

Squared Hellinger divergence is:
$$
H^2(P, Q) := \int (\sqrt{p(z)} - \sqrt{q(z)})^2 dz \in [0, 2],
$$
$\chi^2$ divergence is:
$$
\chi^2(P, Q) = \int \left(\frac{p(z)}{q(z)}\right)^2 q(z) dz - 1.
$$
KL divergence is:
$$
KL(Q, P) = \int \log\left(\frac{q(z)}{p(z)}\right) q(z) dz
$$

In [0]:
#@title Closed form divergence computations { display-mode: "form" }

#@markdown Please see Appendix C.1 of the paper for analytical expressions for
#@markdown the divergences considered here.


def get_dims(A):
  """Infers input and latent dimensions from a matrix."""
  d_latent, d_input = A.get_shape().as_list()
  return d_latent, d_input

def get_q_cov(A, b, std):
  d_latent, _ = get_dims(A)
  return tf.matmul(A, A, transpose_b=True) + std**2 * tf.eye(d_latent)

def compute_kl(A, b, std):
  """Computes the squared Hellinger distance between unit Gaussian Pz
  and Gaussian Qz with mean b and covariance AA^t + (std**2)I ."""
  d_latent, d_input = get_dims(A)
  p = tfd.MultivariateNormalDiag(loc=tf.zeros(shape=(d_latent,)),
                                 scale_diag=tf.ones(d_latent))
  q_cov = get_q_cov(A, b, std)
  q = tfd.MultivariateNormalFullCovariance(loc=b, covariance_matrix=q_cov)
  return q.kl_divergence(p).numpy()

def compute_hsq(A, b, std):
  """Computes the squared Hellinger distance between unit Gaussian Pz
  and Gaussian Qz with mean b and covariance AA^t + (std**2)I ."""
  d_latent, d_input = get_dims(A)
  Sigma1 = tf.eye(d_latent)
  Sigma2 = tf.matmul(A, A, transpose_b=True) + std**2 * tf.eye(d_latent)
  res = tf.linalg.logdet(Sigma1) / 4. + tf.linalg.logdet(Sigma2) / 4.
  res -= tf.linalg.logdet(0.5 * Sigma1 + 0.5 * Sigma2) / 2.
  res = tf.exp(res)
  quad_form = tf.matmul(tf.linalg.inv(0.5 * Sigma1 + 0.5 * Sigma2),
                        tf.reshape(b, (d_latent, -1)))
  quad_form = tf.matmul(tf.reshape(b, (-1, d_latent)), quad_form)
  res *= tf.exp(- 1. / 8 * quad_form)
  return (2. - 2. * res[0, 0]).numpy()

def compute_chi2(A, b, std):
  """Computes the chi square divergence between unit Gaussian Pz
  and Gaussian Qz with mean b and covariance AA^t + (std**2)I ."""
  def quadform(v, M):
    d, _ = get_dims(M)
    quad_form = tf.matmul(M, tf.reshape(v, (d, -1)))
    quad_form = tf.matmul(tf.reshape(v, (-1, d)), quad_form)
    return quad_form[0, 0]

  A = tf.cast(A, tf.float64)
  b = tf.cast(b, tf.float64)
  std = tf.cast(std, tf.float64)

  d_latent, d_input = get_dims(A)
  Sigma1 = (tf.matmul(A, A, transpose_b=True)
            + std**2 * tf.eye(d_latent, dtype=tf.float64))
  Sigma1inv = tf.linalg.inv(Sigma1)
  Sigma2 = tf.eye(d_latent, dtype=tf.float64)
  Sigma2inv = Sigma2
  mu1 = b
  mu2 = tf.zeros(shape=(d_latent,), dtype=tf.float64)

  S = 2. * Sigma1inv - Sigma2inv

  # Check that chi2 is well defined.
  if tf.linalg.det(Sigma1) <= 0:
    raise ValueError("Sigma1 cannot have non-positive determinant.")

  eig, _ = tf.linalg.eigh(S)
  if tf.reduce_min(eig) <= 0:
    return float('Inf')

  scale = 1.0 / tf.sqrt(tf.linalg.det(2. * Sigma1 - tf.matmul(Sigma1, Sigma1)))
  if scale.numpy() < 1:
    print('Scale ', scale.numpy())
  res = 0.5 * quadform(mu2, Sigma2inv) - quadform(mu1, Sigma1inv)
  v = 2. * tf.matmul(Sigma1inv, tf.reshape(mu1, (d_latent, -1)))
  v -= tf.matmul(Sigma2inv, tf.reshape(mu2, (d_latent, -1)))
  M = tf.linalg.inv(- S / 2.)
  res -= 1. / 4 * quadform(v, M)
  if res.numpy() < 0:
    print(res.numpy())
  res = scale * tf.exp(res) - 1.
  return res.numpy()


## Estimators

In [0]:
#@title RAM-MC { display-mode: "form" }
#@markdown Our proposed estimator. See Equation (3) of the paper.

def compute_ram_mc(n, m, A, b, std, f, n_batches):
  """Estimates Df(Qz, Pz) with RAM-MC estimator. Pz is unit Gaussian and
  Qz is Gaussian with mean b and covariance AA^t + (std**2)I.

  Args:
    n: Number of mixture components to approximate Qz.
    m: Number of MC samples to use.
    A: Parameter determining covariance matrix of Qz
    b: Mean of Qz
    std: Parameter determining covariance matrix of Qz
    f: Which f-divergence to compute. "KL", "Chi2" or "Hsq".
    n_batches: Number of repetitions to perform.
  Returns:
    estimates: A numpy array of estimates, one per n_batch.
  """
  d_latent, d_input = get_dims(A)
  p = tfd.MultivariateNormalDiag(loc=tf.zeros(shape=(d_latent,)),
                                 scale_diag=tf.ones(d_latent))

  # Base P(X) distribution, which is standard normal in d_input.
  data = tfd.MultivariateNormalDiag(loc=tf.zeros(d_input),
                                    scale_diag=tf.ones(d_input))
  data_samples = data.sample(n * n_batches)  # Minibatch from P(x).
  data_samples = tf.reshape(data_samples, [n_batches, n, d_input])
  A = tf.reshape(A, [1, d_latent, d_input])
  A = tf.tile(A, [n_batches, 1, 1])
  data_posterior = tfd.MultivariateNormalDiag(
      loc=tf.matmul(data_samples, A, transpose_b=True) + b,
      scale_diag=std * tf.ones(d_latent))
  # Compose a mixture distribution. Experiment-specific parameters are indexed
  # with the first dimension in data_posterior.
  mixture = tfd.MixtureSameFamily(
      mixture_distribution=tfd.Categorical(probs=[1. / n] * n),
      components_distribution=data_posterior)
  if f == 'KL':
    # Estimate is 1/m \sum_i log ( dQn(zi) / dP(zi) ) with zi ~ Qn.
    mc_samples = mixture.sample(m)
    log_density_ratios = (mixture.log_prob(mc_samples) -
                          p.log_prob(mc_samples))
    estimates = (tf.reduce_mean(log_density_ratios, axis=0)).numpy()
  elif f == 'Chi2':
    # Estimate is 1/m \sum_i dQn(zi) / dP(zi) - 1 with zi ~ Qn.
    mc_samples = mixture.sample(m)
    logratio = mixture.log_prob(mc_samples) - p.log_prob(mc_samples)
    estimates = (tf.exp(tf.reduce_logsumexp(logratio, axis=0)) / m - 1.).numpy()
  elif f == 'Hsq':
    ## Estimate is 2 - 2 / m \sum_i exp(0.5 log (dP(zi) / dQn(zi))), zi ~ Qn.
    mc_samples = mixture.sample(m)
    logratio = -mixture.log_prob(mc_samples) + p.log_prob(mc_samples)
    estimates = 2.
    estimates -= 2. * tf.exp(tf.reduce_logsumexp(0.5 * logratio, axis=0)) / m
    estimates = estimates.numpy()
  else:
    raise ValueError("f must be one of 'KL', 'Chi2', 'Hsq'.")
  return estimates

In [0]:
#@title Plug-in estimator { display-mode: "form" }

#@markdown Perform kernel density estimation, then do Monte-Carlo sampling by
#@markdown plugging the estimated densities into the divergence formulae
def estimate_plugin(n, m, A, b, std, f, n_batches, eps=1e-8):
  """Estimates Df(Qz, Pz) with the plugin estimator. Pz is unit Gaussian and Qz
  is Gaussian with mean b and covariance AA^t + (std**2)I. First perform kernel
  density estimation of two densities, then plug in.
  """

  def numpy_sample(p, n, d):
    points = p.sample(n)
    points = tf.reshape(points, [d, -1]).numpy()
    return points

  d_latent, d_input = get_dims(A)
  p = tfd.MultivariateNormalDiag(loc=tf.zeros(shape=(d_latent,)),
                                     scale_diag=tf.ones(d_latent))
  q_cov = get_q_cov(A, b, std)
  q = tfd.MultivariateNormalFullCovariance(
      loc=b, covariance_matrix=q_cov)

  # Repeat computations n_batches times.
  res = []
  for exp in range(n_batches):

    # I.i.d. points from p and q to estimate their densities.
    p_kde_points = numpy_sample(p, n, d_latent)
    q_kde_points = numpy_sample(q, n, d_latent)

    try:
      p_hat = stats.gaussian_kde(p_kde_points)
      q_hat = stats.gaussian_kde(q_kde_points)
    except:
      res.append(np.nan)
      continue

    mc_points = numpy_sample(q, m, d_latent)
    try:
      q_vals = q_hat.evaluate(mc_points)
      p_vals = p_hat.evaluate(mc_points) + eps
      log_q_vals = q_hat.logpdf(mc_points)
      log_p_vals = p_hat.logpdf(mc_points) + eps
    except:
      res.append(np.nan)
      continue

    if f == 'KL':
      res.append(np.mean(log_q_vals - log_p_vals))
    elif f == 'Hsq':
      logratio = log_p_vals - log_q_vals
      estimate_val = 2.
      estimate_val -= 2. * np.exp(logsumexp(0.5 * logratio)) / m
      res.append(estimate_val)
    elif f == 'Chi2':
      logratio = log_q_vals - log_p_vals
      estimate_val = np.exp(logsumexp(logratio)) / m - 1.
      res.append(estimate_val)
    else:
      raise ValueError("f must be one of 'KL', 'Chi2', 'Hsq'.")
  return np.array(res)


In [0]:
#@title Estimator of Nguyen et al. { display-mode: "form" }
#@markdown The M1 estimator proposed by Nguyen et al., *Estimating divergence
#@markdown functionals and the likelihood ratio by convex risk minimization.*
#@markdown For full reference see [28] in the paper.

def nguyen_estimate_rkhs(n, A, b, std, lmbd, rkhs_sigma2=None, n_exps=1):
  """ Compute estimator of Nguyen et al. of KL(Qz, Pz) for the RKHS family.

  Args:
    lmbd: positive regularizer
    rkhs_sigma2: width of the Gaussian kernel
  """
  def kernel_matrices(X, Y, sigma2=None, eps=1e-4):
    # X.
    norms_x_sq = tf.reduce_sum(tf.square(X), axis=1, keepdims=True)
    dotprods_x = tf.matmul(X, X, transpose_b=True)
    dists_x_sq = norms_x_sq + tf.transpose(norms_x_sq) - 2. * dotprods_x
    # Y.
    norms_y_sq = tf.reduce_sum(tf.square(Y), axis=1, keepdims=True)
    dotprods_y = tf.matmul(Y, Y, transpose_b=True)
    dists_y_sq = norms_y_sq + tf.transpose(norms_y_sq) - 2. * dotprods_y
    # XY.
    dotprods_xy = tf.matmul(X, Y, transpose_b=True)
    dists_xy_sq = norms_x_sq + tf.transpose(norms_y_sq) - 2. * dotprods_xy

    if sigma2 is None:
      sigma2 = np.median(dists_xy_sq)

    Kx = tf.exp(- dists_x_sq / 2. / sigma2) + eps * tf.eye(n)
    Ky = tf.exp(- dists_y_sq / 2. / sigma2) + eps * tf.eye(n)
    Kxy = tf.exp(- dists_xy_sq / 2. / sigma2)
    return (Kx, Ky, Kxy)

  def is_pos_def(x):
    eig = np.linalg.eigvals(x)
    res = np.all(eig > 0)
    if not res:
      print(np.sort(eig))
    return res

  d_latent, _ = get_dims(A)
  prior = tfd.MultivariateNormalDiag(loc=tf.zeros(d_latent),
                                     scale_diag=tf.ones(d_latent))
  q_cov = get_q_cov(A, b, std)
  q = tfd.MultivariateNormalFullCovariance(
      loc=b, covariance_matrix=q_cov)

  Y = q.sample(n * n_exps)  # Minibatch from P(x).
  Y = tf.reshape(Y, [n_exps, n, d_latent])
  X = prior.sample(n * n_exps)
  X = tf.reshape(X, [n_exps, n, d_latent])

  # Perform n_exps experiments.
  estimates = []
  for i in range(n_exps):
    Kx, Ky, Kxy = kernel_matrices(X[i], Y[i], sigma2=rkhs_sigma2)
    Kx = Kx.numpy()
    Ky = Ky.numpy()
    Kxy = Kxy.numpy()

    # Get objective of the dual convex program and solve.
    alpha = cp.Variable(n)
    obj = cp.Minimize( -1. - 1. / n * cp.sum(cp.log(n * alpha)) +
                        1. / lmbd / 2. * cp.quad_form(alpha, Ky) +
                        1. / 2. / lmbd / n / n * np.sum(Kx) -
                        1. / lmbd / n * cp.sum(cp.matmul(Kxy, alpha)))
    prob = cp.Problem(obj, [alpha >= 0])  # Constraints.
    prob.solve()
    estimates.append(- 1. / n *  np.sum(np.log(n * alpha.value)))
  return np.array(estimates)

## Run experiments and make plots

In [0]:
#@title Set experiment parameters { display-mode: "form" }
#@markdown And generate base matrix and vector A_0 and b_0.

N_RANGE = [1, 500]  # Sample sizes.
MC_NUM = 128  # Number of Monte-Carlo samples for RAM-MC.
N_EXP = 10  # Number of times to repeat each experiment.
K = 20  # Base space dimensionality.
STD = 0.5  # Gaussian covariance noise.
BETA = 0.5  # Scale for base covariance.
D_RANGE = [1, 4, 16]  # Latent space dimensionality.
LBD_MAX = 2.  # lambda range.

tf.random.set_random_seed(345)

# Generating A and b parameters for various dimensions.
BASE_PARAMS = {}
for d in D_RANGE:
  b0 = tf.random.normal(shape=(d,))
  b0 /= np.linalg.norm(b0)
  A0 = tf.random.normal(shape=(d, K))
  A0 /= tf.linalg.norm(A0)
  BASE_PARAMS[d] = {'b0': b0, 'A0': A0}

In [0]:
#@title Run experiments or load precomputed results { display-mode: "form" }

#@markdown Leaving the following boxes unchecked will load precomputed
#@markdown results. <br/>

#@markdown Running the RAM-MC and plugin experiments takes ~5 minutes.
RUN_RAM_MC_PLUGIN_EXPERIMENTS = False #@param { type: "boolean"}

#@markdown Running the Nguyen et al. M1 experiments takes ~1.5 hours.
RUN_NGUYEN_EXPERIMENTS = False #@param { type: "boolean"}

def load_figure1_data(file_name):
  data = {}
  path = os.path.join(ROOT_PATH, file_name)
  !gsutil cp $path $file_name
  with h5py.File(file_name, 'r') as f:
    for i in f:
      data[int(i)] = {}
      for j in f[i]:
        data[int(i)][int(j)] = {}
        for k in f[i][j]:
          data[int(i)][int(j)][k] = list(f[i][j][k])
  return data

if RUN_RAM_MC_PLUGIN_EXPERIMENTS:
  ram_mc_plugin_results = {}
  for d in D_RANGE:
    if d not in ram_mc_plugin_results:
      ram_mc_plugin_results[d] = {}
    for n in N_RANGE:
      print(d, n)
      if n not in ram_mc_plugin_results[d]:
        ram_mc_plugin_results[d][n] = {}
      for lbd in np.linspace(-LBD_MAX, LBD_MAX, 51):
        # Create Abase with ones on diagonal
        Abase = np.zeros((d, K))
        np.fill_diagonal(Abase, 1.)
        Abase = tf.convert_to_tensor(Abase, tf.dtypes.float32)
        Albd = Abase * BETA + lbd * BASE_PARAMS[d]['A0']
        blbd = lbd * BASE_PARAMS[d]['b0']

        # Compute true closed form values (only once)
        if n == N_RANGE[0]:
          true_kl = compute_kl(Albd, blbd, STD)
          true_hsq = compute_hsq(Albd, blbd, STD)
          true_chi2 = compute_chi2(Albd, blbd, STD)
        else:
          true_kl = None
          true_hsq = None
          true_chi2 = None

        for dvg in ['KL', 'Chi2', 'Hsq']:
          if dvg not in ram_mc_plugin_results[d][n]:
            ram_mc_plugin_results[d][n][dvg] = []

          batch_ram_mc = compute_ram_mc(n, MC_NUM, Albd, blbd, STD,
                                  f=dvg, n_batches=N_EXP)

          batch_plugin = estimate_plugin(n, MC_NUM, Albd, blbd, STD,
                                        f=dvg, n_batches=N_EXP)

          ram_mc_plugin_results[d][n][dvg].append(
              (true_kl, true_hsq, true_chi2, batch_ram_mc, batch_plugin))
else:
  ram_mc_plugin_results = load_figure1_data('ram_mc_plugin_results.hdf5')

if RUN_NGUYEN_EXPERIMENTS:
  nguyen_results = {}
  for d in D_RANGE:
    if d not in nguyen_results:
      nguyen_results[d] = {}
    n = N_RANGE[-1]
    print((d, n))
    if n not in nguyen_results[d]:
      nguyen_results[d][n] = {}
    for lbd in np.linspace(-LBD_MAX, LBD_MAX, 51):
      print(lbd)
      # Create Abase with ones on diagonal
      Abase = np.zeros((d, K))
      np.fill_diagonal(Abase, 1.)
      Abase = tf.convert_to_tensor(Abase, tf.dtypes.float32)
      Albd = Abase * BETA + lbd * BASE_PARAMS[d]['A0']
      blbd = lbd * BASE_PARAMS[d]['b0']
      for dvg in ['KL']:
        if dvg not in nguyen_results[d][n]:
          nguyen_results[d][n][dvg] = []
        batch_nguyen = nguyen_estimate_rkhs(n, Albd, blbd, STD,
                                            1. / n, n_exps=N_EXP)
        nguyen_results[d][n][dvg].append(batch_nguyen)
else:
  nguyen_results = load_figure1_data('nguyen_results.hdf5')

In [0]:
#@title Generate plots { display-mode: "form" }

def make_plot_figure1(ram_mc_plugin_results, nguyen_results):
  sns.set_style("white")
  fig = plt.figure(figsize = (13, 8))
  elinewidth = 0.4  # Width of errorbars
  errorevery = 3  # Set spacing of error bars to avoid crowding of figure.

  def overflow_std(array):
    """Calculates std of array, but if overflow error would occur returns a 
    finite number larger than the range of any axes used in plots."""
    if (np.inf in array) or (np.nan in array) or any(1e20 < array):
      std = 1e20
    else:
      std = np.std(array)
    return std

  for i in range(1, 10):
    sp = plt.subplot(3, 3, i)
    d = D_RANGE[(i - 1) % 3]
    dvg = ['KL', 'Chi2', 'Hsq'][int((i - 1) / 3)]
    colors = cm.rainbow(np.linspace(0, 1, len(N_RANGE)))
    for color, n in zip(colors, N_RANGE):

      if n == N_RANGE[0]:
        # Plot true values
        idx = N_RANGE[0]
        true_kl = np.array([el[0] for el in ram_mc_plugin_results[d][idx][dvg]])
        true_hsq = np.array(
            [el[1] for el in ram_mc_plugin_results[d][idx][dvg]])
        true_chi2 = np.array([el[2] if isinstance(el[2], float) else el[2]
                              for el in ram_mc_plugin_results[d][idx][dvg]])
        if dvg == 'KL':
          plt.plot(true_kl, color='blue', linewidth=3, label='Truth')
          plt.yscale('log')
        if dvg == 'Hsq':
          plt.plot(true_hsq, color='blue', linewidth=3, label='Truth')
        if dvg == 'Chi2':
          plt.plot(true_chi2, color='blue', linewidth=3, label='Truth')
          plt.yscale('log')

      # Plot RAM-MC estimates for N=500.
      if n == 500:
        mean_ram_mc_n500 = np.array(
            [np.mean(el[3]) for el in ram_mc_plugin_results[d][n][dvg]])
        std_ram_mc_n500 = np.array(
            [np.std(el[3]) for el in ram_mc_plugin_results[d][n][dvg]])
        color = 'red'
        plt.errorbar(range(51),
                    mean_ram_mc_n500,
                    errorevery=errorevery,
                    yerr=std_ram_mc_n500,
                    elinewidth=elinewidth,
                    color=color, label='RAM-MC estimator, N=' + str(n),
                    marker="^", markersize=5, markevery=10)

      # Plot Nguyen estimates
      if n == 500 and dvg == 'KL':
        color = 'green'
        mean_nguyen = np.array(
            [np.mean(el) for el in nguyen_results[d][n][dvg]])
        std_nguyen = np.array(
            [np.std(el) for el in nguyen_results[d][n][dvg]])
        plt.errorbar(range(51),
                    mean_nguyen,
                    errorevery=errorevery,
                    yerr=std_nguyen,
                    elinewidth=elinewidth,
                    color=color, label='M1 estimator, N=' + str(n),
                    marker="v", markersize=5, markevery=10)

      # Plot plug-in estimates
      if n == 500:
        mean_plugin = np.array(
            [np.mean(el[4]) for el in ram_mc_plugin_results[d][n][dvg]])
        std_plugin = np.array(
            [overflow_std(el[4]) for el in ram_mc_plugin_results[d][n][dvg]])
        color = 'darkorange'
        plt.errorbar(range(51),
                    mean_plugin,
                    errorevery=errorevery,
                    yerr=std_plugin,
                    elinewidth=elinewidth,
                    color=color, label='Plug-in estimator, N=' + str(n),
                    marker="s", markersize=5, markevery=10)

      # Plot RAM-MC with N=1.
      if n == N_RANGE[0]:
        color = 'black'
        mean_ram_mc1 = np.array(
            [np.mean(el[3]) for el in ram_mc_plugin_results[d][n][dvg]])
        std_ram_mc1 = np.array(
            [np.std(el[3]) for el in ram_mc_plugin_results[d][n][dvg]])
        plt.errorbar(range(51) + 0.3 * np.ones(51),
                    mean_ram_mc1,
                    errorevery=errorevery,
                    yerr=std_ram_mc1,
                    elinewidth=elinewidth,
                    color=color, label='RAM-MC estimator, N=1',
                    marker="o", markersize=5, markevery=10)

      if dvg == 'KL':
        plt.ylim((0.03, 15))
      if dvg == 'Chi2':
        plt.ylim((0.1, 1e6))
      if dvg == 'Hsq':
        plt.ylim((0., 2))

    sp.axes.get_xaxis().set_ticklabels([])
    if d != 1:
      sp.axes.get_yaxis().set_ticklabels([])
    else:
      sp.axes.tick_params(axis='both', labelsize=15)

    if i < 4:
      plt.title("d = {}".format(d), fontsize=18)
    if i == 1:
      plt.ylabel('KL', fontsize=18)
    if i == 4:
      plt.ylabel(r'$\chi^2$', fontsize=18)
    if i == 7:
      plt.ylabel(r'$H^2$', fontsize=18)


    # Hide the right and top spines.
    ax = plt.gca()
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    # Only show ticks on the left and bottom spines.
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')
    plt.tick_params(
      axis='x',          # changes apply to the x-axis
      which='both',      # both major and minor ticks are affected
      bottom=False,      # ticks along the bottom edge are off
      top=False,         # ticks along the top edge are off
      labelbottom=False) # labels along the bottom edge are off
    plt.tick_params(
      axis='y',          # changes apply to the x-axis
      which='both',      # both major and minor ticks are affected
      left=False,      # ticks along the bottom edge are off
      right=False,         # ticks along the top edge are off
      labelbottom=False) # labels along the bottom edge are off
    ax.yaxis.grid()
    plt.xlim((-2, 51))

    if i > 6:
      plt.xlabel(r"$\lambda$", fontsize=17)

  ax = fig.axes[1]
  handles, labels = ax.get_legend_handles_labels()
  labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0]))
  fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.51, 1.0),
            ncol=5, fancybox=True, shadow=True, fontsize=12, frameon=True)
  plt.tight_layout()
  plt.show()

make_plot_figure1(ram_mc_plugin_results, nguyen_results)

# Figures 2 and 3

In [0]:
tf.disable_eager_execution()

## Experiment design:

We take the encoders of 6 trained Autoencder models (a mixture of Variational AEs and Wasserstein AEs). These models were trained on the *CelebA* dataset with ~200K images. The encoders are probabilistic, meaning that an image is mapped to a distribution in the latent space.

We consider the use of RAM-MC as a method to estimate f-divergences between the prior distributions of these models and the *aggregate posteriors*, the mixture of encodings of all of the training data.

In [0]:
#@title Setup for mixture distributions { display-mode: "form" }

def create_qk(means, log_vars, weights):
  """Creates a mixture of K Gaussians and a standard normal prior."""
  k, d_latent = means.shape
  pz = tfd.MultivariateNormalDiag(loc=tf.zeros(shape=(d_latent,)),
                                  scale_diag=tf.ones(d_latent))

  stds = tf.exp(tf.clip_by_value(log_vars, -30, 30) / 2.)
  qk_components = tfd.MultivariateNormalDiag(loc=means, scale_diag=stds)
  qk = tfd.MixtureSameFamily(
      mixture_distribution=tfd.Categorical(probs=weights),
      components_distribution=qk_components)
  return pz, qk

def sample_qk(qk, k, d_latent, n_samples):
  """Sample points from the mixture qk with k components."""
  INT32_MAXVAL = 2147483647
  MAX_SAMPLE_SIZE = 1000000000

  if n_samples * d_latent * k > INT32_MAXVAL:
    print(('Warning: a very large tensor will be internally instantiated, '
           'this may cause an OOM error.'))

  if n_samples * d_latent * k > MAX_SAMPLE_SIZE:
    n_batches = n_samples * d_latent * k // MAX_SAMPLE_SIZE + 1
    mc_samples = []
    for _ in range(n_batches):
      mc_samples.append(qk.sample(n_samples // n_batches))
    if n_samples % n_batches > 0:
      mc_samples.append(qk.sample(n_samples % n_batches))
    mc_samples = tf.concat(mc_samples, axis=0)
    assert mc_samples.get_shape()[0] == n_samples
  else:
    mc_samples = qk.sample(n_samples)
  return mc_samples

## Estimators of Df(Qz, Pz)

In [0]:
#@title Monte-Carlo estimation { display-mode: "form" }

def mc_benchmark(means, log_vars, weights, mc_num, f, num_exps):
  """MC-based estimation of the f-divergence.
    Df(Q, P) = int_z f(q(z) / p(z)) p(z) dz
  Estimate Df(QK, Pz) for standard gaussian Pz and a K-mixture of Gaussians QK.
  Parameters of all the component Gaussians are stored in means and log_vars,
  with both having shapes (K, d_latent).

  For estimating, sample mc_num points from Pz or QK depending on the f and
  average the values of f evaluated at the density ratios dQK/dPz over those
  points.
  """

  tf.reset_default_graph()
  (k, d_latent) = means.shape

  pz, qk = create_qk(means, log_vars, weights)
  mc_samples = sample_qk(qk, k, d_latent, mc_num)
  log_density_ratios = qk.log_prob(mc_samples) - pz.log_prob(mc_samples)

  if f == 'KL':
    mc_estimate = tf.reduce_mean(log_density_ratios)
    surrogate_estimate = tf.log(mc_estimate)
  elif f == 'Hsq':
    mc_estimate = 2.
    mc_estimate -= 2. * tf.exp(tf.reduce_logsumexp(
        -0.5 * log_density_ratios, axis=0)) / mc_num
    # Since estimate may be very close (but below) 2 we also compute
    # log(2 - estimate)
    surrogate_estimate = tf.log(2. / mc_num) + tf.reduce_logsumexp(
        -0.5 * log_density_ratios, axis=0)

  estimates = []
  with tf.Session() as sess:
    for i in range(num_exps):
      mc_estimate_val, surrogate_estimate_val = sess.run(
          [mc_estimate, surrogate_estimate])
      estimates.append((mc_estimate_val, surrogate_estimate_val))
  return np.array(estimates)

In [0]:
#@title RAM-MC { display-mode: "form" }

def estimate_ram_mc(means_all, log_vars_all, n, mc_num, f, num_exps=1):
  """Report RAM-MC estimate for Df(QK|Pz).

  Args:
    means_all, log_vars_all: (k, d_latent) shaped tensors containing
      encodings of all the examples.
    n: Number of encoded examples/mixture components to use in RAM-MC.
    mc_num: Number of MC samples used in RAM-MC.
  """
  (k, d_latent) = means_all.shape
  estimates = []
  for _ in range(num_exps):
    # Sample n components out of all k and use MC sampling.
    counts = np.random.multinomial(n, [1. / k] * k)
    nonzero_ids = np.nonzero(counts)[0]
    counts = counts[nonzero_ids]
    freqs = counts / (n + 0.)
    freqs = np.float32(freqs)
    estimate_val = mc_benchmark(
        means_all[nonzero_ids],
        log_vars_all[nonzero_ids],
        weights=freqs, mc_num=mc_num, f=f, num_exps=1)[0]
    estimates.append(estimate_val)
  return np.array(estimates)

## Process data

In [0]:
#@title Load precomputed embeddings of CelebA data { display-mode: "form" }

#@markdown Gets means and log-variances of encoding distributions of all data
#@markdown for each encoder.


def load_celebA_embeddings():
  means_logvars = {}
  file_name = 'means_logvars.hdf5'
  path = os.path.join(ROOT_PATH, file_name)
  !gsutil cp $path $file_name

  with h5py.File(file_name, 'r') as f:
    for i in f:
      means_logvars[int(i)] = {}
      for j in f[i]:
        means_logvars[int(i)][j] = f[i][j][:]
  return means_logvars

## Divergence estimates

In [0]:
#@title Calculate divergence estimates or load precomputed results.
#@title { display-mode: "form" }

#@markdown Calculates Monte-Carlo baseline and RAM-MC estimates of KL and
#@markdown Squared Hellinger divergences between mixture of encoded data
#@markdown and prior. Takes ~4 hours to run.
CALCULATE_DIVERGENCE_ESTIMATES = False #@param { type: "boolean"}

# Note: in the paper we used a 10,000 MC samples for the baseline.
# In order to speed up computation here, we use a smaller number of samples.
MC_BASELINE_N_SAMPLES = 100
MC_BASELINE_N_EXPS = 10
RAM_MC_N_EXPS = 50

N_RANGE_REAL_EXPS = [2 ** i for i in range(15)[::-1]]
MC_RANGE = [1000, 10]


def load_ram_mc_results():
  results_ram_mc = {}
  file_name = 'results_ram_mc.hdf5'
  path = os.path.join(ROOT_PATH, file_name)
  !gsutil cp $path $file_name

  with h5py.File(file_name, 'r') as f:
    for i in f:
      results_ram_mc[int(i)] = {}
      for j in f[i]:
        results_ram_mc[int(i)][int(j)] = {}
        for k in f[i][j]:
          results_ram_mc[int(i)][int(j)][int(k)] = {}
          for m in f[i][j][k]:
            results_ram_mc[int(i)][int(j)][int(k)][m] = f[i][j][k][m][:]
  return results_ram_mc

def load_mc_benchmark_results():
  results_mc_benchmark = {}
  file_name = 'results_mc_benchmark.hdf5'
  path = os.path.join(ROOT_PATH, file_name)
  !gsutil cp $path $file_name

  with h5py.File(file_name, 'r') as f:
    for i in f:
      results_mc_benchmark[int(i)] = {}
      for j in f[i]:
        results_mc_benchmark[int(i)][j] = f[i][j][:]
  return results_mc_benchmark

if CALCULATE_DIVERGENCE_ESTIMATES:
  t_init = time.time()

  means_logvars = load_celebA_embeddings()

  results_mc_benchmark = {}
  results_ram_mc = {}

  for i in range(1,7):
    results_mc_benchmark[i] = {}
    results_ram_mc[i] = {}

    means, log_variances = [means_logvars[i][key]
                            for key in ['means', 'log_variances']]
    (k, d_latent) = means.shape

    print('z_dim=%d' % d_latent)

    # Compute the benchmark MC estimator as a reference value
    for dvg in ['KL', 'Hsq']:
      print('Computing MC estimate benchmark for f={}'.format(dvg))
      mc_vals = mc_benchmark(means, log_variances, [1. / k] * k,
                             MC_BASELINE_N_SAMPLES, dvg, MC_BASELINE_N_EXPS)
      results_mc_benchmark[i][dvg] = mc_vals

    for n in N_RANGE_REAL_EXPS:
      results_ram_mc[i][n] = {}
      for mc_num in MC_RANGE:
        print('N=%d, MC=%d' % (n, mc_num))
        results_ram_mc[i][n][mc_num] = {}
        for dvg in ['KL', 'Hsq']:
          print('Evaluating for f=%s' % dvg)
          # Compute our estimate
          ram_mc_vals = estimate_ram_mc(means, log_variances, n,
                                        mc_num, dvg, num_exps=RAM_MC_N_EXPS)
          results_ram_mc[i][n][mc_num][dvg] = ram_mc_vals
  print("It took {} seconds to complete.".format(time.time() - t_init))
else:
  results_ram_mc = load_ram_mc_results()
  results_mc_benchmark = load_mc_benchmark_results()


## Make plots

In [0]:
#@title Plotting function { display-mode: "form" }

def make_plots_real_data(dvg):
  num_steps = len(N_RANGE_REAL_EXPS)
  errorbar_width = 2
  # Amount by which to shift the curves relatively to each other.
  delta_step = 0.35

  # Keys to the results dict.
  MODELS = [1, 2, 3, 4, 5, 6]
  # Hardcode their corresponding latent space dims.
  MODELS_D = [32, 32, 64, 64, 128, 128]
  # Hardcode the order of the models which we want to plot.
  MODEL_IDS = [0, 2, 4, 1, 3, 5]

  current_palette = sns.color_palette()

  def get_log_error_bars(log_estimates, scale_std=1.):
    """ Given set of log-observations, compute log(mean /pm std)

    Computing variance is more tricky. We will compute log(std) which is
    0.5 log(var) = 0.5 log( mean(X^2) - (mean(X))^2 ). Notice that
    log mean(X^2) can be computed using logsumexp. log((mean(X))^2) is
    simply 2 log (mean(X)).
    """
    m = len(log_estimates)

    log_mean_value = logsumexp(log_estimates) - np.log(m + 0.)
    log_term1 = logsumexp(2 * log_estimates) - np.log(m + 0.)
    log_term2 = 2 * log_mean_value
    log_var = logsumexp(a=[log_term1, log_term2], b=[1., -1.])
    log_unb_var = log_var + np.log(m / (m - 1.))
    log_std = log_unb_var / 2.

    error_plus = logsumexp([log_std, log_mean_value],
                           b=[scale_std, 1.]) - log_mean_value

    if log_mean_value > log_std + np.log(scale_std):
      error_minus = log_mean_value - logsumexp([log_mean_value, log_std],
                                               b=[1., -scale_std])
    else:
      error_minus = 0.
    return log_mean_value, error_plus, error_minus
  fig = plt.figure(figsize = (13, 8))
  for plot_id in range(1, 7):
    plt.subplot(2, 3, plot_id)

    model = MODELS[MODEL_IDS[plot_id - 1]]

    # MC benchmark
    if dvg == 'KL':
      mc_vals = results_mc_benchmark[model][dvg][:,0]
      mean_value = np.mean(mc_vals)
      error = np.std(mc_vals, ddof=1)
    elif dvg == 'Hsq':
      log_mc_vals = results_mc_benchmark[model][dvg][:,1]
      mean_value, error_plus, error_minus = get_log_error_bars(log_mc_vals)
      error = [np.reshape([error_minus, error_plus], (2, 1))
               for _ in range(num_steps)]
      error = np.hstack(error)
    else:
      raise ValueError(
          "Argument dvg must be 'KL' or 'Hsq', not: {}".format(dvg))

    plt.errorbar(range(num_steps),
                 [mean_value] * num_steps,
                 yerr= [error] * num_steps if dvg == 'KL' else error,
                 linewidth=3,
                 color='blue',
                 elinewidth=errorbar_width,
                 label='True MC reference')

    # RAM-MC estimates.
    for (j, mc_num) in enumerate(reversed(MC_RANGE)):
      if dvg == 'KL':
        values = [np.mean(results_ram_mc[model][n][mc_num][dvg][:,0])
                  for n in N_RANGE_REAL_EXPS[::-1]]
        errors = [np.std(results_ram_mc[model][n][mc_num][dvg][:,0], ddof=1)
                  for n in N_RANGE_REAL_EXPS[::-1]]
      else:
        values = []
        errors = []
        for n in N_RANGE_REAL_EXPS[::-1]:
          log_mc_vals = results_ram_mc[model][n][mc_num][dvg][:,1]
          log_error_bars_out = get_log_error_bars(log_mc_vals)
          log_mean_value, error_plus, error_minus = log_error_bars_out
          values.append(log_mean_value)
          errors.append(np.reshape([error_minus, error_plus], (2,1)))
        errors = np.hstack(errors)

      color_to_show = current_palette[2] if mc_num < 20 else current_palette[3]
      marker_to_show = 's' if mc_num < 20 else 'o'
      plt.errorbar(np.array(range(num_steps)) + (j + 1) * delta_step,
                   values, yerr=errors,
                   linewidth=3,
                   elinewidth=errorbar_width,
                   capthick = 2,
                   color=color_to_show,
                   marker=marker_to_show, markersize=10, markevery=100,
                   label='RAM-MC estimator, M=%d' % mc_num)

    if dvg == 'KL':
      if plot_id == 1:
        plt.ylim((200, 300))
      if plot_id == 2:
        plt.ylim((420, 530))
      if plot_id == 3:
        plt.ylim((740, 900))

    if plot_id < 4:
      plt.title('d=%d' % (MODELS_D[MODEL_IDS[plot_id - 1]]), fontsize=18)
    if plot_id in [1, 4]:
      plt.ylabel('estimate value', fontsize=18)
    plt.xlim((-1, num_steps))
    if plot_id > 3:
      plt.xlabel(r'$\log_2 N$', fontsize=17)
    plt.xticks(range(num_steps), map(str, range(num_steps)))

    # Hide the right and top spines.
    ax = plt.gca()
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    # Only show ticks on the left and bottom spines.
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')
    plt.tick_params(axis='x', which='both',
                    bottom=True, top=False, labelbottom=True)
    plt.tick_params(axis='y', which='both',
                    left=False, right=False, labelbottom=False)
    ax.yaxis.grid()

  ax = fig.axes[1]
  handles, labels = ax.get_legend_handles_labels()
  labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0]))
  fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 1.05),
             ncol=3, fancybox=True, shadow=True, fontsize=14)

  plt.tight_layout()

In [0]:
#@title Make KL divergence plot (Figure 2) { display-mode: "form" }

make_plots_real_data("KL")

In [0]:
#@title Make Squared Hellinger divergence plot (Figure 3) { display-mode: "form" }

make_plots_real_data("Hsq")