# Testing Layers

As you become more comfortable with neural networks, you will come up with new ideas and new layers to try. 

### Outcomes 
In this tutorial, you will


*   learn how to construct a new neural network layer
*   learn how to test your new implementation

### Suggested Activities
* Check out this tutorial on [Extending Autograd](https://pytorch.org/docs/stable/notes/extending.html) and try for yourself!


Check out the [PyTorch Documentation](https://pytorch.org/docs/stable/index.html) as you explore!








## Step 1: Import Packages

We start by importing the necessary packages to run our code.  We are installing the following packages:

   * deep learning toolbox [Pytorch](https://pytorch.org/)
   * visualization toolbox [Matplotlib](https://matplotlib.org/)
   * DNN101 repository [https://github.com/elizabethnewman/dnn101](https://github.com/elizabethnewman/dnn101).

In [None]:
!python -m pip install git+https://github.com/elizabethnewman/dnn101.git


In [5]:
import dnn101
import torch
import torch.nn as nn
import matplotlib as mpl
import matplotlib.pyplot as plt

## Step 2: Create a New Layer

We will create a special residual layer of the form
\begin{align*}
\mathbf{z} = \mathbf{u} + \tanh(\mathbf{K}\mathbf{u})
\end{align*}

We base this off the [PyTorch source code for the Linear layer](https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear). 

We are also going to write our own backwards routine for the features, which mathematically is given by
\begin{align*}
\frac{\partial \mathbf{z}}{\partial \mathbf{u}} \mathbf{v} &= \mathbf{v} + \mathbf{K}^\top (\tanh'(\mathbf{K}\mathbf{u}) \odot \mathbf{v})\\
&=\mathbf{v} + \mathbf{K}^\top ((\mathbf{1} - \tanh^2(\mathbf{K}\mathbf{u})) \odot \mathbf{v})
\end{align*}
where $\mathbf{v}$ is the direction in which we apply the directional derivative $\frac{\partial \mathbf{z}}{\partial \mathbf{u}}$ and $\odot$ is the Hadamard pointwise product. 

We also want to compute is the gradient with respect to the weights.  This is given by
\begin{align*}
\frac{\partial\mathbf{z}}{\partial \mathbf{K}} \mathbf{v} = ((\mathbf{1} - \tanh^2(\mathbf{K}\mathbf{u}))\odot \mathbf{v})\mathbf{u}^\top
\end{align*}

Note that in PyTorch, the data is stored as $(N,H_{\text{in}})$ where $N$ is the number of samples and $H_{\text{in}}$ is the number of input features per sample.  We will consider a layer that returns features of size $(N,H_{\text{out}})$.  This means the mathematical operations will be applied from the right and transposed when we implement.

In [45]:
import torch
import torch.nn as nn
import math

# here is a simple way to implement this layer using automatic differentation
class SpecialResidualLayer(nn.Module):
    def __init__(self, in_features: int, device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.in_features = in_features
        self.weight = nn.Parameter(torch.empty((in_features, in_features), **factory_kwargs))
        self.reset_parameters()

    def reset_parameters(self) -> None:
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + torch.tanh(x @ self.weight.T)

    def extra_repr(self) -> str:
        return 'in_features={}, activation={}'.format(
            self.in_features, self.activation
        )


# here is a way to implement your layer with a self-made backward
class SpecialResidualFunction(torch.autograd.Function):

    @staticmethod
    def forward(u, K):
        # note that we multiply from the right by python conventions 
        z = u + torch.tanh(u.mm(K.t()))
        return z

    @staticmethod
    def setup_context(ctx, inputs, output):
        u, K = inputs
        ctx.save_for_backward(u, K)

    @staticmethod
    def backward(ctx, v):
        u, K = ctx.saved_tensors
        grad_u = grad_K = None
        if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
            grad_a = (1 - torch.tanh(u @ K.t()) ** 2) * v

        if ctx.needs_input_grad[0]:
            grad_u = v + (grad_a).mm(K)
        if ctx.needs_input_grad[1]:
            grad_K = (grad_a.t()).mm(u)

        return grad_u, grad_K


class SpecialResidualLayer2(nn.Module):
    def __init__(self, in_features: int, device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.in_features = in_features
        self.weight = nn.Parameter(torch.empty((in_features, in_features), **factory_kwargs))
        self.reset_parameters()


    def reset_parameters(self) -> None:
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return SpecialResidualFunction.apply(x, self.weight)

    def extra_repr(self) -> str:
        return 'in_features={}'.format(
            self.in_features
        )


## Step 3: Test the Layer




### Option 1: Use ```autograd``` to check

This uses the method of finite differences.

In [None]:
# first, how do we test with Pytorch
func = SpecialResidualFunction.apply

u = torch.randn(11, 3, requires_grad=True)
K = torch.randn(3, 3, requires_grad=True)
torch.autograd.gradcheck(func, (u, K))

### Option 2: Use Taylor series

We describe use a Taylor approximation approach to test the derivative for the weights. 

Suppose we have a smooth, scalar-valued objective function $f$.  We can expand about the weights $\boldsymbol{\theta}$ of unit length in the direction $\mathbf{v}$ using Taylor series as follows:
\begin{align*}
f(\boldsymbol{\theta} + h \mathbf{v}) = f(\boldsymbol{\theta}) + h\langle \nabla f(\boldsymbol{\theta}), \mathbf{v}\rangle + \mathcal{O}(h^2)
\end{align*}

If we compute the gradient correctly, then as $h\to 0$, the absolute error of the linear approximation
\begin{align*}
| f(\boldsymbol{\theta}) + h\langle \nabla f(\boldsymbol{\theta}), \mathbf{v}\rangle - f(\boldsymbol{\theta} + h \mathbf{v})|
\end{align*}
will decay on the order of $h^2$.  This is what we want to observe in our layer.

In [None]:
from copy import deepcopy
from dnn101.utils import convert_to_base

# set data type for better
torch.set_default_dtype(torch.float64)

# set seed for
torch.manual_seed(42)

# layer or network to test
in_features = 3
layer = SpecialResidualLayer2(in_features)

# create data and forward propagate
x = torch.randn(11, in_features, requires_grad=True) # this will pass through our backward
y = layer(x)

# choose loss function (any will do!)
loss = nn.MSELoss()

# compute evaluation without perturbations
y_true = torch.randn_like(y)
out = loss(y, y_true)

# compute gradients
out.backward()

# choose variable to test
perturb_w = True

if perturb_w:
  # fix features (every torch.Tensor has a data and grad attribute)
  w, dw = deepcopy(layer.weight.data), deepcopy(layer.weight.grad)
else:
  x, dx = deepcopy(x.data), deepcopy(x.grad)


with torch.no_grad():
    # form perturbation and compute inner product with

    if perturb_w:
      p = torch.randn_like(dw)
      pgrad = (p * dw).sum()
    else:
      p = torch.randn_like(dx)
      pgrad = (p * dx).sum()

    # MAIN ITERATION
    headers = ('h', 'E0', 'E1')
    print(('{:<20s}' * len(headers)).format(*headers))

    num_test = 15
    E0, E1 = torch.zeros(num_test), torch.zeros(num_test)
    for k in range(num_test):
        # step size
        h = 2.0 ** (-k)

        # perturb weights and forward propgate
        if perturb_w:
          layer.weight = nn.Parameter(w + h * p)
          y_h = layer(x)
        else:
          y_h = layer(x + h * p)

        # evaluate
        out_h = loss(y_h, y_true)

        # compute loss
        err0 = torch.norm(out - out_h)
        err1 = torch.norm(out + h * pgrad - out_h)

        printouts = convert_to_base((err0, err1))
        print(((1 + len(printouts) // 2) * '%0.2f x 2^(%0.2d)\t\t') % ((1, -k) + printouts))

        E0[k] = err0.item()
        E1[k] = err1.item()

    tol = 0.1
    eps = torch.finfo(x.dtype).eps
    grad_check = (sum((torch.log2(E1[:-1] / E1[1:])) > (2 - tol)) > 3)
    grad_check = (grad_check or (torch.kthvalue(E1, num_test // 3)[0] < (100 * eps)))

    if grad_check:
        print('Gradient PASSED!')
    else:
        print('Gradient FAILED.')