# **Impact of network architecture on information flow and learning capabilities**

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
from sklearn.metrics import mutual_info_score
import os

def estimate_mutual_info(X, Y, n_bins=30):
    bins_x = np.linspace(np.min(X), np.max(X), n_bins + 1)
    x_b = np.digitize(X, bins_x) - 1
    y_b = Y if Y.dtype.kind in 'i' else np.digitize(Y, bins_x) - 1
    return mutual_info_score(x_b, y_b)

class MLP(nn.Module):
    def __init__(self, dims):
        super().__init__()
        layers = []
        for i in range(len(dims) - 2):
            layers.append(nn.Linear(dims[i], dims[i+1]))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(dims[-2], dims[-1]))
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        activations = []
        out = x
        for layer in self.network:
            out = layer(out)
            if isinstance(layer, nn.ReLU):
                activations.append(out.detach())
        return out, activations

class SimpleCNN(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(1, channels, 5, padding=2)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(channels, channels*2, 5, padding=2)
        self.pool2 = nn.MaxPool2d(2)
        self.fc = nn.Linear((channels*2)*7*7, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        activations = []
        out = self.relu(self.conv1(x))
        activations.append(out.view(out.size(0), -1).detach())
        out = self.pool1(out)
        out = self.relu(self.conv2(out))
        activations.append(out.view(out.size(0), -1).detach())
        out = self.pool2(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out, activations

def run_arch_experiment(arch_name, model, train_loader, test_loader, device, epochs=5):
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    mi_x_t = {}
    mi_t_y = {}
    layer_count = None

    for epoch in range(1, epochs + 1):
        model.train()
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs, _ = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        model.eval()
        all_acts = []
        all_x = []
        all_y = []
        with torch.no_grad():
            for imgs, labels in test_loader:
                imgs, labels = imgs.to(device), labels.to(device)
                out, acts = model(imgs)
                if not all_acts:
                    all_acts = [[] for _ in acts]
                    layer_count = len(acts)
                for i, a in enumerate(acts):
                    all_acts[i].append(a.cpu().numpy())
                all_x.append(imgs.view(imgs.size(0), -1).cpu().numpy())
                all_y.append(labels.cpu().numpy())

        all_x = np.concatenate(all_x, axis=0)
        all_y = np.concatenate(all_y, axis=0)

        for i in range(layer_count):
            if i not in mi_x_t:
                mi_x_t[i] = []
                mi_t_y[i] = []
            layer_acts = np.concatenate(all_acts[i], axis=0)
            n_neurons = min(layer_acts.shape[1], 50)
            mi_x_vals = []
            mi_y_vals = []
            for n in range(n_neurons):
                mi_x_vals.append(estimate_mutual_info(all_x[:, n], layer_acts[:, n]))
                mi_y_vals.append(estimate_mutual_info(layer_acts[:, n], all_y))
            mi_x_t[i].append(np.mean(mi_x_vals))
            mi_t_y[i].append(np.mean(mi_y_vals))
        print(f"{arch_name} Epoch {epoch} complete")

    return mi_x_t, mi_t_y

def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    transform = transforms.ToTensor()
    train_ds = datasets.MNIST('data', train=True, download=True, transform=transform)
    test_ds = datasets.MNIST('data', train=False, download=True, transform=transform)
    train_loader = DataLoader(train_ds, batch_size=256, shuffle=True)
    test_loader = DataLoader(test_ds, batch_size=256, shuffle=False)

    architectures = {
        'MLP_1layer': MLP([28*28, 128, 10]),
        'MLP_3layer': MLP([28*28, 256, 128, 64, 10]),
        'MLP_5layer': MLP([28*28, 512, 256, 128, 64, 10]),
        'SimpleCNN': SimpleCNN(16)
    }

    results = {}
    for name, model in architectures.items():
        print(f"Running {name}")
        mi_x_t, mi_t_y = run_arch_experiment(name, model, train_loader, test_loader, device)
        results[name] = {'I(X;T)': mi_x_t, 'I(T;Y)': mi_t_y}

    os.makedirs('results', exist_ok=True)
    torch.save(results, 'results/arch_info_flow.pth')
    print("Experiments done. Results saved to results/arch_info_flow.pth")

if __name__ == '__main__':
    main()


Running MLP_1layer
MLP_1layer Epoch 1 complete
MLP_1layer Epoch 2 complete
MLP_1layer Epoch 3 complete
MLP_1layer Epoch 4 complete
MLP_1layer Epoch 5 complete
Running MLP_3layer
MLP_3layer Epoch 1 complete
MLP_3layer Epoch 2 complete
MLP_3layer Epoch 3 complete
MLP_3layer Epoch 4 complete
MLP_3layer Epoch 5 complete
Running MLP_5layer
MLP_5layer Epoch 1 complete
MLP_5layer Epoch 2 complete
MLP_5layer Epoch 3 complete
MLP_5layer Epoch 4 complete
MLP_5layer Epoch 5 complete
Running SimpleCNN
SimpleCNN Epoch 1 complete
SimpleCNN Epoch 2 complete
SimpleCNN Epoch 3 complete
SimpleCNN Epoch 4 complete
SimpleCNN Epoch 5 complete
Experiments done. Results saved to results/arch_info_flow.pth
