# Investigating effective information in artificial neural networks

$$ EI = H(\langle W_i^\text{out} \rangle) - \langle H(W_i^\text{out}) \rangle $$

In [1]:
import os
from pathlib import Path
from itertools import islice
import gzip
import pickle

import numpy as np
import matplotlib.pyplot as plt

from bokeh.io import push_notebook, show, output_notebook
from bokeh.plotting import figure, ColumnDataSource
from bokeh.transform import linear_cmap
from bokeh.palettes import Plasma11, Viridis11
from bokeh.util.hex import hexbin
from bokeh.layouts import gridplot
output_notebook()

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import torchvision
import torchvision.transforms as transforms

In [2]:
def H(x, dim=0):
    """Compute the Shannon information entropy of x.
    
    Given a tensor x, compute the shannon entropy along one of its axes. If
    x.shape == (N,) then returns a scalar (0-d tensor). If x.shape == (N, N)
    then information can be computed along vertical or horizontal axes by
    passing arguments dim=0 and dim=1, respectively.
    
    Note that the function does not check that the axis along which
    information will be computed represents a valid probability distribution.
    
    Args:
        x (torch.tensor) containing probability distribution
        dim (int) dimension along which to compute entropy
    
    Returns:
        (torch.tensor) of a lower order than input x
    """
    r = x * torch.log2(x)
    r[r != r] = 0
    return -torch.sum(r, dim=dim)


def soft_norm(W):
    """Turns 2x2 matrix W into a transition probability matrix.
    
    The weight/adjacency matrix of an ANN does not on its own allow for EI
    to be computed. This is because the out weights of a given neuron are not
    a probability distribution (they do not necessarily sum to 1). We therefore
    must normalize them. 
    
    Applies a softmax function to each row of matrix
    W to ensure that the out-weights are normalized.
    
    Args:
        W (torch.tensor) of shape (2, 2)
        
    Returns:
        (torch.tensor) of shape (2, 2)
    """
    return F.softmax(W, dim=1)


def lin_norm(W):
    """Turns 2x2 matrix W into a transition probability matrix.
    
    Applies a relu across the rows (to get rid of negative values), and normalize
    the rows based on their arithmetic mean.
    
    Args:
        W (torch.tensor) of shape (2, 2)
        
    Returns:
        (torch.tensor) of shape (2, 2)
    """
    W = F.relu(W)
    row_sums = torch.sum(W, dim=1)
    row_sums[row_sums == 0] = 1
    row_sums = row_sums.reshape((-1, 1))
    return W / row_sums


def sig_norm(W):
    """Turns 2x2 matrix W into a transition probability matrix.
    
    Applies logistic function on each element and normalize
    across rows.
    
    Args:
        W (torch.tensor) of shape (2, 2)
        
    Returns:
        (torch.tensor) of shape (2, 2)
    """
    W = torch.sigmoid(W)
    row_sums = torch.sum(W, dim=1).reshape((-1, 1))
    return W / row_sums


def determinism(model, norm=None, device='cpu'):
    H_sum = 0
    N = 0
    for layer in model.modules():
        if type(layer) is not nn.Linear:
            continue
        W = layer.weight.t()
        if norm:
            W = norm(W)
        H_sum += torch.sum(H(W, dim=1)).item()
        N += W.shape[0]
    return H_sum / N


def degeneracy(model, norm=None, device='cpu'):
    in_weights = torch.zeros((0,)).to(device)
    total_weight = 0
    for layer in model.modules():
        if type(layer) is not nn.Linear:
            continue
        W = layer.weight.t()
        if norm:
            W = norm(W)
        in_weights = torch.cat((in_weights, torch.sum(W, dim=0)))
        total_weight += torch.sum(W).item()
    return H(in_weights / total_weight)
     


def EI(model, norm=None, device='cpu'):
    """Compute effective information from of fully-connected sequential `model`."""
    return degeneracy(model, norm=norm, device=device) - determinism(model, norm=norm, device=device)


In [3]:
dir_path = Path().absolute()
dataset_path = dir_path.parent / "data/mnist.pkl.gz"
if not dataset_path.exists():
    print('Downloading dataset with curl ...')
    if not dataset_path.parent.exists():
        os.mkdir(dataset_path.parent)
    url = 'http://ericjmichaud.com/downloads/mnist.pkl.gz'
    os.system('curl -L {} -o {}'.format(url, dataset_path))
print('Download failed') if not dataset_path.exists() else print('Dataset acquired')
f = gzip.open(dataset_path, 'rb')
mnist = pickle.load(f)
f.close()
print('Loaded data to variable `mnist`')

Dataset acquired
Loaded data to variable `mnist`


In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# dtype = torch.cuda.float if torch.cuda.is_available() else torch.float
dtype = torch.float32
torch.set_default_dtype(dtype)

In [5]:
class MNISTDataset(Dataset):
    """MNIST Digits Dataset."""
    def __init__(self, data, transform=None):
        self.mnist = data
        self.transform = transform
        
    def __len__(self):
        return len(self.mnist)

    def __getitem__(self, idx):
        sample = self.mnist[idx]
        if self.transform:
            sample = self.transform(sample)
        return sample

class ToTensor(object):
    """Convert samples (np.ndarray, np.ndarray) to (torch.tensor, torch.tensor)"""
    def __call__(self, sample):
        image, one_hot_label = sample
        image = torch.from_numpy(image).to(dtype)
        label = torch.tensor(np.argmax(one_hot_label)).to(torch.long)
        return (image, label)


In [6]:
class SoftmaxRegression(nn.Module):
    """Single-layer softmax network."""
    def __init__(self, n_in, n_out):
        super(SoftmaxRegression, self).__init__()
        self.linear = nn.Linear(n_in, n_out, bias=False)
    
    def forward(self, x):
        return F.log_softmax(self.linear(x), dim=1)
    
    
class FullyConnected(nn.Module):
    """Single-hidden-layer dense neural network."""
    def __init__(self, *layers):
        super(FullyConnected, self).__init__()
        self.layers = nn.ModuleList()
        for i in range(len(layers) - 1):
            self.layers.append(nn.Linear(layers[i], layers[i+1], bias=False))
        
    def forward(self, x):
        for i in range(len(self.layers) - 1):
            x = torch.tanh(self.layers[i](x))
        return F.log_softmax(self.layers[i+1](x), dim=1)

In [7]:
training = MNISTDataset(mnist[:60000], transform=ToTensor())
test = MNISTDataset(mnist[60000:], transform=ToTensor())
training_loader = torch.utils.data.DataLoader(training, batch_size=20, shuffle=True)
test_loader = torch.utils.data.DataLoader(test, batch_size=20, shuffle=True)

In [8]:
model = SoftmaxRegression(28*28, 10).to(device)
# model = FullyConnected(28*28, 500, 500, 10).to(device)
loss_fn = nn.NLLLoss()
optimizer = torch.optim.Adam(model.parameters())

In [9]:
# --- DEFINE OUTPUT ---

source = ColumnDataSource(data={
    'num_batches': [],
    'accuracy': [],
    'test_loss': [],
    'EI': []
})

accuracy_plot = figure(plot_width=800, plot_height=150, 
                       x_axis_label='batches', title='Accuracy')
test_loss_plot = figure(plot_width=800, plot_height=150, 
                       x_axis_label='batches', title='Loss on Test Data')
EI_plot = figure(plot_width=800, plot_height=150,
                x_axis_label='batches', y_axis_label='bits', title='Effective Information')
# weight_image = figure(plot_width=5*150, plot_height=2*150, title='Output Weights')
# weight_image.xgrid.visible = False
# weight_image.ygrid.visible = False

accuracy_plot.line('num_batches', 'accuracy', source=source, line_width=2, color='green')
test_loss_plot.line('num_batches', 'test_loss', source=source, line_width=2, color='red')
EI_plot.line('num_batches', 'EI', source=source, line_width=2)
# weight_image.image([np.zeros((784, 10)).reshape(112, -1)], 0, 0, 1, 1, palette=Plasma11)

grid = gridplot([[accuracy_plot], [test_loss_plot], [EI_plot]])
display = show(grid, notebook_handle=True)

def update_metrics():
    correct = 0
    outof = 0
    loss = 0
    with torch.no_grad(): 
        ei = EI(model, norm=lin_norm, device=device)
        for x, labels in islice(test_loader, 0, 100): # 100 batches of 20 samples
            output = model(x.to(device))
            loss += loss_fn(output, labels.to(device)).item()
            _, pred = torch.max(output, 1)
            correct += (pred == labels.to(device)).sum().item()
            outof += len(labels)
    loss = loss / outof
    ac = correct / outof
    source.stream({
        'num_batches': [num_batches],
        'accuracy': [ac],
        'test_loss': [loss],
        'EI': [ei.item()]
    })
#     weight = model.linear.weight.cpu().detach().numpy()
#     for i in range(len(weight)):
#         x = i % 5
#         y = i // 5
#         weight_image.image([weight[i].reshape(28, 28)], x*29, -y*29, 28, 28, palette=Viridis11)
    push_notebook(handle=display)


# --- TRAIN ---

num_batches = 0
ac = 0
ei = None
for epoch in range(10):
    for sample, target in training_loader:
        optimizer.zero_grad()
        loss = loss_fn(model(sample.to(device)), target.to(device))
        loss.backward()
        optimizer.step()
        num_batches += 1
        if num_batches % 100 == 0:
            update_metrics()