<a href="https://colab.research.google.com/github/maciejskorski/nn_hessian_intialization/blob/master/chain_rule_nn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Hessian chain rules for neural networks 

Consider a feed-forward neural network and its $k$-th layer. Let $w^{k}$ be the kernel, $z^{k}$ be the input and $z^{k+1} = w^{k}\cdot z^{k} + b^{k}$ be its output (forward activation). One can show via chain rules that the Hessian of loss $L(z,t)$ with respect to the layer weights $w=w^{k}$, where $z=z^n$ is the network output and $t$ is true label, reads
$$
D^2_{w} L = D^2_z L\bullet D_{w} z\bullet D_{w} z + D_z L\bullet D^2_w z  
$$
where $\bullet$ means tensor product along appropriate dimensions.

Note that the second term captures non-linear terms and in particular vanishes for activations like Relu!

Below we show how to compute this chain rule for a relu neural network on MNIST dataset.

In [1]:
%tensorflow_version 1.x

import tensorflow as tf
from tensorflow.keras import backend as K

## load MNIST, ignore validation data; normalize inputs

mnist = tf.keras.datasets.mnist
(train_inputs, train_labels), _ = mnist.load_data()
train_inputs = train_inputs / 255.0

## we will use cross-entropy loss

def SparseCategoricalCrossentropy(labels,logits):
  Z = tf.reduce_logsumexp(logits,axis=-1)
  lookup_labels = tf.stack([tf.range(tf.shape(labels)[0]),tf.cast(labels,tf.int32)],1)
  true_logits = tf.gather_nd(logits,lookup_labels,batch_dims=0)
  return -true_logits + Z

## build a model with two hidden layers and relu activation

def build_network(activation):

  inputs = tf.keras.layers.Input(shape=[28,28],dtype=tf.float32,name='inputs',batch_size=1)
  labels = tf.keras.layers.Input(shape=[],dtype=tf.int32,name='labels',batch_size=1)

  layer1 = tf.keras.layers.Flatten()
  out1 = layer1(inputs)

  layer2 = tf.keras.layers.Dense(30,activation=activation,name='dense1')
  out2 = layer2(out1)

  layer3 = tf.keras.layers.Dense(30,activation=activation,name='dense2')
  out3 = layer3(out2)

  layer4 = tf.keras.layers.Dense(10,activation='linear',name='dense3')
  out = layer4(out3)

  model = tf.keras.Model(inputs=[inputs,labels], outputs=out)

  loss = SparseCategoricalCrossentropy(model.input[1],model.output)
  model.add_loss(loss)

  model.compile(optimizer='sgd')
  
  return model

TensorFlow 1.x selected.


In [0]:
## test automated hessians and the chain rule above (linear part); use MNIST data

from tensorflow.python.ops.parallel_for.gradients import jacobian
from tensorflow.keras import backend as K
import numpy as np

def test_chain_rule(layer):
  global model
  # tensor to evaluate hessian form
  g = layer.kernel

  # auto hessian form
  H_auto = tf.hessians(model.total_loss, layer.kernel)[0]
  H_form_auto = tf.einsum('ab,abcd,cd->',g,H_auto,g)

  # manual hessian form
  H = tf.squeeze(tf.hessians(model.total_loss,model.output)[0])
  B = tf.squeeze(jacobian(model.output,layer.kernel))
  V = tf.einsum('abc,bc->a',B,g)
  H_form = tf.einsum('a,ab,b->',V,H,V)

  # test if auto hessian equals manual one
  test = tf.assert_near(H_form_auto,H_form,message='hessian calculation wrong')
  sess = K.get_session()
  sample = np.random.randint(0,len(train_labels),size=[1])
  feed_dict = {model.inputs[0]:train_inputs[sample],model.inputs[1]:train_labels[sample]}
  sess.run(test,feed_dict)

In [3]:
## test for relu, all good as it is piecewise linear !

K.clear_session()
model = build_network('relu')
for layer_name in ['dense1','dense2','dense3']:
  test_chain_rule( model.get_layer(layer_name) )

Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


In [4]:
## test tanh - this will fail as we skipped the quadratic term, expect errors below !

K.clear_session()
model = build_network('tanh')
for layer_name in ['dense1','dense2','dense3']:
  test_chain_rule( model.get_layer(layer_name) )



InvalidArgumentError: ignored