# Autodiff and Backpropagation

## Jacobian

Let ${\bf f}:\mathbb{R}^n\to \mathbb{R}^m$, we define its Jacobian as:
\begin{align*}
\newcommand{\bbx}{{\bf x}}
\newcommand{\bbv}{{\bf v}}
\newcommand{\bbw}{{\bf w}}
\newcommand{\bbu}{{\bf u}}
\newcommand{\bbf}{{\bf f}}
\newcommand{\bbg}{{\bf g}}
\frac{\partial \bbf}{\partial \bbx} = J_{\bbf}(\bbx) &= \left( \begin{array}{ccc}
\frac{\partial f_1}{\partial x_1}&\dots& \frac{\partial f_1}{\partial x_n}\\
\vdots&&\vdots\\
\frac{\partial f_m}{\partial x_1}&\dots& \frac{\partial f_m}{\partial x_n}
\end{array}\right)\\
&=\left( \frac{\partial \bbf}{\partial x_1},\dots, \frac{\partial \bbf}{\partial x_n}\right)\\
&=\left(
\begin{array}{c}
\nabla f_1(\bbx)^T\\
\vdots\\
\nabla f_m(x)^T
\end{array}\right)
\end{align*}

Hence the Jacobian $J_{\bbf}(\bbx)\in \mathbb{R}^{m\times n}$ is a linear map from $\mathbb{R}^n$ to $\mathbb{R}^m$ such that for $\bbx,\bbv \in \mathbb{R}^n$ and $h\in \mathbb{R}$:
\begin{align*}
\bbf(\bbx+h\bbv) = \bbf(\bbx) + hJ_{\bbf}(\bbx)\bbv +o(h).
\end{align*}
The term $J_{\bbf}(\bbx)\bbv\in \mathbb{R}^m$ is a Jacobian Vector Product (**JVP**), correponding to the interpretation where the Jacobian is the linear map: $J_{\bbf}(\bbx):\mathbb{R}^n \to \mathbb{R}^m$, where $J_{\bbf}(\bbx)(\bbv)=J_{\bbf}(\bbx)\bbv$.

## Chain composition

In machine learning, we are computing gradient of the loss function with respect to the parameters. In particular, if the parameters are high-dimensional, the loss is a real number. Hence, consider a real-valued function $\bbf:\mathbb{R}^n\stackrel{\bbg_1}{\to}\mathbb{R}^m \stackrel{\bbg_2}{\to}\mathbb{R}^d\stackrel{h}{\to}\mathbb{R}$, so that $\bbf(\bbx) = h(\bbg_2(\bbg_1(\bbx)))\in \mathbb{R}$. We have
\begin{align*}
\underbrace{\nabla\bbf(\bbx)}_{n\times 1}=\underbrace{J_{\bbg_1}(\bbx)^T}_{n\times m}\underbrace{J_{\bbg_2}(\bbg_1(\bbx))^T}_{m\times d}\underbrace{\nabla h(\bbg_2(\bbg_1(\bbx)))}_{d\times 1}.
\end{align*}
To do this computation, if we start from the right so that we start with a matrix times a vector to obtain a vector (of size $m$) and we need to make another matrix times a vector, resulting in $O(nm+md)$ operations. If we start from the left with the matrix-matrix multiplication, we get $O(nmd+nd)$ operations. Hence we see that as soon as $m\approx d$, starting for the right is much more efficient. Note however that doing the computation from the right to the left requires to keep in memory the values of $\bbg_1(\bbx)\in\mathbb{R}^m$, and $\bbx\in \mathbb{R}^n$.

**Backpropagation** is an efficient algorithm computing the gradient "from the right to the left", i.e. backward. In particular, we will need to compute quantities of the form: $J_{\bbf}(\bbx)^T\bbu \in \mathbb{R}^n$ with $\bbu \in\mathbb{R}^m$ which can be rewritten $\bbu^T J_{\bbf}(\bbx)$ which is a Vector Jacobian Product (**VJP**), correponding to the interpretation where the Jacobian is the linear map: $J_{\bbf}(\bbx):\mathbb{R}^n \to \mathbb{R}^m$, composed with the linear map $\bbu:\mathbb{R}^m\to \mathbb{R}$ so that $\bbu^TJ_{\bbf}(\bbx) = \bbu \circ J_{\bbf}(\bbx)$.

**example:** let $\bbf(\bbx, W) = \bbx W\in \mathbb{R}^b$ where $W\in \mathbb{R}^{a\times b}$ and $\bbx\in \mathbb{R}^a$. We clearly have
$$
J_{\bbf}(\bbx) = W^T.
$$
Note that here, we are slightly abusing notations and considering the partial function $\bbx\mapsto \bbf(\bbx, W)$. To see this, we can write $f_j = \sum_{i}x_iW_{ij}$ so that 
$$
\frac{\partial \bbf}{\partial x_i}= \left( W_{i1}\dots W_{ib}\right)^T
$$
Then recall from definitions that
$$
J_{\bbf}(\bbx) = \left( \frac{\partial \bbf}{\partial x_1},\dots, \frac{\partial \bbf}{\partial x_n}\right)=W^T.
$$
Now we clearly have
$$
J_{\bbf}(W) = \bbx \text{ since, } \bbf(\bbx,W+\Delta W) = \bbf(\bbx,W) + \bbx \Delta W.
$$
Note that multiplying $\bbx$ on the right is actually convenient when using broadcasting, i.e. we can take a batch of input vectors of shape $\text{bs}\times a$ without modifying the math above. 

## Implementation

In PyTorch, `torch.autograd` provides classes and functions implementing automatic differentiation of arbitrary scalar valued functions. To create a custom [autograd.Function](https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function), subclass this class and implement the `forward()` and `backward()` static methods. Here is an example:
```python=
class Exp(Function):
    @staticmethod
    def forward(ctx, i):
        result = i.exp()
        ctx.save_for_backward(result)
        return result
    @staticmethod
    def backward(ctx, grad_output):
        result, = ctx.saved_tensors
        return grad_output * result
# Use it by calling the apply method:
output = Exp.apply(input)
```


### Backprop the functional way

Here we will implement in `numpy` a different approach mimicking the functional approach of [JAX](https://jax.readthedocs.io/en/latest/index.html) see [The Autodiff Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#).

Each function will take 2 arguments: one being the input `x` and the other being the parameters `w`. For each function, we build 2 **vjp** functions taking as argument a gradient $\bbu$ and corresponding to $J_{\bbf}(\bbx)$ and $J_{\bbf}(\bbw)$ so that these functions return $J_{\bbf}(\bbx)^T \bbu$ and $J_{\bbf}(\bbw)^T \bbu$ respectively. To summarize, for $\bbx \in \mathbb{R}^n$, $\bbw \in \mathbb{R}^d$, and, $\bbf(\bbx,\bbw) \in \mathbb{R}^m$,
\begin{align*}
{\bf jvp}_\bbx(\bbu) &= J_{\bbf}(\bbx)^T \bbu, \text{ with } J_{\bbf}(\bbx)\in\mathbb{R}^{m\times n}, \bbu\in \mathbb{R}^m\\
{\bf jvp}_\bbw(\bbu) &= J_{\bbf}(\bbw)^T \bbu, \text{ with } J_{\bbf}(\bbw)\in\mathbb{R}^{m\times d}, \bbu\in \mathbb{R}^m
\end{align*}
Then backpropagation is simply done by first computing the gradient of the loss and then composing the **vjp** functions in the right order.

### Example: adding bias

We start with the simple example of adding a bias.

In [1]:
import numpy as np

In [2]:
def add(x, b):
    return x + b

def add_make_vjp(x, b):
    def vjp(u):
        return u, u
    return vjp

add.make_vjp = add_make_vjp

In [3]:
rng = np.random.RandomState(0)
x = rng.random((30,2)).astype('float32')
b_source  = np.array([1.])

In [4]:
xb = add(x,b_source)
np.allclose(xb, x+b_source)

True

In [5]:
vjp_add = add.make_vjp(x,b_source)

In [6]:
grad_x, grad_b = vjp_add(rng.random((30,2)))

In [7]:
grad_x.shape

(30, 2)

### Exercise: dot product and squared loss

Implement the corresponding vjp functions. (Note: we are abusing notation for the squared loss as the target `y` is not a parameter and should not be updated! Moreover, the vjp_{y_pred} function for the squared loss does not depend on its input `u`)

In [8]:
def dot(x, W):
    return np.dot(x, W)

def dot_make_vjp(x, W):
    def vjp(u):
        return np.dot(u, W.T), np.einsum('na,nb-> nab',x , u)
    return vjp

dot.make_vjp = dot_make_vjp

def squared_loss(y_pred, y):
    return np.array([np.sum((y - y_pred) ** 2)])

def squared_loss_make_vjp(y_pred, y):
    def vjp(u):
        diff = y_pred - y
        return 2*diff, np.zeros_like(y)
    return vjp

squared_loss.make_vjp = squared_loss_make_vjp

# Setup

Our model is:
$$
y_t = 2x^1_t-3x^2_t+1, \quad t\in\{1,\dots,30\}
$$

Our task is given the 'observations' $(x_t,y_t)_{t\in\{1,\dots,30\}}$ to recover the weights $w^1=2, w^2=-3$ and the bias $b = 1$.

In order to do so, we will solve the following optimization problem:
$$
\underset{w^1,w^2,b}{\operatorname{argmin}} \sum_{t=1}^{30} \left(w^1x^1_t+w^2x^2_t+b-y_t\right)^2
$$

In [9]:
rng = np.random.RandomState(0)
# generate random input data
x = rng.random((30,2)).astype('float32')
# generate labels corresponding to input data x
y = np.dot(x, [2., -3.]) + 1.
y = np.expand_dims(y, axis=1).astype('float32')
w_source = np.array([2., -3.])
b_source  = np.array([1.])

In [10]:
def create_feed_forward(y, seed=0):
    rng = np.random.RandomState(seed)
    funcs = [dot,add,squared_loss]
    params = [rng.randn(2,1),rng.randn(1),y]
    return funcs, params

### Forward pass

The following function should take a batch of inputs, functions and their parameters and return the final value or all values.

In [11]:
def evaluate_chain(x, funcs, params, return_all=False):
    all_x = [x]
    for (f,p) in zip(funcs,params):
        x = f(all_x[-1],p)
        all_x.append(x)
    
    if return_all:
        return all_x
    else:
        return x

In [12]:
funcs, params = create_feed_forward(y=y, seed=0)
W, b, _ = params

In [13]:
xs = evaluate_chain(x, funcs, params, return_all=True)

### Backward pass

The following function should do the forward pass and then the backward pass.

In [14]:
def backward_diff_chain(x, funcs, params):
    """
    Reverse-mode differentiation of a chain of computations.

    Args:
    x: initial input to the chain.
    funcs: a list of functions of the form func(x, param).
    params: a list of parameters, with len(params) = len(funcs).
    Returns:
    value, vjp_x, all vjp_params
    """
    # Evaluate the feedforward model and store intermediate computations,
    # as they will be needed during the backward pass.
    xs = evaluate_chain(x, funcs, params, return_all=True)
    K = len(funcs)  # Number of functions.
    u = None # the gradient of the loss does not require an input    
    # List that will contain the Jacobian of each function w.r.t. parameters.
    J = [None] * K

    for (k,(f,p,x)) in reversed(list(enumerate(zip(funcs,params,xs)))):
        vjp_x, vjp_param = f.make_vjp(x,p)(u)
        u = vjp_x
        J[k] = vjp_param

    return xs[-1], u, J

In [15]:
loss, grad_x, grads = backward_diff_chain(x, funcs, params)

### Optimizer

First compute the update for each parameter and then modify the parameters.

In [16]:
def optim_SGD(grads, learning_rate = 1e-2):
    return [-learning_rate*g.sum(0) for i,g in enumerate(grads)]

def update_params(updates, params):
    return [params[i] + u for i,u in enumerate(updates)]

### Training loop

In [17]:
funcs, params = create_feed_forward(y=y, seed=0)
W, b, _ = params
for epoch in range(10):
    loss, grad_x, grads = backward_diff_chain(x, funcs, params)
    print("progress:", "epoch:", epoch, "loss",loss)
    updates = optim_SGD(grads)
    params = update_params(updates, params)
    
# After training
print("estimation of the parameters:")
print(params[:-1])

progress: epoch: 0 loss [110.2900527]
progress: epoch: 1 loss [17.17577308]
progress: epoch: 2 loss [15.48445439]
progress: epoch: 3 loss [14.27434616]
progress: epoch: 4 loss [13.16495088]
progress: epoch: 5 loss [12.14630386]
progress: epoch: 6 loss [11.21068817]
progress: epoch: 7 loss [10.35106573]
progress: epoch: 8 loss [9.56101153]
progress: epoch: 9 loss [8.83465916]
estimation of the parameters:
[array([[ 1.6269806 ],
       [-1.09195793]]), array([0.121909])]


### Jax implementation

In [18]:
import jax
import jax.numpy as jnp
import haiku as hk
import optax
from functools import partial

class config:
    size_out = 1
    w_source = jnp.array(W)
    b_source = jnp.array(b)
    
def _linear(x, config):
    return hk.Linear(config.size_out,w_init=hk.initializers.Constant(config.w_source), b_init=hk.initializers.Constant(config.b_source))(x)

def mse_loss(y_pred, y_t):
    return jax.lax.integer_pow(y_pred - y_t,2).sum()

def loss_fn(x_in, y_t, config):
    return mse_loss(_linear(x=x_in, config=config),y_t)

hk_loss_fn = hk.without_apply_rng(hk.transform(partial(loss_fn, config=config)))
params = hk_loss_fn.init(x_in=x,y_t=y,rng=None)
loss_fn = hk_loss_fn.apply

optimizer = optax.sgd(learning_rate=1e-2)

opt_state = optimizer.init(params)
for epoch in range(10):
    loss, grads = jax.value_and_grad(loss_fn)(params,x_in=x,y_t=y)
    print("progress:", "epoch:", epoch, "loss",loss)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    
# After training
print("estimation of the parameters:")
print(params)



progress: epoch: 0 loss 110.29005
progress: epoch: 1 loss 17.175777
progress: epoch: 2 loss 15.484457
progress: epoch: 3 loss 14.274347
progress: epoch: 4 loss 13.164951
progress: epoch: 5 loss 12.146305
progress: epoch: 6 loss 11.21069
progress: epoch: 7 loss 10.351067
progress: epoch: 8 loss 9.561012
progress: epoch: 9 loss 8.83466
estimation of the parameters:
FlatMap({
  'linear': FlatMap({
              'b': DeviceArray([0.12190904], dtype=float32),
              'w': DeviceArray([[ 1.6269805],
                                [-1.0919579]], dtype=float32),
            }),
})


### PyTorch implementation

In [19]:
import torch

dtype = torch.FloatTensor
w_init_t = torch.from_numpy(W).type(dtype)
b_init_t = torch.from_numpy(b).type(dtype)
x_t = torch.from_numpy(x).type(dtype)
y_t = torch.from_numpy(y).type(dtype)

model = torch.nn.Sequential(torch.nn.Linear(2, 1),)

for m in model.children():
    m.weight.data = w_init_t.T.clone()
    m.bias.data = b_init_t.clone()

loss_fn = torch.nn.MSELoss(reduction='sum')

model.train()

optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)

for epoch in range(10):
    y_pred = model(x_t)
    loss = loss_fn(y_pred, y_t)
    print("progress:", "epoch:", epoch, "loss",loss.item())
    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
# After training
print("estimation of the parameters:")
for param in model.parameters():
    print(param)

progress: epoch: 0 loss 110.2900619506836
progress: epoch: 1 loss 17.1757755279541
progress: epoch: 2 loss 15.484455108642578
progress: epoch: 3 loss 14.274348258972168
progress: epoch: 4 loss 13.164952278137207
progress: epoch: 5 loss 12.146303176879883
progress: epoch: 6 loss 11.210689544677734
progress: epoch: 7 loss 10.351065635681152
progress: epoch: 8 loss 9.561013221740723
progress: epoch: 9 loss 8.834660530090332
estimation of the parameters:
Parameter containing:
tensor([[ 1.6270, -1.0920]], requires_grad=True)
Parameter containing:
tensor([0.1219], requires_grad=True)
