# Rank of Neural Hessians

This code illustrates the results of the paper **Analytic Insights into Structure and Rank of Neural Network Hessian Maps** by calculating the rank of several Hessians and comparing them to the derived upper-bounds. The code is accompanied by the relevant formulas of the paper but of course is by no means an adequate substitution but is meant to provide context.

In [8]:
import jax.numpy as jnp
from jax import *
from jax.config import config
from jax.experimental.stax import softmax, logsoftmax
from initializers import get_init

from data import get_dataset
from hessians import outer_prod, loss_hessian
from architectures import fully_connected

from dataloader import DatasetTorch
from torch.utils.data import DataLoader

High numerical precision is essential to the calculations, thus all tensors will be of type **float64**. Using float32 leads to instabilities and indeed produces wrong results! Moreover, we found that float64 is vital to all calculations, not just the final rank computation!

In [9]:
config.update("jax_enable_x64", True)  

We have the following setting. We have a neural network  $f_{{\boldsymbol{\theta}}}: \mathbb{R}^{d} \xrightarrow{} \mathbb{R}^{k}$ with parameters $\boldsymbol{\theta} \in \mathbb{R}^{p}$, some input data $\boldsymbol{x}_1, \dots, \boldsymbol{x}_n \in \mathbb{R}^{d}$ and targets $\boldsymbol{y}_1, \dots, \boldsymbol{y}_n \in \mathbb{R}^{k}$.  Moreover we have some loss function $\mathcal{L}(\boldsymbol{\theta})$ that measures how well we predict $\boldsymbol{y}$. Here we will focus on the squared loss, 
$$\mathcal{L}(\boldsymbol{\theta}) = \sum_{i=1}^n||f_{\boldsymbol{\theta}}(\boldsymbol{x}_i)-\boldsymbol{y}_i{||}_2^2$$
 We want to calculate the Hessian of the loss with respect to the parameters $\boldsymbol{\theta}$, i.e
$$\boldsymbol{H}_{\mathcal{L}} = \frac{\partial^2}{\partial \boldsymbol{\theta} \partial \boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta})$$
Let us first setup the hyperparameters, namely the sample size $n$ of our problem, the dimensionality $d$ of our inputs and the number of classes $k$, i.e. the dimensionality of $y$.

In [10]:
# Define the hyperparameters
n_train = 50                                                      # Sample size
dim = 25                                                          # Dimension of data
widths = [5, 10]                                                  # Width of the network, excluding last layer
classes = 10                                                      # Number of classes
bs = 10

all_widths = [dim] + widths + [classes]
p = sum([all_widths[i] * all_widths[i + 1] for i in range(len(all_widths)-1)])   

# Initialize seed                           
key = random.PRNGKey(1)

Let's specify the data distribution, we can choose from using 'MNIST', 'FashionMNIST' or 'CIFAR10'. We need to calculate the covariance matrix 
$$\boldsymbol{\Sigma} = \sum_{i=1}^n \boldsymbol{x}_i \boldsymbol{x}_i^T$$
and its rank $r = \text{rank}(\boldsymbol{\Sigma})$, as $r$ will enter in our formulas for the predictions.

In [12]:
# Define the data, we will choose down-scaled MNIST 

data = get_dataset('CIFAR', n_train=n_train, n_test=1, dim=dim, classes=classes)

# Define a train loader so that we can batch the Hessian calculation
train_loader = DataLoader(DatasetTorch(data.x_train, data.y_train), batch_size=bs)

Files already downloaded and verified
Files already downloaded and verified


Next we fix our neural network model, which will be a fully-connected linear network of the following form:
$$f_{\boldsymbol{\theta}}(\boldsymbol{x}) = \boldsymbol{W}_L \dots \boldsymbol{W}_1 x$$
where $\boldsymbol{W}_i \in R^{m_{i-1} \times m_{i}}$ with $m_0 = d$ and $m_L = k$. For simplicity we will ignore biases, but if you're interested, check out the paper for the corresponding formulas with bias! Our theorems hold for a variety of initialization schemes, here you can choose from either 'glorot' (scaled Gaussian initialization), 'uniform' (scaled uniform initialization) or 'orthogonal' (sampling according to the Haar measure from the space of orthogonal matrices):

In [13]:
# Choose initialization
init = get_init('glorot')                   
# Define linear neural network architecture
init_fn, apply_fn = fully_connected(units=widths, classes=classes, activation='linear', init=init)

# Initialize the parameters
_, params = init_fn(key, (-1, dim))

# Make sure parameters are double precision
params = [jnp.double(param) for param in params]

As said above, here we focus on the mean-squared error 
$$\mathcal{L}(\boldsymbol{\theta}) = \sum_{i=1}^n||\boldsymbol{y}_i-f_{\boldsymbol{\theta}}(\boldsymbol{x}_i){||}_2^2$$
Our theorems however also extend to the case of many other losses. For instance, use 'cross' for the cross entropy loss or 'cosh' for the cosh loss.

In [14]:
loss_name = 'mse'

if loss_name == 'mse':
    cross = False
    
    def loss(preds, targets):
        return 1/2 * jnp.sum((preds - targets)**2)
    
    
    def loss_params(params, inputs, targets):
        preds = apply_fn(params, inputs)
    
        return 1/2 * jnp.sum((preds - targets)**2)

if loss_name == 'cross':
    cross = True
    
    def loss(preds, targets):
        return -jnp.sum(logsoftmax(preds) * targets)


    def loss_params(params, inputs, targets):
        preds = apply_fn(params, inputs)

        return -jnp.sum(logsoftmax(preds) * targets)
    
if loss_name == 'cosh':
    cross = False
    
    def loss(preds, targets):
        return jnp.sum(jnp.log(jnp.cosh(preds - targets)))

    def loss_params(params, inputs, targets):
        preds = apply_fn(params, inputs)

        return loss(preds, targets)

Now we come to the Hessian. We proceed to split into a functional part and a part that consists of the outer product of gradients, i.e.
$$\boldsymbol{H}_{\mathcal{L}} = \boldsymbol{H}_f + \boldsymbol{H}_o$$
where, for squared-loss, it holds that 
$$\boldsymbol{H}_f = \sum_{i=1}^n\sum_{l=1}^k(y_{il}-f_{\boldsymbol{\theta}}(\boldsymbol{x}_i))\frac{\partial^2 f_l(\boldsymbol{x}_i)}{\partial \boldsymbol{\theta} \partial \boldsymbol{\theta}}$$
and for the outer-gradient term we have
$$\boldsymbol{H}_o = \sum_{i=1}^n\sum_{l=1}^k \frac{\partial f_k(\boldsymbol{x}_i)}{\partial \boldsymbol{\theta}}\left(\frac{\partial f_k(\boldsymbol{x}_i)}{\partial \boldsymbol{\theta}}\right)^T$$
For the formulas for more general loss functions, please check out our paper! Due to the structure of Jax, it is simpler (and more memory-efficient) to calculate the loss hessian and the outer product of gradients. Hence we will calculate the functional Hessian as 
$$\boldsymbol{H}_f = \boldsymbol{H}_{\mathcal{L}} - \boldsymbol{H}_o$$

In [15]:
H_L, H_outer = jnp.zeros(shape=(p, p)), jnp.zeros(shape=(p, p))
cov = jnp.zeros(shape=(dim, dim))

for batch_input, batch_label in train_loader:
    batch_input, batch_label = (batch_input.numpy(), batch_label.numpy())
    # Calculate the covariance
    cov += batch_input.T @ batch_input
    # Calculate loss hessian
    H_L += loss_hessian(loss_params, params, batch_input, batch_label)
    # Calculate the outer gradient product
    H_outer += outer_prod(loss, apply_fn, params, batch_input, batch_label, cross=cross)

# To save time we calculate the functional Hessian as the difference
H_F = H_L - H_outer

In [16]:
rank_cov = jnp.linalg.matrix_rank(cov)
rank_L = jnp.linalg.matrix_rank(H_L)
rank_outer = jnp.linalg.matrix_rank(H_outer)
rank_F = jnp.linalg.matrix_rank(H_F)

We can calculate our upper bounds on the Hessian rank, introduced in our paper and compare them with the numerical results. Our predictions, look as follows
\begin{align}
\text{rank}(\boldsymbol{H}_o) &\leq q(r + k - q) \\
\text{rank}(\boldsymbol{H}_f) &\leq 2q \sum_{l=1}^L m_l + 2qs -Lq^2 \\
\text{rank}(\boldsymbol{H}_L) &\leq 2q \sum_{l=1}^L m_l -Lq^2 + q(r+k)
\end{align}
where $q=\text{min}(r, k, m_1, \dots m_L)$ and $s = \text{min}(r, k)$

In [17]:
if loss_name == 'cross':
    classes = classes - 1
    
s = jnp.min(jnp.array([rank_cov, classes]))
q = jnp.min(jnp.array([rank_cov, classes] + widths))

pred_F = 2 * q * sum(widths) + 2 * q * s - (len(widths)+1) * q**2
pred_outer = (rank_cov + classes - q) * q
pred_L = pred_F + pred_outer + q * (q - 2 * s)

print('Rank of Functional Hessian is ' + str(rank_F) + ' and the prediction is ' + str(pred_F))

print('Rank of Gradient Outer Product is ' + str(rank_outer) + ' and the prediction is ' + str(pred_outer))

print('Rank of Loss Hessian is ' + str(rank_L) + ' and the prediction is ' + str(pred_L))

Rank of Functional Hessian is 175 and the prediction is 175
Rank of Gradient Outer Product is 150 and the prediction is 150
Rank of Loss Hessian is 250 and the prediction is 250
