# Mutual Information Neural Estimation

$\def\abs#1{\left\lvert #1 \right\rvert}
\def\Set#1{\left\{ #1 \right\}}
\def\mc#1{\mathcal{#1}}
\def\M#1{\boldsymbol{#1}}
\def\R#1{\mathsf{#1}}
\def\RM#1{\boldsymbol{\mathsf{#1}}}
\def\op#1{\operatorname{#1}}
\def\E{\op{E}}
\def\d{\mathrm{\mathstrut d}}$

In [None]:
from dv import *
import matplotlib.pyplot as plt
from IPython import display
import ipywidgets as widgets
import pandas as pd
import tensorboard as tb

%load_ext tensorboard
%matplotlib inline

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


**How to estimate MI via KL divergence?**

One way is to obtain MI {eq}`MI` from KL divergence {eq}`D` as follows:

$$
\begin{align*}
I(\R{X}\wedge \R{Y}) = D(\underbrace{P_{\R{X},\R{Y}}}_{P_{\R{Z}}}\| \underbrace{P_{\R{X}}\times P_{\R{Y}}}_{P_{\R{Z}'}}).
\end{align*}
$$

Since both $P_{\R{Z}}$ and $P_{\R{Z}'}$ are unknown, we can apply {eq}`avg-DV` to estimate the divergence.

$$
\begin{align}
I(\R{X}\wedge \R{Y}) \approx \sup_{t: \mc{Z} \to \mathbb{R}} \frac1n \sum_{i\in [n]} t(\R{X}_i,\R{Y}_i) - \frac1{n'}\sum_{i\in [n']} e^{t(\R{X}'_i,\R{Y}'_i)}
\end{align}
$$ (MINE)

where $P_{\R{X}',\R{Y}'}:=P_{\R{X}}\times P_{\R{Y}}$.

**But how to obtain the reference samples ${\R{Z}'}^{n'}$, i.e., ${\R{X}'}^{n'}$ and ${\R{Y}'}^{n'}$?**

## Resampling trick

We can approximate the i.i.d. sampling of $P_{\R{X}}\times P_{\R{Y}}$ using samples from $P_{\R{X},\R{Y}}$ by a re-sampling trick:

$$
\begin{align}
P_{\R{Z}'^{n'}} &\approx P_{((\R{X}_{\R{J}_i},\R{Y}_{\R{K}_i})\mid i \in [n'])}
\end{align}
$$ (resample)

where $\R{J}_i$ and $\R{K}_i$ for $i\in [n']$ are independent and uniformly random indices

$$
P_{\R{J},\R{K}} = \op{Uniform}_{[n]\times [n]}
$$

and $[n]:=\Set{1,\dots,n}$.

MINE {cite}`belghazi2018mine` uses the following implementation that samples $(\R{J},\R{K})$ but without replacement. You can change $n'$ using the slider for `n_`.

In [18]:
rng = np.random.default_rng(SEED)


def resample(data, size, replace=False):
    index = rng.choice(range(data.shape[0]), size=size, replace=replace)
    return data[index]


@widgets.interact
def plot_resampled_data_without_replacement(n_=(2, n)):
    XY_ = np.block([resample(XY[:, [0]], n_), resample(XY[:, [1]], n_)])
    resampled_data = pd.DataFrame(XY_, columns=["X'", "Y'"])
    p_ = plot_samples_with_kde(resampled_data)
    plt.show()

interactive(children=(IntSlider(value=501, description='n_', max=1000, min=2), Output()), _dom_classes=('widge…

In the above, we defined the function `resample` that 
- uses `choice` to uniformly randomly select a choice
- from a `range` of integers going from `0` to 
- the size of the first dimension of the `data`.

Note however that the sampling is without replacement.

**Is it a good idea to sample without replacement?**

**Exercise** To allow $n>n'$, we need to sample the index with replacement. Complete the following code and observe what happens when $n \gg n'$.

In [19]:
@widgets.interact
def plot_resampled_data_with_replacement(
    n_=widgets.IntSlider(n, 2, 10 * n, continuous_update=False)
):
    ### BEGIN SOLUTION
    XY_ = np.block(
        [resample(XY[:, [0]], n_, replace=True), resample(XY[:, [1]], n_, replace=True)]
    )
    ### END SOLUTION
    resampled_data = pd.DataFrame(XY_, columns=["X'", "Y'"])
    p_ = plot_samples_with_kde(resampled_data)
    plt.show()

interactive(children=(IntSlider(value=1000, continuous_update=False, description='n_', max=10000, min=2), Outp…

**Exercise** 

Explain whether the resampling trick gives i.i.d. samples $(\R{X}_{\R{J}_i},\R{Y}_{\R{K}_i})$ for the cases with replacement and without replacement respectively?

**Solution** 

The samples are identically distributed. However, they are not independent except in the trivial case $n=1$ or $n'=1$, regardless of whether the sample is with replacement or not. Consider $n=1$ and $n'=2$ as an example.

## Neural Estimation

We have encapsulated the neural estimation of KL divergence in the previous notebook by defining a class `DVTrainer` in `dv.py`:

In [4]:
DVTrainer??

[0;31mInit signature:[0m [0mDVTrainer[0m[0;34m([0m[0mZ[0m[0;34m,[0m [0mZ_ref[0m[0;34m,[0m [0mnet[0m[0;34m,[0m [0mn_iters_per_epoch[0m[0;34m,[0m [0mwriter_params[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m      <no docstring>
[0;31mSource:[0m        
[0;32mclass[0m [0mDVTrainer[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0;32mdef[0m [0m__init__[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mZ[0m[0;34m,[0m [0mZ_ref[0m[0;34m,[0m [0mnet[0m[0;34m,[0m [0mn_iters_per_epoch[0m[0;34m,[0m [0mwriter_params[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0;31m# set optimizer[0m[0;34m[0m
[0;34m[0m        [0mself[0m[0;34m.[0m[0moptimizer[0m [0;34m=[0m [0moptim[0m[0;34m.[0m[0mAdam[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mnet[0m[0;34m.[0m[0mparameters[0m[0;34m([0m[0;34

In [12]:
dir(rng)

['__class__',
 '__delattr__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__ne__',
 '__new__',
 '__pyx_vtable__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '_bit_generator',
 '_poisson_lam_max',
 'beta',
 'binomial',
 'bit_generator',
 'bytes',
 'chisquare',
 'choice',
 'dirichlet',
 'exponential',
 'f',
 'gamma',
 'geometric',
 'gumbel',
 'hypergeometric',
 'integers',
 'laplace',
 'logistic',
 'lognormal',
 'logseries',
 'multinomial',
 'multivariate_hypergeometric',
 'multivariate_normal',
 'negative_binomial',
 'noncentral_chisquare',
 'noncentral_f',
 'normal',
 'pareto',
 'permutation',
 'poisson',
 'power',
 'random',
 'rayleigh',
 'shuffle',
 'standard_cauchy',
 'standard_exponential',
 'standard_gamma',
 'standard_normal',
 'standard_t',
 'triangular',
 'un

In [2]:
trainer = DVTrainer(Z, Z_ref, net, n_iters_per_epoch=10)

In [9]:
trainer.step(1000)

-17986.154296875

In [10]:
%tensorboard --logdir=runs

Reusing TensorBoard on port 6006 (pid 19510), started 0:07:57 ago. (Use '!kill 19510' to kill it.)