##### Copyright 2019 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License")

In [0]:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Basic setup

**About this Colab**<br /> This is a companion Colab for the paper:

*On Mutual Information Maximization for Representation Learning*<br />
Michael Tschannen\*, Josip Djolonga\*, Paul Rubenstein, Sylvain Gelly, Mario Lucic

The Colab can be used to visualize precomputed results or to rerun the experiments reported in the paper.

**Running the experiments**<br />
By default, the precomputed results will be loaded, but individual experiments can be run with the Colab by checking the `RUN_EXPERIMENTS` checkbox below. The batch size used in the paper was 128 and we average over 20 runs. With one run, the entire set of experiments will complete in ~2 hours. For multiple runs we suggest copying the code and running a stand-alone version. If you wish to run the experiments within the Colab, make sure you execute all cells in the "Setup" section.

In [0]:
#@title Imports, configurations, and helper functions { display-mode: "form" }
from __future__ import division
from __future__ import print_function

import collections
import copy
import functools
import itertools
import os
import pickle

from matplotlib import pyplot as plt
from matplotlib.ticker import FuncFormatter
import numpy as np
import pandas as pd
from scipy.ndimage import gaussian_filter1d
import seaborn as sns
import tensorflow as tf
from tensorflow.python.ops.parallel_for import gradients
import tensorflow_datasets as tfds
import tensorflow_probability as tfp
import sklearn.linear_model as sk_linear


slim = tf.contrib.slim
tfb = tfp.bijectors
tfd = tfp.distributions
tfkl = tf.keras.layers

tf.keras.backend.clear_session()

ResultsConfig = collections.namedtuple(
    "ResultsConfig", ["nets", "critic", "loss"])

Results = collections.namedtuple(
    'Results',
    ['iterations', 'training_losses', 'testing_losses',
     'classification_accuracies', 'singular_values'])

ResultsAdversarial = collections.namedtuple(
    "ResultsAdversarial",
    ["losses_e", "losses_c", "classification_accuracies", "iters"]
)

ResultsSamplingIssues = collections.namedtuple(
    "ResultsSamplingIssues", ["mi_true", "nce_estimates_noniid", 
                              "nce_estimates_iid", "nwj_estimates_noniid", 
                              "nwj_estimates_iid"])

def convert_to_data_frame(result, exp_name, nets, critic, loss, seed):
  """Convert results class to a data frame."""
  label = "{}, {}, {}".format(nets, critic, loss)
  rows = list(
      zip(
          itertools.repeat(exp_name),
          itertools.repeat(nets),
          itertools.repeat(critic),
          itertools.repeat(loss),
          itertools.repeat(seed),
          result.iterations,
          [-loss for loss in result.testing_losses],  # Loss -> bound.
          result.classification_accuracies,
          itertools.repeat(label)))
  df_eval = pd.DataFrame(
      rows,
      columns=("exp_name", "nets", "Critic", "Estimator",
               "run", "iteration", "bound_value", "accuracy", "label"))

  df_eval["Estimator"] = df_eval["Estimator"].replace(
      to_replace={
          "nce": "$I_{NCE}$",
          "nwj": "$I_{NWJ}$"
      })
  df_eval["Critic"] = df_eval["Critic"].replace(
      to_replace={
          "concat": "MLP",
          "separable": "Separable",
          "innerprod": "Inner product",
          "bilinear": "Bilinear"
      })
  return df_eval


def apply_default_style(ax):
  ax.set_xlim([0, 20001])
  ax.get_xaxis().set_major_formatter(
      FuncFormatter(lambda x, p: format(int(x/1000), ',')))
  ax.set_xlabel("Training steps (in thousands)")
  plt.tick_params(top=False, right=False, bottom=False, left=False)
  handles, labels = ax.get_legend_handles_labels()
  plt.legend(loc="lower right", handles=handles[1:], labels=labels[1:])

FONTSIZE = 15 
sns.set_style("whitegrid")
plt.rcParams.update({'axes.labelsize': FONTSIZE,
                     'xtick.labelsize': FONTSIZE,
                     'ytick.labelsize': FONTSIZE,
                     'legend.fontsize': FONTSIZE})

NRUNS = 1 #@param { type: "slider", min: 1, max: 20, step: 1}
TRAIN_BATCH_SIZE = 128 #@param { type: "slider", min: 64, max: 128, step: 64}
RUN_EXPERIMENTS = False #@param { type: "boolean"}
DIMS = 784

def get_testing_loss(x_array, session, loss, data_ph, dims, batch_size=512):
  total_loss = 0
  for i in range(0, x_array.shape[0], batch_size):
    x_slice = x_array[i:i+batch_size, :dims]
    total_loss += x_slice.shape[0] * session.run(loss,
                                                 feed_dict={data_ph: x_slice})
  return total_loss / x_array.shape[0]

def get_classification_accuracy(session, codes, data_ph, dims):
  x_train_mapped = map_data(x_train, session, codes, data_ph, dims)
  x_test_mapped = map_data(x_test, session, codes, data_ph, dims)
  accuracy = logistic_fit(x_train_mapped, y_train, x_test_mapped, y_test)
  return accuracy

def map_data(x_array, session, codes, data_ph, dims, batch_size=512):
  x_mapped = []
  for i in range(0, x_array.shape[0], batch_size):
    x_mapped.append(
        session.run(codes,
                    feed_dict={data_ph: x_array[i:i+batch_size, :dims]}))
  return np.concatenate(x_mapped, axis=0)


In [0]:
#@title Import bounds implemented by Poole et al. (2019) { display-mode: "form" }
# From https://colab.research.google.com/github/google-research/google-research/blob/master/vbmi/vbmi_demo.ipynb 

def reduce_logmeanexp_nodiag(x, axis=None):
  batch_size = x.shape[0].value
  logsumexp = tf.reduce_logsumexp(x - tf.linalg.tensor_diag(np.inf * tf.ones(batch_size)), axis=axis)
  if axis:
    num_elem = batch_size - 1.
  else:
    num_elem  = batch_size * (batch_size - 1.)
  return logsumexp - tf.math.log(num_elem)

def tuba_lower_bound(scores, log_baseline=None):
  if log_baseline is not None:
    scores -= log_baseline[:, None]
  batch_size = tf.cast(scores.shape[0], tf.float32)
  # First term is an expectation over samples from the joint,
  # which are the diagonal elmements of the scores matrix.
  joint_term = tf.reduce_mean(tf.linalg.diag_part(scores))
  # Second term is an expectation over samples from the marginal,
  # which are the off-diagonal elements of the scores matrix.
  marg_term = tf.exp(reduce_logmeanexp_nodiag(scores))
  return 1. + joint_term -  marg_term

def nwj_lower_bound(scores):
  # equivalent to: tuba_lower_bound(scores, log_baseline=1.)
  return tuba_lower_bound(scores - 1.) 

def infonce_lower_bound(scores):
  """InfoNCE lower bound from van den Oord et al. (2018)."""
  nll = tf.reduce_mean(tf.linalg.diag_part(scores) - tf.reduce_logsumexp(scores, axis=1))
  # Alternative implementation:
  # nll = -tf.nn.sparse_softmax_cross_entropy_with_logits(logits=scores, labels=tf.range(batch_size))
  mi = tf.math.log(tf.cast(scores.shape[0].value, tf.float32)) + nll
  return mi

In [0]:
#@title Define the linear evaluation protocol { display-mode: "form" }

def logistic_fit(x_train, y_train, x_test, y_test):
  logistic_regressor = sk_linear.LogisticRegression(
      solver='saga', multi_class='multinomial', tol=.1, C=10.)
  from sklearn.preprocessing import MinMaxScaler
  scaler = MinMaxScaler()
  x_train = scaler.fit_transform(x_train)
  x_test = scaler.transform(x_test)
  logistic_regressor.fit(x_train, y_train.ravel())
  return logistic_regressor.score(x_test, y_test.ravel())

In [0]:
#@title Define and load the dataset, check baseline in pixel space { display-mode: "form" }

tf.reset_default_graph()

TFDS_NAME = "mnist"
FEATURE_INPUT = "image"
FEATURE_LABEL = "label"
N_CLASSES = 10

DIMS = 784  # Total dimensions after flattening.

def map_fn(example):
  image = example[FEATURE_INPUT]
  image = tf.cast(image, tf.float32) / 255.0
  image = tf.reshape(image, [-1])  # Flatten.
  label = example[FEATURE_LABEL]
  return {FEATURE_INPUT: image, FEATURE_LABEL: label}

def load_data(split):
  return (tfds.load(TFDS_NAME, split=split)
              .cache()
              .map(map_func=map_fn)
              .shuffle(1000))
  
def tfds_to_np(dataset):
  features = list(tfds.as_numpy(dataset))
  images = np.stack([f[FEATURE_INPUT].ravel() for f in features])
  labels = np.stack([f[FEATURE_LABEL].ravel() for f in features])
  return images, labels

dataset_train = load_data("train")
dataset_test = load_data("test")
x_train, y_train = tfds_to_np(dataset_train)
x_test, y_test = tfds_to_np(dataset_test)
tf.reset_default_graph()

x_train_noisy = x_train + 0.05 * np.random.randn(*x_train.shape)
x_test_noisy = x_test + 0.05 * np.random.randn(*x_test.shape)
print("Fit on half the pixels: {}. It should be around 0.835.".format(
    logistic_fit(x_train_noisy[:, :DIMS//2], y_train,
                 x_test_noisy[:, :DIMS//2], y_test)))

def processed_train_data(dims, batch_size):
  dataset = load_data("train")
  dataset_batched = dataset.repeat().batch(batch_size, drop_remainder=True)
  get_next = dataset_batched.make_one_shot_iterator().get_next()
  features = get_next[FEATURE_INPUT]
  labels = get_next[FEATURE_LABEL]
  x_1, x_2 = tf.split(features, [dims, DIMS-dims], axis=-1)
  return x_1, x_2, labels

## Encoders

Here we define the encoder architectures, namely MLP and ConvNet.

In [0]:
class MLP(tf.keras.Model):
  def __init__(self, layer_dimensions, shortcuts, dense_kwargs={}):
      super(MLP, self).__init__()
      self._layers = [tfkl.Dense(dimensions, **dense_kwargs)
                     for dimensions in layer_dimensions[:-1]]
      dense_kwargs_copy = copy.deepcopy(dense_kwargs)
      dense_kwargs_copy["activation"] = None
      self._layers.append(tfkl.Dense(layer_dimensions[-1], **dense_kwargs_copy))
      self._shortcuts = shortcuts

  @property
  def layers(self):
    return self._layers

  def __call__(self, inputs):
    x = inputs
    for layer in self.layers:
      x = layer(x) + x if self._shortcuts else layer(x)
    return x


# LayerNorm implementation copied from
# https://stackoverflow.com/questions/39095252/fail-to-implement-layer-normalization-with-keras
class LayerNorm(tfkl.Layer):

    """ Layer Normalization in the style of https://arxiv.org/abs/1607.06450 """
    def __init__(self, scale_initializer='ones', bias_initializer='zeros',
                 axes=[1,2,3], epsilon=1e-6, **kwargs):
        super(LayerNorm, self).__init__(**kwargs)
        self.epsilon = epsilon
        self.scale_initializer = tf.keras.initializers.get(scale_initializer)
        self.bias_initializer = tf.keras.initializers.get(bias_initializer)
        self.axes = axes

    def build(self, input_shape):
        self.scale = self.add_weight(shape=(input_shape[-1],),
                                     initializer=self.scale_initializer,
                                     trainable=True,
                                     name='{}_scale'.format(self.name))
        self.bias = self.add_weight(shape=(input_shape[-1],),
                                    initializer=self.bias_initializer,
                                    trainable=True,
                                    name='{}_bias'.format(self.name))
        self.built = True

    def call(self, x, mask=None):
        mean = tf.keras.backend.mean(x, axis=self.axes, keepdims=True)
        std = tf.keras.backend.std(x, axis=self.axes, keepdims=True)
        norm = (x - mean) * (1/(std + self.epsilon))
        return norm * self.scale + self.bias

    def compute_output_shape(self, input_shape):
        return input_shape


class ConvNet(tf.keras.Sequential):
  def __init__(self, channels=64, kernel_size=5, input_dim=DIMS//2, output_dim=100,
               activation=tf.nn.relu):
      # Note: This works only for the specific data set considered here.
      super(ConvNet, self).__init__([
        tfkl.Reshape((14, 28, 1), input_shape=(input_dim,)),
        tfkl.Conv2D(channels, kernel_size, strides=2,
                    padding="same", activation=activation),
        tfkl.Conv2D(2*channels, kernel_size, strides=2,
                    padding="same", activation=activation),
        LayerNorm(),
        tfkl.GlobalAveragePooling2D(),
        tfkl.Dense(output_dim),
      ])

### Custom RealMVP

We make two small modifications to the standard RealNVP implementation (highlighted with comments with **). Due to numerical instability we replace the exp with a softplus. The exp is normally used because it makes calculation of the log-det-jacobian simple, which is necessary for many applications of  normalizing flows. In our setting, we only care about the architecture being  invertible. This is still satisfied with our modification.

In [0]:
from tensorflow_probability.python.internal import tensorshape_util
import tensorflow.compat.v1 as tf1
from tensorflow_probability.python.bijectors import affine_scalar
from tensorflow_probability.python.bijectors import bijector as bijector_lib

# Modified from tensorflow_probability/python/bijectors/real_nvp.py
class RealNVP(bijector_lib.Bijector):
  def __init__(self,
               num_masked,
               shift_and_log_scale_fn=None,
               bijector_fn=None,
               is_constant_jacobian=False,
               validate_args=False,
               name=None):
    name = name or 'real_nvp'
    if num_masked < 0:
      raise ValueError('num_masked must be a non-negative integer.')
    self._num_masked = num_masked
    # At construction time, we don't know input_depth.
    self._input_depth = None
    if bool(shift_and_log_scale_fn) == bool(bijector_fn):
      raise ValueError('Exactly one of `shift_and_log_scale_fn` and '
                       '`bijector_fn` should be specified.')
    if shift_and_log_scale_fn:
      def _bijector_fn(x0, input_depth, **condition_kwargs):
        shift, log_scale = shift_and_log_scale_fn(x0, input_depth,
                                                  **condition_kwargs)
        # ** First modification is here.
        return affine_scalar.AffineScalar(shift=shift, scale=log_scale)

      bijector_fn = _bijector_fn

    if validate_args:
      bijector_fn = _validate_bijector_fn(bijector_fn)

    # Still do this assignment for variable tracking.
    self._shift_and_log_scale_fn = shift_and_log_scale_fn
    self._bijector_fn = bijector_fn

    super(RealNVP, self).__init__(
        forward_min_event_ndims=1,
        is_constant_jacobian=is_constant_jacobian,
        validate_args=validate_args,
        name=name)

  def _cache_input_depth(self, x):
    if self._input_depth is None:
      self._input_depth = tf.compat.dimension_value(
          tensorshape_util.with_rank_at_least(x.shape, 1)[-1])
      if self._input_depth is None:
        raise NotImplementedError(
            'Rightmost dimension must be known prior to graph execution.')
      if self._num_masked >= self._input_depth:
        raise ValueError(
            'Number of masked units must be smaller than the event size.')

  def _forward(self, x, **condition_kwargs):
    self._cache_input_depth(x)
    x0, x1 = x[..., :self._num_masked], x[..., self._num_masked:]
    y1 = self._bijector_fn(x0, self._input_depth - self._num_masked,
                           **condition_kwargs).forward(x1)
    y = tf.concat([x0, y1], axis=-1)
    return y

  def _inverse(self, y, **condition_kwargs):
    self._cache_input_depth(y)
    y0, y1 = y[..., :self._num_masked], y[..., self._num_masked:]
    x1 = self._bijector_fn(y0, self._input_depth - self._num_masked,
                           **condition_kwargs).inverse(y1)
    x = tf.concat([y0, x1], axis=-1)
    return x

  def _forward_log_det_jacobian(self, x, **condition_kwargs):
    self._cache_input_depth(x)
    x0, x1 = x[..., :self._num_masked], x[..., self._num_masked:]
    return self._bijector_fn(x0, self._input_depth - self._num_masked,
                             **condition_kwargs).forward_log_det_jacobian(
                                 x1, event_ndims=1)

  def _inverse_log_det_jacobian(self, y, **condition_kwargs):
    self._cache_input_depth(y)
    y0, y1 = y[..., :self._num_masked], y[..., self._num_masked:]
    return self._bijector_fn(y0, self._input_depth - self._num_masked,
                             **condition_kwargs).inverse_log_det_jacobian(
                                 y1, event_ndims=1)

def real_nvp_default_template(hidden_layers,
                              shift_only=False,
                              activation=tf.nn.relu,
                              name=None,
                              *args,  # pylint: disable=keyword-arg-before-vararg
                              **kwargs):
  with tf.name_scope(name or 'real_nvp_default_template'):

    def _fn(x, output_units, **condition_kwargs):
      """Fully connected MLP parameterized via `real_nvp_template`."""
      if condition_kwargs:
        raise NotImplementedError(
            'Conditioning not implemented in the default template.')

      if tensorshape_util.rank(x.shape) == 1:
        x = x[tf.newaxis, ...]
        reshape_output = lambda x: x[0]
      else:
        reshape_output = lambda x: x
      for units in hidden_layers:
        x = tf1.layers.dense(
            inputs=x,
            units=units,
            activation=activation,
            *args,  # pylint: disable=keyword-arg-before-vararg
            **kwargs)
      x = tf1.layers.dense(
          inputs=x,
          units=(1 if shift_only else 2) * output_units,
          activation=None,
          *args,  # pylint: disable=keyword-arg-before-vararg
          **kwargs)
      if shift_only:
        return reshape_output(x), None
      shift, log_scale = tf.split(x, 2, axis=-1)
       # ** Here is the second modification.
      return reshape_output(shift), 1e-7 + tf.nn.softplus(reshape_output(log_scale))

    return tf1.make_template('real_nvp_default_template', _fn)

class RealNVPBijector(tf.keras.Model):
  def __init__(self, dimensions, n_couplings, hidden_layers, dense_kwargs):
    super(RealNVPBijector, self).__init__()
    permutations = [np.random.permutation(dimensions)
                    for _ in range(n_couplings)]
    bijectors = []
    for permutation in permutations:
      bijectors.append(RealNVP(
        dimensions // 2,
        real_nvp_default_template(hidden_layers, **dense_kwargs)))
      bijectors.append(tfb.Permute(permutation))
    self._bijector = tfb.Chain(bijectors)

  def call(self, inputs):
    return self._bijector.forward(inputs)

## Critics

Here we define the encoder architectures, namely Inner product, bilinear, concat, and separable critic.

In [0]:
class InnerProdCritic(tf.keras.Model):
  def call(self, x, y):
    return tf.matmul(x, y, transpose_b=True)

class BilinearCritic(tf.keras.Model):
  def __init__(self, feature_dim=100, **kwargs):
    super(BilinearCritic, self).__init__(**kwargs)
    self._W = tfkl.Dense(feature_dim, use_bias=False)

  def call(self, x, y):
    return tf.matmul(x, self._W(y), transpose_b=True)

# Copied from
# https://colab.research.google.com/github/google-research/google-research/blob/master/vbmi/vbmi_demo.ipynb
class ConcatCritic(tf.keras.Model):
  def __init__(self, hidden_dim=200, layers=1, activation='relu', **kwargs):
    super(ConcatCritic, self).__init__(**kwargs)
    # output is scalar score
    self._f = MLP([hidden_dim for _ in range(layers)]+[1], False, {"activation": "relu"})

  def call(self, x, y):
    batch_size = tf.shape(x)[0]
    # Tile all possible combinations of x and y
    x_tiled = tf.tile(x[None, :],  (batch_size, 1, 1))
    y_tiled = tf.tile(y[:, None],  (1, batch_size, 1))
    # xy is [batch_size * batch_size, x_dim + y_dim]
    xy_pairs = tf.reshape(tf.concat((x_tiled, y_tiled), axis=2),
                          [batch_size * batch_size, -1])
    # Compute scores for each x_i, y_j pair.
    scores = self._f(xy_pairs) 
    return tf.transpose(tf.reshape(scores, [batch_size, batch_size]))


class SeparableCritic(tf.keras.Model):
  def __init__(self, hidden_dim=100, output_dim=100, layers=1,
               activation='relu', **kwargs):
    super(SeparableCritic, self).__init__(**kwargs)
    self._f_x = MLP([hidden_dim for _ in range(layers)] + [output_dim], False, {"activation": activation})
    self._f_y = MLP([hidden_dim for _ in range(layers)] + [output_dim], False, {"activation": activation})

  def call(self, x, y):
    x_mapped = self._f_x(x)
    y_mapped = self._f_y(y)
    return tf.matmul(x_mapped, y_mapped, transpose_b=True)

# Experiments

## Training loop for Section 3.1 - 3.3 in the paper

Classic training loop where we update the encoder (and possibly the critic) and evaluate the model on test data.

In [0]:
def train(g1,
          g2,
          critic,
          loss_fn,
          learning_rate=1e-4,
          batch_size=TRAIN_BATCH_SIZE,
          n_iters=15000,
          n_evals=15,
          compute_jacobian=False,
          noise_std=0.0,
          data_dimensions=DIMS//2):
  """Runs the training loop for a fixed model.

  Args:
    g1: Function, maps input1 to representation.
    g2: Function, maps input2 to representation.
    critic: Function, maps two representations to scalar.
    loss_fn: Function, mutual information estimator.
    learning_rate: Learning rate.
    batch_size: Training batch size.
    n_iters: Number of optimization iterations.
    n_evals: Number of model evaluations.
    compute_jacobian: Whether to estimate the singular values of the Jacobian.
    noise_std: Standard deviation for the Gaussian noise. Default is 0.0.
    data_dimensions: The dimension of the data. By default it's half of the
      original data dimension.
  Returns:
    Returns and instance of `Results` tuple.
  """
  x_1, x_2, _ = processed_train_data(data_dimensions, batch_size)

  if noise_std > 0.0:
    assert x_1.shape == x_2.shape, "X1 and X2 shapes must agree to add noise!"
    noise = noise_std * tf.random.normal(x_1.shape)
    x_1 += noise
    x_2 += noise

  # Compute the representations.
  code_1, code_2 = g1(x_1), g2(x_2)
  critic_matrix = critic(code_1, code_2)

  # Compute the Jacobian of g1 if needed.
  if compute_jacobian:
    jacobian = gradients.batch_jacobian(code_1, x_1, use_pfor=False)
    singular_values = tf.linalg.svd(jacobian, compute_uv=False)

  # Optimizer setup.
  loss = loss_fn(critic_matrix)
  optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
  optimizer_op = optimizer.minimize(loss)

  with tf.Session() as session:
    session.run(tf.global_variables_initializer())

    # Subgraph for eval (add noise to input if necessary)
    data_ph = tf.placeholder(tf.float32, shape=[None, data_dimensions])
    data_ph_noisy = data_ph + noise_std * tf.random.normal(tf.shape(data_ph))
    codes = g1(data_ph_noisy)

    training_losses, testing_losses, classification_accuracies, iters, sigmas \
      = [], [], [], [], []
    # Main training loop.
    for iter_n in range(n_iters):
      # Evaluate the model performance.
      if iter_n % (n_iters // n_evals) == 0:
        iters.append(iter_n)
        accuracy = get_classification_accuracy(session, codes, data_ph, data_dimensions)
        classification_accuracies.append(accuracy)
        testing_losses.append(
            get_testing_loss(x_test, session, loss, data_ph, data_dimensions))
        if compute_jacobian:
          sigmas.append(session.run(singular_values))
        print("Step {:>10d} fit {:>.5f}".format(iter_n, accuracy))
      # Run one optimization step.
      loss_np, _ = session.run([loss, optimizer_op])
      training_losses.append(loss_np)

  return Results(iterations=iters,
                 training_losses=training_losses,
                 testing_losses=testing_losses,
                 classification_accuracies=classification_accuracies,
                 singular_values=sigmas)


def run_sweep(nets, critics, loss_fns, exp_name, **kwargs):
  """Runs the sweep across encoder networks, critics, and the estimators."""
  grid = itertools.product(nets, critics, loss_fns)
  data_frames = []
  results_with_singular_values = []
  for nets_name, critic_name, loss_name in grid:
    print("[New experiment] encoder: {}, critic: {}, loss: {}".format(
        nets_name, critic_name, loss_name))
    with tf.Graph().as_default():
      g1, g2 = nets[nets_name]()
      critic = critics[critic_name]()
      loss_fn = loss_fns[loss_name]
      results_per_run = []
      for n in range(NRUNS):
        try:
          results = train(g1, g2, critic, loss_fn, **kwargs)
          results_per_run.append(results)
        except Exception as ex:
          print("Run {} failed! Error: {}".format(n, ex))
      for i, result in enumerate(results_per_run):
        data_frames.append(convert_to_data_frame(
            result, exp_name, nets_name, critic_name, loss_name, i))
      if kwargs.get('compute_jacobian', False):
        results_with_singular_values.append((
            ResultsConfig(nets_name, critic_name, loss_name), results_per_run
        ))
  
  return {
      "df": pd.concat(data_frames), 
      "singular_values": results_with_singular_values
  }

## Maximized MI and improved downstream performance

Reproduces the first experiment of Section 3.1 and the corresponding Figures 1 (a, b).

In this experiment we use invertible architectures. We show that training to maximize the MI estimators results in improved downstream performance, even though MI is maximized for any parameter setting (due to invertibility).

In [0]:
#@title Run experiment or load precomputed results { display-mode: "form" }
def run_all_experiments():
  tf.reset_default_graph()
  infonce_loss = lambda x: -infonce_lower_bound(x)
  nwj_loss = lambda x: -nwj_lower_bound(x)
  loss_fcts = {
      "nwj": nwj_loss,
      "nce": infonce_loss
  }
  kwargs = dict(
      shift_only=True,
      activation=lambda x: tf.nn.relu(x),
      kernel_initializer=tf.initializers.truncated_normal(stddev=0.0001),
      bias_initializer='zeros')
  nets = {
      "realnvp": lambda: (
          RealNVPBijector(DIMS // 2, n_couplings=30, hidden_layers=[512, 512], dense_kwargs=kwargs),
          RealNVPBijector(DIMS // 2, n_couplings=30, hidden_layers=[512, 512], dense_kwargs=kwargs)
          )
      }
  critics = {
      "bilinear": lambda: BilinearCritic(feature_dim=DIMS//2),
  }
  return run_sweep(nets, critics, loss_fcts, "invertible", n_iters=21000, n_evals=21)

if RUN_EXPERIMENTS:
  data_invertible = run_all_experiments()["df"]
else:
  !wget -q -N https://storage.googleapis.com/mi_for_rl_files/mi_results.pkl
  data_invertible = pd.read_pickle('mi_results.pkl')
  data_invertible = data_invertible[data_invertible.exp_name == "invertible"]

In [0]:
#@title Downstream accuracy plot { display-mode: "form" }
data = data_invertible[data_invertible.Critic.isin(["Bilinear"])]

plt.figure()
ax = sns.lineplot(data=data, x="iteration", y="accuracy", hue="Estimator", ci="sd");
apply_default_style(ax)
ax.set_ylabel("Accuracy");

In [0]:
#@title MI lower bound plot { display-mode: "form" }
plt.figure()
ax = sns.lineplot(data=data, x="iteration", y="bound_value", hue="Estimator", ci="sd",);
apply_default_style(ax)
ax.set_ylim(-5, 8)
ax.set_ylabel("$I_{EST}$");

## Maximized MI and worsened downstream performance

Reproduces the second experiment of Section 3.1 and the corresponding Figure 1 (c). 

In this experiment we use invertible architectures. By adversarially training an encoder, we show that it is possible to significantly deteriorate downstream performance, even though MI is maximized for any parameter setting (due to invertibility).

In [0]:
#@title Define training loop { display-mode: "form" }

def train_adversarial(net,
                      learning_rate=1e-3,
                      batch_size=TRAIN_BATCH_SIZE,
                      n_iters=4000,
                      record_every=400,
                      data_dimension=DIMS//2):
  """Runs the adversarial training loop for a fixed model.

  Args:
    net: Function, maps input to representation.
    learning_rate: Learning rate.
    batch_size: Training batch size.
    n_iters: Number of optimization iterations.
    record_every: Evaluate the model every `record_every` steps.
    data_dimensions: The dimension of the data. By default it's half of the
      original data dimension.
  Returns:
    Returns and instance of `Results` tuple.
  """


  if net.__class__ is not RealNVPBijector:
    raise ValueError("Only implemented for the RealNVP class.")

  # Get the data and compute the representation.
  x_1, _, labels = processed_train_data(data_dimension, batch_size)
  code = net(x_1)

  with tf.variable_scope("classifier"):
    logits = tf.layers.dense(code, N_CLASSES)

  # True classification loss for linear classifier.
  loss_c = tf.nn.sparse_softmax_cross_entropy_with_logits(
      logits=logits, labels=labels)
  loss_c = tf.reduce_mean(loss_c)

  # Fake classification loss for the encoder.
  labels_unif = (1 / N_CLASSES) * tf.ones(logits.shape)
  loss_e = tf.nn.softmax_cross_entropy_with_logits(
      logits=logits, labels=labels_unif)
  loss_e = tf.reduce_mean(loss_e)

  # Setup the optimizers.
  vars_e = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="real_nvp")
  vars_c = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="classifier")

  optimizer_c = tf.train.AdamOptimizer(learning_rate=learning_rate)
  optimizer_e = tf.train.AdamOptimizer(learning_rate=learning_rate*0.01)

  optimizer_op_c = optimizer_c.minimize(loss_c, var_list=vars_c)
  optimizer_op_e = optimizer_e.minimize(loss_e, var_list=vars_e)


  with tf.Session() as session:
    session.run(tf.global_variables_initializer())
    data_ph = tf.placeholder(tf.float32, shape=[None, data_dimension])
    codes = net(data_ph)
    losses_c, losses_e, classification_accuracies, iters = [], [], [], []

    # Warm-up the linear classifier.
    for _ in range(1000):
      session.run([loss_c, optimizer_op_c])

    # Run main loop.
    for iter_n in range(n_iters):
      if iter_n % record_every == 0:
        iters.append(iter_n)
        accuracy = get_classification_accuracy(
            session, codes, data_ph, data_dimension)
        classification_accuracies.append(accuracy)
        print("Step {:>10d} fit {:>.5f}".format(iter_n, accuracy))

      # Run 10 optimization steps for the classifier.
      for _ in range(10):
        loss_c_np, _ = session.run([loss_c, optimizer_op_c])
      # Run 1 optimization steps for the encoder.
      loss_e_np, _ = session.run([loss_e, optimizer_op_e])
      losses_c.append(loss_c_np)
      losses_e.append(loss_e_np)
      if iter_n % 100 == 0:
        print("  loss_e {:>.5f} loss_c {:>.5f}".format(loss_e_np, loss_c_np))

  return ResultsAdversarial(
      losses_e=losses_e,
      losses_c=losses_c,
      classification_accuracies=classification_accuracies,
      iters=iters)


In [0]:
#@title Run experiment or load precomputed results { display-mode: "form" }

def run_all_experiments():
  tf.reset_default_graph()
  kwargs = dict(activation=lambda x: tf.nn.relu(x),
                      kernel_initializer=tf.initializers.truncated_normal(stddev=0.0001),
                      bias_initializer='zeros')
  net = RealNVPBijector(DIMS // 2, n_couplings=30, hidden_layers=[512, 512], dense_kwargs=kwargs)
  return train_adversarial(
      net, learning_rate=1e-3, n_iters=4001, 
      record_every=400, data_dimension=DIMS//2, batch_size=128)

if RUN_EXPERIMENTS:
  data_adversarial = run_all_experiments()
else:
  !wget -q -N https://storage.googleapis.com/mi_for_rl_files/adversarial_results.pkl
  with tf.gfile.Open('adversarial_results.pkl', 'rb') as f:
    data_adversarial = pickle.load(f, encoding='latin1')

In [0]:
#@title Downstream accuracy plot { display-mode: "form" }
plt.figure()
plt.plot(data_adversarial.iters, data_adversarial.classification_accuracies, linewidth=2)
ax = plt.gca()
ax.set_ylabel("Accuracy")
apply_default_style(ax)
ax.set_xlim([0, 4001])
ax.set_xticklabels([0, 0.5, 1, 1.5, 2, 2.5, 3, 3.5, 4])
leg = plt.legend(["Adversarially trained\ninvertible encoder"], prop={'size':14});

## Bias towards hard-to-invert encoders

Reproduces the third experiment of Section 3.1 and the corresponding Figures 2 (a, b, c).

In this experiment we use architectures that can be both invertible and non-invertible. We show that training to maximize the MI estimators results in the encoders becoming hard to invert, in that they become locally ill-conditioned.

In [0]:
#@title Run experiment or load precomputed results { display-mode: "form" }

def run_all_experiments():
  tf.reset_default_graph()
  infonce_loss = lambda x: -infonce_lower_bound(x)
  nwj_loss = lambda x: -nwj_lower_bound(x)
  loss_fcts = {
      "nwj": nwj_loss,
      "nce": infonce_loss
  }
  kwargs = dict(activation="relu",
                kernel_initializer=tf.initializers.truncated_normal(stddev=0.0001),
                bias_initializer="zeros")
  nets = {
      "mlp": lambda: (
          MLP([DIMS // 2] * 5, shortcuts=True, dense_kwargs=kwargs),
          MLP([DIMS // 2] * 5, shortcuts=True, dense_kwargs=kwargs))
  }
  critics = {
      "bilinear": lambda: BilinearCritic(feature_dim=DIMS//2),
  }
  return run_sweep(nets, critics, loss_fcts,"non_invertible",
                   n_iters=21000, n_evals=21, compute_jacobian=True,
                   noise_std=0.05)

if RUN_EXPERIMENTS:
  all_results = run_all_experiments()
  data_non_invertible = all_results["df"]
  non_invertible_singular_values = all_results["singular_values"]

else:
  data_non_invertible = pd.read_pickle('mi_results.pkl')
  data_non_invertible = data_non_invertible[data_non_invertible.exp_name == "non_invertible"]

  !wget -q -N https://storage.googleapis.com/mi_for_rl_files/condition_numbers_results.pkl
  with tf.gfile.Open('condition_numbers_results.pkl', 'rb') as f:
    non_invertible_singular_values = pickle.load(f, encoding='latin1')

In [0]:
#@title Downstream accuracy plot { display-mode: "form" }
data = data_non_invertible[data_non_invertible.Critic.isin(["Bilinear"])]
plt.figure()
ax = sns.lineplot(data=data, x="iteration", y="accuracy", hue="Estimator", ci="sd");
apply_default_style(ax)
ax.set_ylim([0.83, 0.9])
ax.set_ylabel("Accuracy");

In [0]:
#@title MI lower bound plot { display-mode: "form" }
plt.figure()
ax = sns.lineplot(data=data, x="iteration", y="bound_value", hue="Estimator", ci="sd");
apply_default_style(ax)
ax.set_ylim(-5, 8)
ax.set_ylabel("$I_{EST}$");

In [0]:
#@title Jacobian condition number plot { display-mode: "form" }

colors = sns.color_palette()

def percentile_plot(log_condition_number, iters, pcs = 5):
  """Create percentile plot."""
  sorted_log_condition_number = np.sort(log_condition_number, axis=1)
  n_iters, n_condition_numbers = sorted_log_condition_number.shape
  percentiles = [i / pcs for i in range(pcs + 1)]
  pc_idx = [int(p * (n_condition_numbers - 1)) for p in percentiles]

  alpha = (
      [i / (pcs / 2) for i in range(pcs//2 + 1)] +
      [2 - (i / (pcs / 2)) for i in range(pcs//2 + 1, pcs + 1)]
  )
  alpha = [a + 0.05 for a in alpha]
  alpha = [a / max(alpha) for a in alpha]

  plt.plot(
    iters[: n_iters], 
    sorted_log_condition_number[:, pc_idx[0]],
    color="gray", alpha=1, lw=2,
    label="Minimum",
    linestyle="--"
  )
  for i in range(len(pc_idx) - 1):
    p1, p2 = pc_idx[i], pc_idx[i + 1]
    plt.fill_between(
        iters[: n_iters], 
        sorted_log_condition_number[:, p1],
        sorted_log_condition_number[:, p2],
        color=colors[i], alpha=0.75)
    if i != 4:
      plt.plot(
        iters[: n_iters], 
        sorted_log_condition_number[:, pc_idx[i + 1]],
        color=colors[i], alpha=1, lw=2,
        label="%.0fth perc." % ((100/pcs) * (i+1)),
        linestyle="--"
        )
  plt.plot(
    iters[: n_iters], 
    sorted_log_condition_number[:, pc_idx[-1]],
    color=colors[4], alpha=1, lw=2,
    label="Maximum",
    linestyle="--"
  )

  apply_default_style(plt.gca())

# As the Jacobian singular values are a batch_size x input_dim matrix per
# iteration, we need a separate routine to extract the corresponding data
# and aggregate it
def aggregate_singular_values(configs_and_results):
  data_eval = {}
  for (config, results_all_runs) in configs_and_results:
    label = "{}, {}, {}".format(config.nets, config.critic, config.loss)
    condition_numbers_runs = []
    for run_number, results in enumerate(results_all_runs):
      stacked_singular_values = np.stack(results.singular_values)
      sorted_singular_values = np.sort(stacked_singular_values, axis=-1)
      log_condition_numbers = np.log(sorted_singular_values[..., -1]) \
                              - np.log(sorted_singular_values[..., 0])
      condition_numbers_runs.append(log_condition_numbers)
    if len(results_all_runs) > 0:
      iterations = results_all_runs[0].iterations
      condition_numbers = np.concatenate(condition_numbers_runs, axis=1)
      data_eval[label] = (iterations, condition_numbers)
  return data_eval

results_dict = aggregate_singular_values(non_invertible_singular_values)

for key in results_dict.keys():
  if "bilinear" not in key:
    continue
  plt.figure()
  its = results_dict[key][0]
  condnbrs = results_dict[key][1]
  percentile_plot(condnbrs, its)
  if "nce" in key:
    plt.ylabel(r"Jacobian, $log\ (\sigma_1\ /\ \sigma_{392})$");
    plt.legend(loc="lower right", ncol=2)
    plt.ylim([0, 8])
  else:
    plt.ylabel("Jacobian, $log\ (\sigma_1\ /\ \sigma_{392})$");
    plt.legend(loc="upper left", ncol=2)
    plt.ylim([0, 8])

## Looser bounds with simpler critics can lead to better representations

Reproduces the experiment from Section 3.2 and the corresponding Figure 3.

In this experiment we investigate the effect of the critic architecture on downstream performance. We find that simpler critics can result in better performance, despite leading to looser MI bounds.

In [0]:
#@title Run experiment or load precomputed results { display-mode: "form" }

def run_all_experiments():
  tf.reset_default_graph()
  infonce_loss = lambda x: -infonce_lower_bound(x)
  nwj_loss = lambda x: -nwj_lower_bound(x)
  loss_fcts = {"nwj": nwj_loss, "nce": infonce_loss}
  nets = {
      "mlp": lambda: (MLP([300, 300, 100], False, {"activation": "relu"}),
                      MLP([300, 300, 100], False, {"activation": "relu"})),
  }
  critics = {
      "concat": lambda: ConcatCritic(),
      "bilinear": lambda: BilinearCritic(),
      "separable": lambda: SeparableCritic(layers=1),
  }
  return run_sweep(
    nets, critics, loss_fcts, "critic_impact", n_iters=21000, n_evals=21)

if RUN_EXPERIMENTS:
  data_critic_impact = run_all_experiments()["df"]
else:
  data_critic_impact = pd.read_pickle('mi_results.pkl')
  data_critic_impact = data_critic_impact[data_critic_impact.exp_name == "critic_impact"]

In [0]:
#@title Downstream accuracy plot { display-mode: "form" }
data = data_critic_impact
data = data[data.Critic.isin(["Bilinear", "MLP", "Separable"])]
data_nwj = data[data.Estimator.isin(["$I_{NWJ}$"])]
data_nce = data[data.Estimator.isin(["$I_{NCE}$"])]

plt.figure()
ax = sns.lineplot(data=data_nwj, x="iteration", y="accuracy", hue="Critic", ci="sd");
apply_default_style(ax)
ax.set_ylabel("Accuracy with $I_{NWJ}$")
ax.set_ylim([0.8, 0.9])
plt.figure()
ax = sns.lineplot(data=data_nce, x="iteration", y="accuracy", hue="Critic", ci="sd");
apply_default_style(ax)
ax.set_ylim([0.8, 0.9])
ax.set_ylabel("Accuracy with $I_{NCE}$");

In [0]:
#@title MI lower bound plot { display-mode: "form" }
plt.figure()
ax = sns.lineplot(data=data_nwj, x="iteration", y="bound_value", hue="Critic", ci="sd");
apply_default_style(ax)
ax.set_ylim(2)
plt.ylabel("$I_{NWJ}$")
plt.figure()
ax = sns.lineplot(data=data_nce, x="iteration", y="bound_value", hue="Critic", ci="sd");
apply_default_style(ax)
ax.set_ylim(2)
ax.set_ylabel("$I_{NCE}$");

## Encoder architecture can be more important that the specific estimator

Reproduces the experiment from Section 3.3 and the corresponding Figures 4 (a, b).

In this experiment we show that the choice of encoder architecture can have more impact on downstream performance than the specific estimator used.

In [0]:
#@title Run experiment or load precomputed results { display-mode: "form" }

def run_all_experiments():
  tf.reset_default_graph()
  loss_fcts = {}
  loss_fcts_est = {'nwj': nwj_lower_bound, 'nce': infonce_lower_bound}

  def loss_target_fn(x, fn, t):
    return tf.abs(fn(x) - t)

  for target in [4, 2]:
    for loss_name, loss_fn in loss_fcts_est.items():
      loss_fcts['{}-{}'.format(loss_name, target)] = functools.partial(
          loss_target_fn, fn=loss_fn, t=target)

  nets = {
      "convnet": lambda: (ConvNet(), ConvNet()),
      "mlp": lambda: (MLP([300, 300, 100], False, {"activation": "relu"}),
                      MLP([300, 300, 100], False, {"activation": "relu"})),
  }

  critics = {
      "bilinear": lambda: BilinearCritic(),
  }
  return run_sweep(nets, critics, loss_fcts, "encoder_impact", n_iters=21000,
                   n_evals=21)

if RUN_EXPERIMENTS:
  data_encoder_impact = run_all_experiments()["df"]
else:
  data_encoder_impact = pd.read_pickle('mi_results.pkl')
  data_encoder_impact = data_encoder_impact[data_encoder_impact.exp_name == "encoder_impact"]

In [0]:
#@title Downstream accuracy and testing loss plots { display-mode: "form" }
data = data_encoder_impact
data = data[data.Critic == "Bilinear"].copy()

data["label"].replace(to_replace={
    "convnet, bilinear, nwj-2": "ConvNet $(I_{NWJ}, t=2)$",
    "convnet, bilinear, nwj-4": "ConvNet $(I_{NWJ}, t=4)$",
    "mlp, bilinear, nwj-2": "MLP $(I_{NWJ}, t=2)$",
    "mlp, bilinear, nwj-4": "MLP $(I_{NWJ}, t=4)$",
    "convnet, bilinear, nce-2": "ConvNet $(I_{NCE}, t=2)$",
    "convnet, bilinear, nce-4": "ConvNet $(I_{NCE}, t=4)$",
    "mlp, bilinear, nce-2": "MLP $(I_{NCE}, t=2)$",
    "mlp, bilinear, nce-4": "MLP $(I_{NCE}, t=4)$",
  }, inplace=True)

# We are trying to reach a given bound of t, hence it is minimized
data["bound_value"] *= -1
data_nwj = data[data.Estimator.isin(["nwj-2", "nwj-4"])]
data_nce = data[data.Estimator.isin(["nce-2", "nce-4"])]

del data # Make sure that `data` is not used by accident below.

hue_ordering = np.unique(data_nwj.label.values)

plt.rcParams.update({'legend.fontsize': 13})
plt.figure()
ax = sns.lineplot(data=data_nwj, x="iteration", y="accuracy", hue="label", ci="sd", hue_order=hue_ordering)
apply_default_style(ax)
ax.set_ylim([0.78, 0.92])
ax.set_ylabel("Accuracy with $I_{NWJ}$")

hue_ordering = np.unique(data_nce.label.values)

plt.figure()
ax = sns.lineplot(data=data_nce, x="iteration", y="accuracy", hue="label", ci="sd", hue_order=hue_ordering)
apply_default_style(ax)
ax.set_ylim([0.78, 0.92])
ax.set_ylabel("Accuracy with $I_{NCE}$");
plt.rcParams.update({'legend.fontsize': FONTSIZE})

# Loss values
plt.figure()
hue_ordering = np.unique(data_nwj.label.values)

ax = sns.lineplot(data=data_nwj, x="iteration", y="bound_value", hue="label", ci="sd",
                  hue_order=hue_ordering)
apply_default_style(ax)
handles, labels = ax.get_legend_handles_labels()
plt.legend(loc="upper right", handles=handles[1:], labels=labels[1:])
ax.set_ylabel("$L_t(g_1, g_2), I_{NWJ}$")

plt.figure()
hue_ordering = np.unique(data_nce.label.values)
ax = sns.lineplot(data=data_nce, x="iteration", y="bound_value", hue="label", ci="sd",
                  hue_order=hue_ordering)
apply_default_style(ax)
handles, labels = ax.get_legend_handles_labels()
plt.legend(loc="upper right", handles=handles[1:], labels=labels[1:])
ax.set_ylabel("$L_t(g_1, g_2), I_{NCE}$");

## InfoNCE and the importance of negative sampling

Reproduces the experiment in Section 4 and the corresponding Figure 4 (c).

In this experiment we show empirically that both $I_{NCE}$ and $I_{NWJ}$ estimators are not in general lower bounds on MI when samples are not drawn iid.

The InfoNCE objective is only provably a lower bound on the true mutual information if all of the samples $(X_i, Y_i)$ are drawn iid from the joint distribution $p(x,y)$. Here we demonstrate in with a simple synthetic example that when the $(X_i, Y_i)$ are drawn in a dependent fashion, InfoNCE can actually be larger than the true mutual information.

We will draw a batch $(X_i, Y_i)$ as follows.

First sample $Z \sim \mathcal{N}\left(0, \begin{bmatrix}1 & -0.5\\ -0.5 & 1\end{bmatrix}\right)$.

Then sample $\epsilon_i \sim \mathcal{N}\left(0, \begin{bmatrix}1 & 0.9\\ 0.9 & 1\end{bmatrix}\right)$ iid, and set $(X_i, Y_i) = Z + \epsilon_i$.



Then each $(X_i, Y_i)$ has marginal distribution $\mathcal{N}\left(0, \begin{bmatrix}2 & 0.4\\ 0.4 & 2\end{bmatrix}\right)$, but note that the samples within a batch are dependent.

For a bivariate Gaussian $(X,Y) \sim \mathcal{N}\left(0, \Sigma\right)$ we have that $I(X,Y) = -0.5\log(1-\rho^2)$  where $\rho^2 = \Sigma_{12}\Sigma_{21} / \Sigma_{11}\Sigma_{22}$.

In [0]:
#@title Define training loop { display-mode: "form" }

def train_bound(xy,
                estimator="nce",
                hidden_dim=10,
                layers=5,
                learning_rate=1e-4,
                n_iters=20000):
  """Estimates the MI lower-bound using a simple concat critic."""

  if estimator not in ["nce", "nwj"]:
    raise ValueError(
        "estimator must be one of 'nce', 'nwj', not: {}".format(estimator))

  critic = ConcatCritic(hidden_dim=hidden_dim, layers=layers, activation='relu')
  scores = critic(xy[:, 0, None], xy[:, 1, None])

  if estimator == "nce":
    bound = infonce_lower_bound(scores)
  else:
    bound = nwj_lower_bound(scores)

  # Optimizer setup.
  optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
  optimizer_op = optimizer.minimize(-bound)

  # Main training loop
  with tf.Session() as session:
    session.run(tf.global_variables_initializer())
    bound_estimates = []
    for iter_n in range(n_iters):
      bound_np, _ = session.run([bound, optimizer_op])
      bound_estimates.append(bound_np)
      if iter_n % 1000 == 0:
        print("Step {:>10d} {} {:>.5f}".format(iter_n, estimator, bound_np))
    return bound_estimates

def mi_from_sigma(sigma):
  rho_sq = sigma[0,1] * sigma[1,0] / (sigma[0,0] * sigma[1,1])
  return -0.5 * np.log(1 - (rho_sq))

def run_all_experiments():

  bs = TRAIN_BATCH_SIZE
  sigma_z = np.array([[1.0, -0.5],[-0.5, 1.0]])
  sigma_eps = np.array([[1.0, 0.9],[0.9, 1.0]], dtype=np.float64)

  z = tfd.MultivariateNormalFullCovariance(loc=0, covariance_matrix=sigma_z)
  eps = tfd.MultivariateNormalFullCovariance(loc=0, covariance_matrix=sigma_eps)

  z_sample = tf.cast(z.sample(1), tf.float32)
  eps_sample = tf.cast(eps.sample(bs), tf.float32)

  xy = z_sample + eps_sample

  mi_true = mi_from_sigma(sigma_eps + sigma_z)

  # Let's estimate the MI using InfoNCE with our non-iid samples
  nce_estimates = train_bound(xy, 'nce')

  # We'll also estimate the MI using the NWJ estimator.
  nwj_estimates = train_bound(xy, 'nwj')

  # Now as a sanity check, let's also evaluate the InfoNCE estimator using proper iid samples
  sigma_xy = sigma_z + sigma_eps

  xy_iid = tfd.MultivariateNormalFullCovariance(
      loc=0, covariance_matrix=sigma_xy).sample(bs)
  xy_iid = tf.cast(xy_iid, tf.float32)

  # Compute the estimates using IID samples.
  nce_estimates_iid = train_bound(xy_iid, 'nce')
  nwj_estimates_iid = train_bound(xy_iid, 'nwj')

  return ResultsSamplingIssues(
      mi_true=mi_true,
      nce_estimates_noniid=nce_estimates,
      nce_estimates_iid=nce_estimates_iid,
      nwj_estimates_noniid=nwj_estimates,
      nwj_estimates_iid=nwj_estimates_iid)

In [0]:
#@title Run experiment or load precomputed results { display-mode: "form" }
if RUN_EXPERIMENTS:
  data_noniid_sampling = run_all_experiments()
else:
  !wget -q -N https://storage.googleapis.com/mi_for_rl_files/noniid_results.pkl
  with tf.gfile.Open('noniid_results.pkl', 'rb') as f:
    data_noniid_sampling = pickle.load(f, encoding='latin1')

In [0]:
#@title i.i.d. vs non-i.i.d. sampling plot { display-mode: "form" }
results = data_noniid_sampling
steps = [i for i in range(len(results.nce_estimates_iid))]

plt.rcParams.update({'axes.labelsize': FONTSIZE,
                     'xtick.labelsize': FONTSIZE,
                     'ytick.labelsize': FONTSIZE,
                     'legend.fontsize': FONTSIZE,
                     'lines.linewidth': 2})
plt.figure()
ax = plt.gca()
plt.axhline(y=results.mi_true, ls='--', color='k', label="True MI")

ax.plot(steps, gaussian_filter1d(results.nce_estimates_noniid, 100),
        label="$I_{NCE}$, non-i.i.d. samples")
ax.plot(steps, gaussian_filter1d(results.nce_estimates_iid, 100),
        label="$I_{NCE}$, I.i.d. samples")

steps = [i for i in range(len(results.nwj_estimates_iid))]

ax.plot(steps, gaussian_filter1d(results.nwj_estimates_noniid, 100),
        label="$I_{NWJ}$, non-i.i.d. samples")
ax.plot(steps, gaussian_filter1d(results.nwj_estimates_iid, 100),
        label="$I_{NWJ}$, i.i.d. samples")

apply_default_style(ax)
ax.set_ylabel('$I_{EST}$')
ax.set_ylim(-0.3, 0.25)
plt.legend(loc="lower right", prop={'size':13})