# **Concept of information bottleneck in neural networks by systematically varying the network's capacity**

In [1]:
# Information Bottleneck Experiment: Varying Network Capacity
# ==============================================================
# This script trains simple MLP classifiers on MNIST with varying
# hidden-layer sizes (network capacity) and estimates mutual
# information I(X;T) and I(T;Y) for the bottleneck layer via
# discretization (histogram binning).

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
from sklearn.metrics import mutual_info_score
import os

# -------------------------
# Helper: Estimate Mutual Info
# -------------------------

def estimate_mutual_info(X, Y, n_bins=30):
    '''
    Estimate mutual information between X and Y via histogram binning.
    X, Y: 1D arrays of same length
    '''
    # Digitize
    bins = np.linspace(np.min(X), np.max(X), n_bins + 1)
    X_binned = np.digitize(X, bins) - 1
    Y_binned = np.digitize(Y, bins) - 1 if Y.dtype.kind in 'f' else Y

    # Compute joint histogram
    return mutual_info_score(X_binned, Y_binned)

# -------------------------
# MLP with Bottleneck Layer
# -------------------------
class BottleneckMLP(nn.Module):
    def __init__(self, input_dim, bottleneck_dim, hidden_dim=256, num_classes=10):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.bottleneck = nn.Linear(hidden_dim, bottleneck_dim)
        self.fc_out = nn.Linear(bottleneck_dim, num_classes)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        h = self.relu(self.fc1(x))
        t = self.relu(self.bottleneck(h))  # bottleneck activations
        out = self.fc_out(t)
        return out, t

# -------------------------
# Training & MI logging
# -------------------------

def run_experiment(bottleneck_size, device, epochs=10, batch_size=256):
    # Data
    transform = transforms.ToTensor()
    train_ds = datasets.MNIST(root='data', train=True, download=True, transform=transform)
    test_ds = datasets.MNIST(root='data', train=False, download=True, transform=transform)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

    # Model, opt, loss
    model = BottleneckMLP(28*28, bottleneck_size).to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    # Containers for MI estimates
    mi_x_t = []  # I(X;T)
    mi_t_y = []  # I(T;Y)

    for epoch in range(1, epochs+1):
        model.train()
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            logits, t = model(imgs)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()

        # After epoch: estimate MI on a subset
        model.eval()
        all_t = []
        all_x = []  # flattened inputs
        all_y = []
        with torch.no_grad():
            for imgs, labels in test_loader:
                imgs, labels = imgs.to(device), labels.to(device)
                logits, t = model(imgs)
                all_t.append(t.cpu().numpy())
                all_x.append(imgs.view(imgs.size(0), -1).cpu().numpy())
                all_y.append(labels.cpu().numpy())
        all_t = np.concatenate(all_t, axis=0)
        all_x = np.concatenate(all_x, axis=0)
        all_y = np.concatenate(all_y, axis=0)

        # Flatten for MI per neuron then average
        i_x_t = np.mean([estimate_mutual_info(all_x[:, i], all_t[:, i])
                           for i in range(min(all_t.shape[1], 50))])
        i_t_y = np.mean([estimate_mutual_info(all_t[:, i], all_y)
                           for i in range(min(all_t.shape[1], 50))])
        mi_x_t.append(i_x_t)
        mi_t_y.append(i_t_y)
        print(f"Bottleneck {bottleneck_size} | Epoch {epoch}: I(X;T)={i_x_t:.4f}, I(T;Y)={i_t_y:.4f}")

    return mi_x_t, mi_t_y

# -------------------------
# Main: Sweep over capacities
# -------------------------
if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    bottleneck_sizes = [5, 10, 20, 50, 100]
    results = {}

    for size in bottleneck_sizes:
        print(f"Running experiment for bottleneck size = {size}")
        mi_x_t, mi_t_y = run_experiment(size, device)
        results[size] = {'I(X;T)': mi_x_t, 'I(T;Y)': mi_t_y}

    # Save results for plotting
    os.makedirs('results', exist_ok=True)
    torch.save(results, 'results/ib_results.pth')
    print("All experiments completed. Results saved to results/ib_results.pth")


Running experiment for bottleneck size = 5


100%|██████████| 9.91M/9.91M [00:00<00:00, 42.5MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.19MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 10.8MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 5.35MB/s]


Bottleneck 5 | Epoch 1: I(X;T)=0.0000, I(T;Y)=0.5872
Bottleneck 5 | Epoch 2: I(X;T)=0.0000, I(T;Y)=0.6147
Bottleneck 5 | Epoch 3: I(X;T)=0.0000, I(T;Y)=0.6157
Bottleneck 5 | Epoch 4: I(X;T)=0.0000, I(T;Y)=0.6115
Bottleneck 5 | Epoch 5: I(X;T)=0.0000, I(T;Y)=0.6069
Bottleneck 5 | Epoch 6: I(X;T)=0.0000, I(T;Y)=0.6130
Bottleneck 5 | Epoch 7: I(X;T)=0.0000, I(T;Y)=0.6117
Bottleneck 5 | Epoch 8: I(X;T)=0.0000, I(T;Y)=0.6135
Bottleneck 5 | Epoch 9: I(X;T)=0.0000, I(T;Y)=0.6184
Bottleneck 5 | Epoch 10: I(X;T)=0.0000, I(T;Y)=0.6164
Running experiment for bottleneck size = 10
Bottleneck 10 | Epoch 1: I(X;T)=0.0000, I(T;Y)=0.4868
Bottleneck 10 | Epoch 2: I(X;T)=0.0000, I(T;Y)=0.4944
Bottleneck 10 | Epoch 3: I(X;T)=0.0000, I(T;Y)=0.5020
Bottleneck 10 | Epoch 4: I(X;T)=0.0000, I(T;Y)=0.5133
Bottleneck 10 | Epoch 5: I(X;T)=0.0000, I(T;Y)=0.5160
Bottleneck 10 | Epoch 6: I(X;T)=0.0000, I(T;Y)=0.5279
Bottleneck 10 | Epoch 7: I(X;T)=0.0000, I(T;Y)=0.5292
Bottleneck 10 | Epoch 8: I(X;T)=0.0000, I(T;Y)=