In [None]:
# Using torch as the underlying tensor library so that we can easily check if
# our gradients match the correct gradients.

!pip3 install torch
import torch
print(torch.__version__)

# Design

The goal is to build an eager-mode backward-mode autodiff engine, similar to torch. This will be an ultra bare-bones implementation for pedagogical purposes.

## Terminology

Eager-mode means that for each computation of outputs and gradients we build a new graph (this makes it easy to create dynamic graphs structures like loops using the host language, Python). Backward-mode means we calculate gradients of parameters with a backward pass through the computation graph, from output to leaves. This is what "backprop" refers to. Autodiff ([automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation)), distinct from [symbolic differentiation](https://en.wikipedia.org/wiki/Symbolic_differentiation "Symbolic differentiation") and [numerical differentiation](https://en.wikipedia.org/wiki/Numerical_differentiation "Numerical differentiation"), uses the symbolic structure of a computation graph to calculate symbolically exact derivatives at a given input (the result is not symbolic but numerical, and we are also not performing a numerical estimation but doing an exact calculation up to rounding error).

## Design Of An Autodiff Engine

Let's assume we have access to some pre-existing tensor library, i.e. a library supports multi-dimensional array operations. In this demo, we are using Torch, which has the benefit of also being an autodiff engine so that we can easily compare our calculated gradients to the correct gradients for testing purposes.

Suppose `x` is some tensor and `F` is a function that takes in a tensor and outputs a tensor. Then `y = F(x)` is a tensor. However, we want to record that we called `F` on `x` to produce `y`. So let's have `F` return `x` wrapped in a graph node instance which stores `F` and points to `x` as a child. Then `y` is this node object wrapping the output tensor.

Suppose `G` is another such tensor function so that `z = G(y)` is a node instance containing the result of the computation `G(F(x))` and also points to the node `y` as its child.

In general we might be interested in the Jacobian of `z` w.r.t. `x`, i.e. the tensor of all partial derivatives of elements in `z` w.r.t. elements in `x`, evaluated at `x`. However, we will restrict ourselves to calculating gradients where `z` is a scalar (1-tensor). This is the typical use case in machine learning, where `z` is a loss, and this is also computationally simpler to deal with.

When `z` is a scalar, the Jacobian of `z` w.r.t. `x` is the tensor of partial derivatives of `z` w.r.t. each element of `x`, called the gradient. This gradient is the same shape as `x`.


### Tensor Chain Rule

Let $f$ and $g$ be tensor-valued functions.  
The *shape* of a tensor is a tuple of numbers, each referring to the number of positions in each tensor dimension.  
Let $J$ be the Jacobian operator, so that e.g. $J[f]$ is the Jacobian of $f$ w.r.t. its input.

Suppose,  
$f$ has input shape $(n_1,\dots,n_r)$ and output shape $(m_1,\dots,m_s)$, and    
$g$ has input shape $(m_1,\dots,m_s)$ and output shape $(p_1,\dots,p_t)$.


$J[f]$ has shape $(m_1,\dots,m_s \mid n_1,\dots,n_r)$ (tensor consisting of partial derivatives of every output dimension w.r.t. every input dimension).    
$J[g]$ has shape $(p_1,\dots,p_t \mid m_1,\dots,m_s)$.  
$J[g\circ f]$ must have shape $(p_1,\dots,p_t \mid n_1,\dots,n_r)$.  
Here I am distinguishing between input and output indices in the Jacobian with a middle bar.

Now we can write the tensor chain rule:  
Letting lower indices be input indices, and upper indices be output indices, we have
$$J[g\circ f]_{i_1,\dots,i_r}^{j_1,\dots,j_t} = \sum_{k_1,\dots,k_s} J[g]_{k_1,\dots,k_s}^{j_1,\dots,j_t}J[f]_{i_1,\dots,i_r}^{k_1,\dots,k_s}$$

Or written out in terms of partial derivatives, $\newcommand{\pd}{\partial}\newcommand{\pdiff}[2]{\frac{\pd{#1}}{\pd{#2}}}$

$$
\pdiff{(g(f(x))^{j_1,\dots,j_t})}{(x_{i_1,\dots,i_r})} = \sum_{k_1,\dots,k_s} \pdiff{(g(f(x))^{j_1,\dots,j_t})}{(f(x)^{k_1,\dots,k_s})}\pdiff{(f(x)^{k_1,\dots,k_s})}{(x_{i_1,\dots,i_r})}
$$

### Design of backward-mode differentiation

From the tensor chain rule, we can see how autodiff would work. Each graph node, in addition to storing its input children, output, and function called, should also store a Jacobian. 

However, producing explicit Jacobians and summing over them is often unnecessarily costly, since they will tend to sparse 4-tensors (most entries are 0s). The conventional solution to this problem is to have each graph node provide a `backward` function which takes in a gradient tensor (w.r.t. that node's output tensor) and produces a new gradient tensor (w.r.t. that node's input tensor) performing whatever optimizations are necessary during the computation under the hood.

Thus this autodiff engine design only supports gradients, i.e. derivatives w.r.t. a scalar output, and does not in general produce explicit Jacobian tensors. But this is sufficient for our purposes. While Jacobian functions are provided in libraries like TensorFlow and Torch, they are not often invoked in deep learning applications.

#### Vector-Jacobian Product (VJP)

When $g\circ f$ outputs a scalar, then the tensor chain rule reduces to 

$$J[g\circ f]_{i_1,\dots,i_r} = \sum_{k_1,\dots,k_s} J[g]_{k_1,\dots,k_s}J[f]_{i_1,\dots,i_r}^{k_1,\dots,k_s}$$

where $J[f]$ has shape $(m_1,\dots,m_s \mid n_1,\dots,n_r)$ and $J[g]$ has shape $(1 \mid m_1,\dots,m_s)$.

If we flatten these tensors so that $J[f]$ has shape $(m \mid n)$ and $J[g]$ has shape $(1 \mid m)$, then the chain rule becomes a matrix product:

$$J[g\circ f] = J[g]J[f]$$

where $J[f]$ is an $m\times n$ matrix and $J[g]$ is a $1\times m$ matrix (a row-vector). Hence we have a vector-Jacobian product (as opposed to Jacobian-vector product, the reverse order of the Jacobian times a column vector, which is used in forward-mode autodiff, see [this article](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#how-it-s-made-two-foundational-autodiff-functions) for details).

Again, doing this product with an explicit Jacobian is needlessly compute intensive, since that Jacobian matrix is likely very sparse. What is done in practice, and what we will do, is have each computation node store a `vjp` function which takes in the "vector" $J[g]$ (which is just a gradient since $g$ outputs a scalar) and outputs $J[g\circ f]$ (which is also just a gradient).

Specifically, if our tensor-valued function `F` has input shape $(n_1,\dots,n_r)$ and output shape $(m_1,\dots,m_s)$, then the function call `F(x)` should return a computation node which contains `x`, `F`, and `vjp`, where `vjp` is a tensor-valued function with input shape $(m_1,\dots,m_s)$ and output shape $(n_1,\dots,n_r)$ (reverse of `F`). When called, we have `g_in = vjp(g_out)`, where `g_out` is the gradient w.r.t. some scalar output at the top of our computation graph, and `g_in` is the gradient w.r.t. that same scalar going through the call `F` evaluated at `x`. "Gradient" here means a tensor.

More specifically, let our computation graph be the chain $g = f_r\circ \dots\circ f_1$ where $f_r$ outputs a scalar,  and let $y = g(x)$ with $h_i = f_i \circ \dots \circ f_1(x)$ so that $y = h_r$. Then $J[f_r\circ \dots\circ f_{i+1}]$ has shape $(1\mid \text{shape}(h_i))$, which we are calling the gradient of $f_r\circ \dots\circ f_{i+1}$ w.r.t. its input $h_i$. The function `vjp` at node $i$ in this graph takes in $J[f_r\circ \dots\circ f_{i+1}]$ with shape $(1\mid \text{shape}(h_i))$ and outputs $J[f_r\circ \dots\circ f_{i}]$ with shape $(1\mid \text{shape}(h_{i-1}))$. So the explicit gradient tensors flowing through the graph (from top to bottom) are not very large Jacobians, and we get no combinatorial explosion. We avoid explicitly passing around $J[f_i]$ with shape $(\text{shape}(h_{i}) \mid \text{shape}(h_{i-1}))$, which may be a much larger tensor.


#### Meta Chain Rule

So far we've only considered computation graphs which are chains. Extending our autodiff engine to arbitrary DAGs is straightforward.

The multivariate chain-rule states that for $f(x(z),y(z))$, we have

$$
\frac{\text{d} f}{\text{d} z} = \frac{\partial f}{\partial x}\frac{\text{d} x}{\text{d} z} + \frac{\partial f}{\partial y}\frac{\text{d} y}{\text{d} z}
$$

which extends to more than two arguments in the way you'd expect.

The tensor chain rule contains this logic within it if we consider $f$ to take a single tensor argument which contains two elements.

However, what about the case when $f$ takes as input multiple tensor arguments? What I call the meta chain rule, the multivariate chain rule for scalars extends to a multivariate chain rule for tensors in the obvious way: letting $f,x,y$ be tensor-valued functions and $z$ be a tensor, we have

$$J_{z}[f(x(z),y(z))] = \sum_{\text{indices} \in \text{shape}(x)} J_x[f]_{\text{indices}}J_z[x]^{\text{indices}}+\sum_{\text{indices} \in \text{shape}(x)}J_y[f]_{\text{indices}}J_z[y]^{\text{indices}}$$

etc.  
which we can notationally simplify to using [Einstein summation notation](https://en.wikipedia.org/wiki/Einstein_notation) (where matching upper and lower indices are automatically summed over):

$$J_{z}[f(x(z),y(z))] = J_x[f]J_z[x]+J_y[f]J_z[y]$$

What this amounts to for our design is the following:

Suppose `F` takes as input a tuple of inputs and gives a tuple of outputs, all elements being tensors. E.g. `a, b = F(x, y)`. Then we want `a` and `b` to each be distinct graph node instances. Each node contains a single output, but points to many child nodes as inputs. So `a` and `b` both point to `x` and `y` as children. However, `a` and `b` will have different `vjp` functions due to the gradients "passing through" different paths, i.e. through `a` to `(x,y)` vs through `b` to `(x,y)`.

### Graph Traversals

How to traverse our computation graph and accumulate gradients is the final piece of the puzzle.

Suppose we are calculating the gradient of some top scalar node `y` w.r.t. some leaf node `x`. Further suppose `x` is an input to many different nodes in the graph. Every node where `x` is an input has its own `vjp` function.

Starting from `y` we do a directed graph traversal (BFS or DFS or whatever), going from parent to children, so that we traverse each directed edge exactly once. Whenever we hit a leaf node `x` (tensor instance), we at that point will have produced the gradient of `y` output w.r.t. `x` **via that path from `y` to `x`**. The meta chain rule simply tells us to sum the gradients across all distinct edges which point to that same `x` (so every place `x` appears as a child to any node in the graph). Thus we simply accumulate a running sum of gradients whenever we hit `x` via a distinct edge. Once we've completed the traversal, we return the accumulated gradient sum.

We can just as easily calculate gradients w.r.t. multiple different input tensors, e.g. `x_1`, ..., `x_k`. We have a separate gradient accumulation for each such dinstinct tensor instance and add to the corresponding accumulation whenever we hit the tensor instance.


### Other Design Considerations

We want the ability for ops to take construction-time arguments as well as inputs. For example, `F(x, activation=s)` passes in an activation function. Or, e.g. `G(F(x), flag=True)`. These auxilliary arguments are hyperparameters and not intended to be differentiable, so we don't want our autodiff engine to attempt to "pass gradients through" them. Using Python syntactic sugar, we can requires that `F`, `G`, etc. produce graph nodes which contain `vjp` functions that ignore these keyword arguments, and have the graph structure not reflect these keyword arguments.

## Further Reading

I obtained the bulk of my understanding of how to implement autodiff, and specifically vector-Jacobian products, from the [JAX docs](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#how-it-s-made-two-foundational-autodiff-functions).

PyTorch also has some exposition of autodiff in [their blog post](https://pytorch.org/blog/overview-of-pytorch-autograd-engine/). The references at the bottom of their post seem useful as well.

For a more comprehensive pedagogical implementation of autodiff, see  
https://github.com/mattjj/autodidact





# Implementation

In [None]:
# Decorator for defining node computation functions from "regular" tensor valued functions.
# This point of this is so we can define graph operations entirely in terms of torch tensors,
# without having to think about ComputeNode objects.
# However, we still need to explicitly define corresponding `vjp` functions for the backward pass.
def graph_computation(func):
  def wrapper(*args, **kwargs):
    # args should be ComputeNode or tensor instances.
    # kwargs are hyperparameters which are not differentiable.
    
    # Unwrap any ComputeNode arguments and pass their output tensors along as input to this function call.
    unwrapped = [(arg.output if isinstance(arg, ComputeNode) else arg) for arg in args]
    results, vjps = func(*unwrapped, **kwargs)
    # We expect `func` to return both output tensors and a `vjp` function for each output.

    if not isinstance(results, (tuple, list)):
      results = [results]
      vjps = [vjps]
    output = [ComputeNode(output=r, inputs=args, vjp=vjp, op=func, kwargs=kwargs) for r, vjp in zip(results, vjps)]
    if len(output) == 1: output = output[0]

    # Return the output tensor(s) wrapped in `ComputeNode` objects containing all the correct information.
    # Returns a single `ComputeNode` if `func` returns a single output, or a list of `ComputeNode` if `func` returns a list of outputs.
    return output
  return wrapper


class ComputeNode:
  
  def __init__(self, *, output, inputs, vjp, op, kwargs):
    # leaf node if `inputs` is None
    self.output = output
    self.inputs = inputs
    self.vjp = vjp
    self.op = op
    self.kwargs = kwargs  # keyword dictionary of hyperparameters
    # `op` should be a pure function (stateless) so that we can reproduce `output`
    # given `op`, `inputs` and `kwargs`.
    self._grads = None   # intermediate gradients for debugging purposes
    
  def backprop(self, targets, grad=None, cum_gradients=None):
    # calculate jacobian of this output w.r.t. a list of ComputeNode instances
    # recursively calls backprop on input nodes - assumes no directed cycles
    
    if grad is None:
      grad = torch.ones_like(self.output)
    if cum_gradients is None:
      return_grads = True
      cum_gradients = {id(v): torch.zeros_like(v) for v in targets}
    else:
      return_grads = False
      
    input_grads = self._grads = self.vjp(grad)

    if not isinstance(input_grads, (list, tuple)):
      input_grads = [input_grads]
    for a, g in zip(self.inputs, input_grads):
      if torch.is_tensor(a):
        if id(a) in cum_gradients:
          cum_gradients[id(a)] += g
      else:  # instance of ComputeNode
        a.backprop(targets, g, cum_gradients)
        
    if return_grads:
      return [cum_gradients[id(t)] for t in targets]
    
  def __repr__(self):
    return repr(self.output)
  

In [None]:
# Define some operations which will let us construct a feed foward network and loss.

@graph_computation
def affine(x, w, b):
  # output
  y = x @ w + b
  
  # vjp function
  def vjp(v):
    # calculate vJ, where v is a row vector and J is the jacobian matrix of this calculuation,
    # with J having shape (prod(output_shape), prod(input_shape))
    
    # w.r.t. x
    gx = v @ w.T
    
    # w.r.t. w
    gw = x.T @ v
    
    # w.r.t. b
    gb = v.sum(axis=0)
    
    return gx, gw, gb
  
  return y, vjp

@graph_computation
def sigmoid(x):
  y = 1/(1+(-x).exp())
  
  def vjp(v):
    dy = x.exp()/(1 + x.exp())**2
    return v*dy
  
  return y, vjp


@graph_computation
def relu(x):
  # y = torch.maximum(x, torch.zeros_like(x))
  y = torch.clamp(x, min=0)

  def vjp(v):
    # Note that the torch implementation makes the gradient 1 at x=0
    return v * (x >= 0).float()

  return y, vjp


@graph_computation
def square_error_loss(x, *, y):
  diff = x - y
  L = (diff**2).sum(axis=-1)
  
  def vjp(v):
    return v * 2 * diff
  
  return L, vjp

@graph_computation
def tensor_sum(x):
  y = x.sum()
  
  def vjp(v):
    return v * torch.ones_like(x)
  
  return y, vjp
  

@graph_computation
def tensor_concat(*tensors, axis=0):
  y = torch.cat(tensors, dim=axis)

  def vjp(v):
    v_shape = v.shape
    grads = []
    j = axis if axis >= 0 else axis + len(v_shape)
    offset = 0
    fv = torch.reshape(v, (-1,)+v_shape[j:])  # if j == 0, then we add a dimension of size 1 to the front
    output_shape = v_shape[:j] + (-1,) + v_shape[j+1:]
    for t in tensors:
      width = t.shape[j]
      sl = v[:, offset:offset+width]
      sl = torch.reshape(sl, output_shape)
      grads.append(sl)
      offset += width
    return grads

  return y, vjp


@graph_computation
def tensor_index(x, *, axis, index):
  if axis < 0:
    axis += len(x.shape)
  get_arg = (slice(None),)*axis + (index,)
  y = x[get_arg]

  def vjp(v):
    r = torch.zeros_like(x)
    r[get_arg] = v
    return r
  
  return y, vjp

# Test

In [None]:
# Compare to gradients calculated by torch

def get_torch_grads(target, params):
  # zero out previous cum gradients
  for p in params:
    if p.grad is not None:
      p.grad.zero_()
  # update cum gradients
  target.backward(torch.ones_like(target), retain_graph=True)
  return [p.grad for p in params]

## Feed-forward network

In [None]:
# build a simple NN with toy data

# data - shape == (batch, features)
x = torch.tensor([[1,2,3], [4,5,6]], dtype=float, requires_grad=True)
y = torch.tensor([[1, 0], [0, 1]], dtype=float, requires_grad=True)

# params
w1 = torch.tensor([[1, 1], [-1, 1], [-2, 2]], dtype=float, requires_grad=True)
b1 = torch.tensor([[0, 1]], dtype=float, requires_grad=True)
w2 = torch.tensor([[.2, .5], [.5, -.5]], dtype=float, requires_grad=True)
b2 = torch.tensor([[-1, .5]], dtype=float, requires_grad=True)
w3 = torch.tensor([[.7, -.5], [-.2, .3]], dtype=float, requires_grad=True)
b3 = torch.tensor([[.3, -.2]], dtype=float, requires_grad=True)
params = [w1, b1, w2, b2, w3, b3]

# build NN
h0 = x
h1 = sigmoid(affine(h0, w1, b1))
h2 = relu(affine(h1, w2, b2))
h_out = affine(h2, w3, b3)
L = tensor_sum(square_error_loss(h_out, y=y))
L  # view the output

In [None]:
# invoke our autodiff engine!
my_grads = L.backprop(params)
my_grads  # view our calculated gradients

In [None]:
# compare to torch grads
torch_grads = get_torch_grads(L.output, params)
print('matches:', [torch.allclose(my_g, tc_g) for my_g, tc_g in zip(my_grads, torch_grads)])
# If all are True then we've succeeded

## Recurrent network

In [None]:
# build a recurrent NN with toy data

# data - shape == (batch, time, features)
x = torch.tensor([[[1,2,3], [4,5,6], [7, 8, 9]], [[-1, 1, -2], [2, -3, 3], [-2, 3, -4]]], dtype=float, requires_grad=True)
y = torch.tensor([[1, 0], [0, 1]], dtype=float, requires_grad=True)

# params
w1 = torch.tensor([[1, 1], [-1, 1], [-2, 2], [.5, -.5], [2, -2]], dtype=float, requires_grad=True)
b1 = torch.tensor([[0, 1]], dtype=float, requires_grad=True)
w2 = torch.tensor([[.2, .5], [.5, -.5]], dtype=float, requires_grad=True)
b2 = torch.tensor([[-1, .5]], dtype=float, requires_grad=True)
params = [w1, b1, w2, b2]

# build NN
s = torch.zeros((2, 2))
for time_step in range(x.shape[1]):
  input = tensor_index(x, axis=1, index=time_step)
  input = tensor_concat(input, s, axis=-1)
  s = relu(affine(input, w1, b1))
final = affine(s, w2, b2)
L = tensor_sum(square_error_loss(final, y=y))
L  # view the output

In [None]:
# invoke our autodiff engine!
my_grads = L.backprop(params + [x])
my_grads  # view our calculated gradients

In [None]:
# compare to torch grads
torch_grads = get_torch_grads(L.output, params + [x])
print('matches:', [torch.allclose(my_g, tc_g) for my_g, tc_g in zip(my_grads, torch_grads)])
# If all are True then we've succeeded