## Understanding ```hessQuik``` Layers

In this notebook, we construct a slightly simplified version of a ```hessQuik``` single layer, show various methods of computing the gradient and Hessian of the network in forward mode, and time the various methods of computation.  This notebook should serve as a tutorial for constructing new layers, our methods of testing layers, and exploration of how to make implementations more efficient on CPUs and GPUs.

## Install ```hessQuik``` and Other Packages

Here is also were you can select your data type and device.  We recommend testing the layer derivatives on a CPU with double precision first.

In [None]:
!python -m pip install git+https://github.com/elizabethnewman/hessQuik.git

In [None]:
import torch
import torch.nn as nn
import hessQuik.activations as act
import hessQuik.layers as lay

# set precision
torch.set_default_dtype(torch.float64)
print('Default data type:', torch.get_default_dtype())

# use GPU if available
device = torch.device(f'cuda:0' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

## Constructing Our First Layer

Here, we show a slightly simplified single layer implementation where we remove the bias and do not use PyTorch's initialization of the layer parameters.  We also only implement a forward method for illustrative purposes.  You can find our full implementation here: [singleLayer](https://github.com/elizabethnewman/hessQuik/blob/main/hessQuik/layers/single_layer.py).

<!-- The basics of this layer are the following. Given an input $\mathbf{x}\in \mathbb{R}^{d}$ where $d$ is the dimension of the input,  suppose we have forward propagated using a function $u:\mathbb{R}^d \to \mathbb{R}^n$.  
Then, our current layer has the following dependence on $\mathbf{x}$:
\begin{align}
f(\mathbf{x}) &= \sigma(\mathbf{K} u(\mathbf{x})) && f:\mathbb{R}^d \to \mathbb{R}^m \text{ and } \mathbf{K}\in \mathbb{R}^{m\times n},
\end{align}
and $\sigma:\mathbb{R}\to \mathbb{R}$ is a twice continuously differentiable activation function applied entrywise.

The gradient of $f$ with respect to $\mathbf{x}$ is given by
\begin{align}
\frac{\partial f}{\partial \mathbf{x}}
  &=\frac{\partial u}{\partial \mathbf{x}} \frac{\partial f}{\partial u}
  =\frac{\partial u}{\partial \mathbf{x}} \mathbf{K}^\top \text{diag}(\sigma'(\mathbf{K} u(\mathbf{x})))
\end{align}
where $\dfrac{\partial u}{\partial \mathbf{x}}\in \mathbb{R}^{d\times n}$ has been pre-computed.  Here, $\sigma'$ is the derivative of the activation function applied entrywise and $\text{diag}(\cdot)$ stores the components of a vector as the diagonal entries of a matrix. -->

In [None]:
class simpleHessQuikLayerV1(lay.hessQuikLayer):

  def __init__(self, in_features, out_features, act=act.tanhActivation(), 
               device=None, dtype=None):
      factory_kwargs = {'device': device, 'dtype': dtype}
      super(simpleHessQuikLayerV1, self).__init__()
      self.in_features = in_features
      self.out_features = out_features
      self.act = act
      self.K = nn.Parameter(torch.randn(in_features, out_features, **factory_kwargs))

  def dim_input(self):
    return self.in_features

  def dim_output(self):
    return self.out_features

  def forward(self, u, do_gradient=False, do_Hessian=False, forward_mode=True, 
              dudx=None, d2ud2x=None):
    (dfdx, d2fd2x) = (None, None)
    f, dsig, d2sig = self.act(u @ self.K, do_gradient=do_gradient, do_Hessian=do_Hessian)
    
    if (do_gradient or do_Hessian):
      dfdx = dsig.unsqueeze(1) * self.K

      if do_Hessian:
        d2fd2x = (d2sig.unsqueeze(1) * self.K).unsqueeze(2) * self.K.unsqueeze(0).unsqueeze(0)
        
        if d2ud2x is not None:
          # Gauss-Newton approximation
          d2fd2x = dudx.unsqueeze(1) @ (d2fd2x.permute(0, 3, 1, 2) @ dudx.permute(0, 2, 1).unsqueeze(1))
          d2fd2x = d2fd2x.permute(0, 2, 3, 1)
          
          # extra term to compute full Hessian
          d2fd2x += d2ud2x @ dfdx.unsqueeze(1)

    if dudx is not None:
      dfdx = dudx @ dfdx

    return f, dfdx, d2fd2x


## Testing ```hessQuik``` Layers and Networks

We use a Taylor series approximation to test new layers.  Suppose we have a ```hessQuik``` network $f$ and we have pre-computed the value and derivatives at a particular point $\mathbf{x}$ via

```python
f0, df0, d2f0 = f(x, do_gradient=True, do_Hessian=True)
```
If our gradient is correct, then if we perturb our input $\mathbf{x}$ by unit vector $\mathbf{p}$ with step size $h$, then by Taylor's Theorem,
\begin{align}
\left\|f(\mathbf{x}) + h \frac{\partial f(\mathbf{x})}{\partial \mathbf{x}}^\top \mathbf{p} - f(\mathbf{x} + h\mathbf{p})\right\| = \mathcal{O}(h^2).
\end{align}
If we have computed the correct gradient, $\frac{\partial f(\mathbf{x})}{\partial \mathbf{x}}$, then as $h\to 0$, the first-order error (left-hand side) also goes to zero at a rate of $h^2$.  This means if we divide $h$ by $2$ (smaller perturbation), then we should see the first-order error decrease by a factor of $4$. Similar logic can be applied for the zeroth-order (no derivative information) and second-order (Hessian information) error.

In the ```input_derivative_check```, we should see the printouts where the first column, ```h```, is continually cut in half. The second column, ```E0```, is the zeroth-order error and is cut in half with each change of ```h```. The third column, ```E1```, is the first-order  error and is divided by 4 with each change of ```h```.  The last column, ```E2```, is the second-order error and is divded by 8 every time ```h``` is halved.  We want to see this behavior consistently, but some steps will not match perfectly due to the numerical approximation.




In [None]:
from hessQuik.utils import input_derivative_check, input_derivative_check_finite_difference

my_layer = simpleHessQuikLayerV1(3, 4).to(device)

nex = 10
x = torch.randn(nex, my_layer.dim_input(), device=device)

grad_check, hess_check = input_derivative_check(my_layer, x, do_Hessian=True, verbose=True, forward_mode=True)

# an alternative, but typically slower test using finite differences
# grad_check, hess_check = input_derivative_check_finite_difference(my_layer, x, do_Hessian=True, verbose=True, forward_mode=True)

We can use the same test to verify that a network made from our layers computed derivatives correctly.

In [None]:
import hessQuik.networks as net

my_networkV1 = net.NN(simpleHessQuikLayerV1(3, 4), 
                    simpleHessQuikLayerV1(4, 5)).to(device)

nex = 10
x = torch.randn(nex, my_networkV1.dim_input(), device=device)

grad_check, hess_check = input_derivative_check(my_networkV1, x, do_Hessian=True, verbose=True, forward_mode=True)

## Alternative Methods of Computing Derivatives
There are many ways to utilize broadcasting and other PyTorch parallelism to make our computation more time and/or storage efficient.  The example above is the method we use in our package, but we provide some alternative options here for completeness.  In certain settings (e.g., CPU vs. GPU, network architecture), one option may be better than the other.  

Note: for this simple layer, there is not much difference between the various methods.

In [None]:
class simpleHessQuikLayerV2(simpleHessQuikLayerV1):
  def __init__(self, in_features, out_features):
      super().__init__(in_features, out_features)

  def forward(self, u, do_gradient=False, do_Hessian=False, forward_mode=True, 
              dudx=None, d2ud2x=None):
    (dfdx, d2fd2x) = (None, None)
    f, dsig, d2sig = self.act(u @ self.K, do_gradient=do_gradient, do_Hessian=do_Hessian)
    
    if (do_gradient or do_Hessian):
      dfdx = dsig.unsqueeze(1) * self.K

      if do_Hessian:
        # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
        # alternative to
        # d2fd2x = (d2sig.unsqueeze(1) * self.K).unsqueeze(2) * self.K.unsqueeze(0).unsqueeze(0)
        d2fd2x = (d2sig.unsqueeze(-1).unsqueeze(-1) * (self.K.T.unsqueeze(-1) @ self.K.T.unsqueeze(1)))
        d2fd2x = d2fd2x.permute(0, 2, 3, 1)
        # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #

        if d2ud2x is not None:
          # Gauss-Newton approximation
          d2fd2x = dudx.unsqueeze(1) @ (d2fd2x.permute(0, 3, 1, 2) @ dudx.permute(0, 2, 1).unsqueeze(1))
          d2fd2x = d2fd2x.permute(0, 2, 3, 1)
          
          # extra term to compute full Hessian
          d2fd2x += d2ud2x @ dfdx.unsqueeze(1)

    if dudx is not None:
      dfdx = dudx @ dfdx
    return f, dfdx, d2fd2x


my_networkV2 = net.NN(simpleHessQuikLayerV2(3, 4), 
                      simpleHessQuikLayerV2(4, 5)).to(device)

nex = 10
x = torch.randn(nex, my_networkV2.dim_input(), device=device)

grad_check, hess_check = input_derivative_check(my_networkV2, x, do_Hessian=True, verbose=True, forward_mode=True)

In [None]:
class simpleHessQuikLayerV3(simpleHessQuikLayerV1):
  def __init__(self, in_features, out_features):
      super().__init__(in_features, out_features)

  def forward(self, u, do_gradient=False, do_Hessian=False, forward_mode=True, 
              dudx=None, d2ud2x=None):
    (dfdx, d2fd2x) = (None, None)
    f, dsig, d2sig = self.act(u @ self.K, do_gradient=do_gradient, do_Hessian=do_Hessian)
    
    if (do_gradient or do_Hessian):
      dfdx = dsig.unsqueeze(1) * self.K

      if do_Hessian:
        d2fd2x = (d2sig.unsqueeze(1) * self.K).unsqueeze(2) * self.K.unsqueeze(0).unsqueeze(0)
        
        if d2ud2x is not None:
          # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
          # alternative to
          # # Gauss-Newton approximation
          # d2fd2x = dudx.unsqueeze(1) @ (d2fd2x.permute(0, 3, 1, 2) @ dudx.permute(0, 2, 1).unsqueeze(1))
          # d2fd2x = d2fd2x.permute(0, 2, 3, 1)
          
          # # extra term to compute full Hessian
          # d2fd2x += d2ud2x @ dfdx.unsqueeze(1)

          # Gauss-Newton approximation
          d2fd2x = torch.sum(dudx.unsqueeze(2).unsqueeze(-1) * d2fd2x.unsqueeze(1), dim=3)
          d2fd2x = torch.sum(dudx.unsqueeze(2).unsqueeze(-1) * d2fd2x.unsqueeze(1), dim=3)
          
          # extra term to compute full Hessian
          d2fd2x += torch.sum(dfdx.unsqueeze(1).unsqueeze(1) * d2ud2x.unsqueeze(4), dim=3)
          # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #

    if dudx is not None:
      dfdx = dudx @ dfdx

    return f, dfdx, d2fd2x


my_networkV3 = net.NN(simpleHessQuikLayerV3(3, 4), 
                      simpleHessQuikLayerV3(4, 5)).to(device)

nex = 10
x = torch.randn(nex, my_networkV2.dim_input(), device=device)

grad_check, hess_check = input_derivative_check(my_networkV3, x, do_Hessian=True, verbose=True, forward_mode=True)

## Computational Efficiency of Derivative Computation

Let's compare the timing of computing the derivative when constructing networks from these various layers.  Feel free to play with different network parameters.

In [None]:
import time

# parameters to play with
in_features = 2     # number of input featurs
width = 10          # width of network
depth = 4           # number of hidden layers
out_features = 1    # number of output features
nex = 10            # number of examples


# helper functions
def create_network(layer, in_features, out_features, width, depth):

  args = (layer(in_features, width),)
  for _ in range(depth):
    args += (layer(width, width),)
  args += (layer(width, out_features),)

  test_net = net.NN(*args)
  return test_net


def run_test(test_net, x, num_trials=10):
  total_time = 0.0
  for _ in range(num_trials):
    t1_start = time.time()
    f0, df0, d2f0 = test_net(x, do_gradient=True, do_Hessian=True, forward_mode=True)
    t1_end = time.time()
    total_time += t1_end - t1_start
  
  return total_time / num_trials


# inputs
x = torch.randn(nex, in_features, device=device)

# timing tests
test_net = create_network(simpleHessQuikLayerV1, in_features, out_features, width, depth)
test_net = test_net.to(device)
time1 = run_test(test_net, x)
print('Implementation 1: %0.4f' % time1)

test_net = create_network(simpleHessQuikLayerV2, in_features, out_features, width, depth).to(device)
test_net = test_net.to(device)
time2 = run_test(test_net, x)
print('Implementation 2: %0.4f' % time2)

test_net = create_network(simpleHessQuikLayerV3, in_features, out_features, width, depth).to(device)
test_net = test_net.to(device)
time3 = run_test(test_net, x)
print('Implementation 3: %0.4f' % time3)