# Information Bottleneck theory for Deep Learning


This is a demonstration of the information bottleneck theory for deep learning, introduced by Naftali Tishby. Here I tried to reproduce the main results in their recent paper [Opening the black box of Deep Neural Networks via Information](https://arxiv.org/pdf/1703.00810.pdf).


## Data generation


First, we will generate a very simple dataset for the demonstration. The inputs are vectors of 10 binaries, and the outputs are just single binaries. The inputs could be represented by integers from 0 to 1023 ($=2^{10}-1$). The 1024 possible inputs are divided into 16 groups (each group has 64 numbers), and each integer input $n\in[0,1023]$ belongs to group $i$ if $x\equiv i \pmod{16}$, where $i \in [0,15]$. Each group $i$ is then associated with a random binary number (output).


In [2]:
!pip install equinox jax numpy optax matplotlib

Collecting equinox
  Downloading equinox-0.13.2-py3-none-any.whl.metadata (19 kB)
Collecting jax
  Using cached jax-0.7.2-py3-none-any.whl.metadata (13 kB)
Collecting numpy
  Using cached numpy-2.3.3-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (62 kB)
Collecting optax
  Using cached optax-0.2.6-py3-none-any.whl.metadata (7.6 kB)
Collecting matplotlib
  Downloading matplotlib-3.10.7-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (11 kB)
Collecting jaxtyping>=0.2.20 (from equinox)
  Downloading jaxtyping-0.3.3-py3-none-any.whl.metadata (7.8 kB)
Collecting typing-extensions>=4.5.0 (from equinox)
  Using cached typing_extensions-4.15.0-py3-none-any.whl.metadata (3.3 kB)
Collecting wadler-lindig>=0.1.0 (from equinox)
  Using cached wadler_lindig-0.1.7-py3-none-any.whl.metadata (17 kB)
Collecting jaxlib<=0.7.2,>=0.7.2 (from jax)
  Using cached jaxlib-0.7.2-cp313-cp313-manylinux_2_27_x86_64.whl.metadata (1.3 kB)
Collecting ml_dtypes>=0.5.0 (from j

In [None]:
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np
import optax  # https://github.com/deepmind/optax
import numpy as np
from random import randint, seed

%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML

In [None]:
seed = 123456
groups = np.append(np.zeros(8),np.ones(8)) # 16 groups
np.random.seed(seed)
np.random.shuffle(groups)

n_train_samples = 50000
n_test_samples = 10000

def generate_samples(n_samples):
    x_data = np.zeros((n_samples, 10)) # inputs
    x_int = np.zeros(n_samples) # integers representing the inputs
    y_data = np.zeros((n_samples, 2)) # outputs
    
    for i in range(n_samples):
        random_int = randint(0, 1023)
        x_data[i,:] = [int(b) for b in list("{0:b}".format(random_int).zfill(10))]
        x_int[i] = random_int
        y_data[i,0] = groups[random_int % 16]
        y_data[i,1] = 1 - y_data[i,0]
        
    return x_data, y_data, x_int

x_train, y_train, x_train_int = generate_samples(n_train_samples) # training dataset
x_test, y_test, _ = generate_samples(n_test_samples) # testing dataset

x_train = jnp.array(x_train)
y_train = jnp.array(y_train)
x_int_train = jnp.array(x_train_int)
x_test = jnp.array(x_test)
y_test = jnp.array(y_test)

For our dataset, the theoretical mutual information between $X$ and $Y$ would be
\begin{align}
I(X;Y) & = \sum_{x\in X, y\in Y}P(x,y)\log\Big(\frac{P(x,y)}{P(x)P(y)}\Big) \\
& = \sum_{x\in X}\Big[P(x,y=0)\log\Big(\frac{P(x,y=0)}{P(x)P(y=0)}\Big) + P(x,y=1)\log\Big(\frac{P(x,y=1)}{P(x)P(y=1)}\Big)\Big] \\
& = 1024 \Big[ \frac{1}{1024}\log\Big(\frac{1/1024}{0.5/1024}\Big) + 0\Big] \\
& = 0.693.
\end{align}
Note that terms with $P(x,y)=0$ are set to $0$ for entropy calculation. All calculations and code below use JAX arrays.


## Neural network


It's time to build our network using JAX and Equinox!


In [36]:
# Define the MLP model using Equinox
class MLP(eqx.Module):
    layers: list

    def __init__(self, input_size, hidden_layer_neurons, key):
        keys = jax.random.split(key, len(hidden_layer_neurons) + 1)
        self.layers = [eqx.nn.Linear(input_size, hidden_layer_neurons[0], key=keys[0])]
        self.layers += [eqx.nn.Linear(hidden_layer_neurons[i], hidden_layer_neurons[i + 1], key=keys[i + 1]) for i in range(len(hidden_layer_neurons) - 1)]
        self.layers = [*self.layers, eqx.nn.Linear(hidden_layer_neurons[-1], 2, key=keys[-1])]

    def __call__(self, x):
        for layer in self.layers[:-1]:
            x = jax.nn.tanh(layer(x))  # Use tanh as in original
        return self.layers[-1](x)

@eqx.filter_jit
def loss_fn(model, x, y):
    logits = jax.vmap(model)(x)
    return optax.sigmoid_binary_cross_entropy(logits, y).mean()

loss_and_grad_fn = eqx.filter_jit(eqx.filter_value_and_grad(loss_fn))

@jax.jit
def compute_accuracy(model, x, y):
    logits = jax.vmap(model)(x)
    preds = jnp.argmax(logits, axis=1)
    targets = jnp.argmax(y, axis=1)
    return jnp.mean(preds == targets)

def print_out_summary(model, x_test, y_test, epoch):
    acc = compute_accuracy(model, x_test, y_test)
    loss = loss_fn(model, x_test, y_test)
    print(f"Epoch {epoch:>4}:  Testing accuracy {float(acc):.4f} - Testing loss {float(loss):.4f}")

@eqx.filter_jit
def train_step(model, optimizer, x_batch, y_batch, opt_state):
    loss, grads = loss_and_grad_fn(model, x_batch, y_batch)
    updates, opt_state = optimizer.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return model, opt_state, loss

Let's first train the network without calculating the mutual information to make sure we have the correct results. We'll use JAX and Optax for the training loop.


In [27]:
#n_epochs = 300
#batch_size = 250
#
#now = time.perf_counter()
#for epoch in range(n_epochs):
#    # Shuffle the training data at the start of each epoch
#    perm = jax.random.permutation(jrandom.PRNGKey(epoch), x_train.shape[0])
#    x_train_shuffled = x_train[perm]
#    y_train_shuffled = y_train[perm]
#
#    for i in range(0, x_train.shape[0], batch_size):
#        x_batch = x_train_shuffled[i:i+batch_size] # shape (batch_size, 10)
#        y_batch = y_train_shuffled[i:i+batch_size] # shape (batch_size, 2)
#        model, opt_state, loss = train_step(model, x_batch, y_batch, opt_state)
#
#    if epoch % 10 == 0:
#        acc = compute_accuracy(model, x_test, y_test)
#        loss = loss_fn(model, x_test, y_test)
#        print(f"Epoch {epoch:>4}:  Testing accuracy {acc:.4f} - Testing loss {loss:.4f}")
#print("Training time:", time.perf_counter() - now)

## Mutual information


Now we are ready to explore the information bottleneck theory for our network. To estimate the mutual information between all the hidden layers and intput/output layers, we could binned the output activations as stated in the paper (here we choose 30 bins, the same as in the paper), so that the hidden layer random variables $T_i$ (each $i$ corresponds to one hidden layer) would be discrete. Then, we will be able to estimate the joint distribution $P(X,T_i)$ and $P(T_i,Y)$, and use them to calculate the encoder mutual information (between input $X$ and hidden layer $T_i$)
\begin{equation}
I(X;T*i) = \sum*{x\in X, t\in T_i}P(x,t)\log\Big(\frac{P(x,t)}{P(x)P(t)}\Big)
\end{equation}

and decoder mutual information (between hidden layer $T_i$ and desired output $Y$, note that it is not the model output $\widehat{Y}$)
\begin{equation}
I(T*i;Y) = \sum*{t\in T_i, y\in Y}P(t,y)\log\Big(\frac{P(t,y)}{P(t)P(y)}\Big).
\end{equation}


In [37]:
from collections import Counter

def calc_mutual_information(hidden):
    # discretization 
    n_bins = 30
    bins = jnp.linspace(-1, 1, n_bins+1)
    indices = jnp.digitize(hidden, bins)

    # Convert JAX arrays to NumPy arrays before looping
    indices_np = np.array(indices)
    x_int_train_np = np.array(x_int_train)
    y_train_np = np.array(y_train)

    # initialize pdfs
    pdf_x, pdf_y, pdf_t = Counter(), Counter(), Counter()
    pdf_xt, pdf_yt = Counter(), Counter()

    density = 1/float(n_train_samples)
    for i in range(n_train_samples):
        # Use numpy arrays to create hashable tuple keys
        key_tuple = tuple(indices_np[i,:])
        pdf_x[int(x_int_train_np[i])] += density
        pdf_y[int(y_train_np[i,0])] += density
        pdf_xt[(int(x_int_train_np[i]),) + key_tuple] += density
        pdf_yt[(int(y_train_np[i,0]),) + key_tuple] += density
        pdf_t[key_tuple] += density

    # calcuate encoder mutual information I(X;T)
    mi_xt = 0
    for i in pdf_xt:
        # P(x,t), P(x) and P(t)
        p_xt = pdf_xt[i]; p_x = pdf_x[i[0]]; p_t = pdf_t[i[1:]]
        # I(X;T)
        mi_xt += p_xt * np.log(p_xt / p_x / p_t)

    # calculate decoder mutual information I(T;Y)
    mi_ty = 0
    for i in pdf_yt:
        # P(t,y), P(t) and P(y)
        p_yt = pdf_yt[i]; p_t = pdf_t[i[1:]]; p_y = pdf_y[i[0]]
        # I(T;Y)
        mi_ty += p_yt * np.log(p_yt / p_t / p_y)

    return mi_xt, mi_ty

# get mutual information for all hidden layers
def get_mutual_information(hiddens):
    mi_xt_list = []; mi_ty_list = []
    for hidden in hiddens:
        mi_xt, mi_ty = calc_mutual_information(hidden)
        mi_xt_list.append(mi_xt)
        mi_ty_list.append(mi_ty)
    return mi_xt_list, mi_ty_list

We are now able to estimate the mutual information while training the network using JAX. We'll save the mutual information for later use.


In [38]:
def get_hidden_layers(model, x):
    activations = []
    h = x
    for lyr in model.layers[:-1]:
        # Apply the layer to each sample in the batch
        h = jax.vmap(lyr)(h)
        h = jax.nn.tanh(h)
        activations.append(h)
    return activations

# train the neural network and obtain mutual information
def train_with_mi(n_epochs, model, x_train, y_train):
    mi_xt_all, mi_ty_all, epochs = [], [], []
    learning_rate = 0.1
    optimizer = optax.sgd(learning_rate)
    opt_state = optimizer.init(model)

    for epoch in range(n_epochs):
        model, opt_state, loss = train_step(model, optimizer, jnp.array(x_train), jnp.array(y_train), opt_state)
        if epoch % 200 == 0:
            print_out_summary(model, jnp.array(x_test), jnp.array(y_test), epoch)
        if epoch % 20 == 0:
            hiddens = get_hidden_layers(model, jnp.array(x_train))
            mi_xt, mi_ty = get_mutual_information([jnp.array(h) for h in hiddens])
            mi_xt_all.append(mi_xt)
            mi_ty_all.append(mi_ty)
            epochs.append(epoch)
    return np.array(mi_xt_all), np.array(mi_ty_all), np.array(epochs)

n_epochs = 3000
batch_size = 100
hidden_layers = [8, 6, 4]
key = jax.random.PRNGKey(12345)
model = MLP(10, hidden_layers, key)
mi_xt_all, mi_ty_all, epochs = train_with_mi(n_epochs, model, x_train, y_train)

Epoch    0:  Testing accuracy 0.4974 - Testing loss 0.7299
Epoch  200:  Testing accuracy 0.5830 - Testing loss 0.6904
Epoch  400:  Testing accuracy 0.6313 - Testing loss 0.6777
Epoch  600:  Testing accuracy 0.7402 - Testing loss 0.6183
Epoch  800:  Testing accuracy 0.7546 - Testing loss 0.5660
Epoch 1000:  Testing accuracy 0.8205 - Testing loss 0.5092
Epoch 1200:  Testing accuracy 0.8205 - Testing loss 0.4794
Epoch 1400:  Testing accuracy 0.8205 - Testing loss 0.4651
Epoch 1600:  Testing accuracy 0.8205 - Testing loss 0.4527
Epoch 1800:  Testing accuracy 0.8205 - Testing loss 0.4286
Epoch 2000:  Testing accuracy 0.8205 - Testing loss 0.3589
Epoch 2200:  Testing accuracy 0.9374 - Testing loss 0.2601
Epoch 2400:  Testing accuracy 0.9374 - Testing loss 0.1877
Epoch 2600:  Testing accuracy 0.9865 - Testing loss 0.1173
Epoch 2800:  Testing accuracy 1.0000 - Testing loss 0.0801


## Visualization


Below is a movie showing the evolution of the hidden layers with the training epochs in the information plane. We can clearly see two distinct optimization phases in the information plane, as discussed in the paper. During the first _empirical error minimization_ (ERM) phase (until around epoch 1500), the information on the outputs $I_Y$ increases quickly, and then during the second _representation compression_ phase (from around epoch 1500 onwards), the information on the inputs $I_X$ decreases. The evolution is not as smooth as shown in the paper, because it is the result of only one network, instead of average of multiple networks.


In [39]:
fig, ax = plt.subplots(figsize=(8,8))
ax.set_xlim((3,7))
ax.set_ylim((0.1,0.7))
ax.set_xlabel('I(X;T)')
ax.set_ylabel('I(T;Y)')
title = ax.set_title('')
plt.close(fig)
cmap = plt.cm.get_cmap('cool')

def animate(i):
    title.set_text('Epoch %s' % str(epochs[i]).zfill(4))
    ax.plot(mi_xt_all[i,:], mi_ty_all[i,:], 'k-',alpha=0.2)
    if i > 0:
        for j in range(len(hidden_layers)):
            ax.plot(mi_xt_all[(i-1):(i+1),j],mi_ty_all[(i-1):(i+1),j],'.-',c=cmap(j*.2),ms=10)
    return

anim = animation.FuncAnimation(
    fig,
    animate,
    init_func=None,
    frames=len(epochs),
    interval=100
)
HTML(anim.to_html5_video())

  cmap = plt.cm.get_cmap('cool')


**Note**: Github cannot render the movies embedded in the notebook. You could view the movies from this notebook through nbviewer: https://nbviewer.jupyter.org/github/stevenliuyi/information-bottleneck/blob/master/information_bottleneck.ipynb


## More hidden layers


Finally, we try to use 5 hidden layers instead of 3. The number of neurons in the layers are 8, 7, 6, 5 and 3 respectively. Still, two optimzation phases are clearly seen in the information plane.


In [42]:
n_epochs = 3000
key = jax.random.PRNGKey(12341)
hidden_layers = [8,7,6,5,3]
model = MLP(10, hidden_layers, key)
mi_xt_all, mi_ty_all, epochs = train_with_mi(n_epochs, model, x_train, y_train)

Epoch    0:  Testing accuracy 0.4974 - Testing loss 0.6942
Epoch  200:  Testing accuracy 0.5363 - Testing loss 0.6925
Epoch  400:  Testing accuracy 0.5725 - Testing loss 0.6917
Epoch  600:  Testing accuracy 0.6197 - Testing loss 0.6889
Epoch  800:  Testing accuracy 0.6325 - Testing loss 0.6664
Epoch 1000:  Testing accuracy 0.8205 - Testing loss 0.5269
Epoch 1200:  Testing accuracy 0.8205 - Testing loss 0.4756
Epoch 1400:  Testing accuracy 0.8205 - Testing loss 0.4708
Epoch 1600:  Testing accuracy 0.8205 - Testing loss 0.4689
Epoch 1800:  Testing accuracy 0.8205 - Testing loss 0.4677
Epoch 2000:  Testing accuracy 0.8205 - Testing loss 0.4664
Epoch 2200:  Testing accuracy 0.8205 - Testing loss 0.4648
Epoch 2400:  Testing accuracy 0.8205 - Testing loss 0.4619
Epoch 2600:  Testing accuracy 0.8205 - Testing loss 0.4551
Epoch 2800:  Testing accuracy 0.8205 - Testing loss 0.4354


In [None]:
fig, ax = plt.subplots(figsize=(8,8))
ax.set_xlim((0,7))
ax.set_ylim((0.0,0.7))
ax.set_xlabel('I(X;T)')
ax.set_ylabel('I(T;Y)')
title = ax.set_title('')
plt.close(fig)

anim = animation.FuncAnimation(
    fig,
    animate,
    init_func=None,
    frames=len(epochs),
    interval=100
)
HTML(anim.to_html5_video())

**Note**: Github cannot render the movies embedded in the notebook. You could view the movies from this notebook through nbviewer: https://nbviewer.jupyter.org/github/stevenliuyi/information-bottleneck/blob/master/information_bottleneck.ipynb
