# 03: Back Propagation


In [21]:
import numpy as np
import torch
from torch import nn

![](../img/03_forward_backward.png)

The training loop for a neural network involves:

1. A **forward pass**: Feed the input features ($x_1$, $x_2$) through all the layers of the network to compute our predictions, $\hat{y}$.

2. Compute the **loss** (or cost), $\mathcal{L}(y, \hat{y})$, a function of the predicted values $\hat{y}$ and the actual values $y$.

3. A **backward pass** (or _back propagation_): Feed the loss $\mathcal{L}$ back through the network to compute the rate of change of the loss (i.e. the derivative) with respect to the network parameters (the weights and biases for each node, $w$, $b$)

4. Given their derivatives, update the network parameters ($w$, $b$) using an algorithm like gradient descent.

## Forward Pass: Recap

Each node computes a linear combination of the output of all the nodes in the previous layer, for example:

$$
z_1^{[1]} = w_{1 \rightarrow 1}^{[1]} x_1 + w_{2 \rightarrow 1}^{[1]} x_2 + b_1^{[1]}
$$

This is passed to an activation function, $g$, (assumed to be the same function in all layers here), to create the final output, or "activation", of each node:

$$
a_3^{[2]} = g(z_3^{[2]})
$$

For example, $\hat{y}$, can be expressed in terms of the activation of the final layer as follows:

$$
\hat{y} = a_1^{[3]} = g\left(w_{1 \rightarrow 1}^{[3]} a_{1}^{[2]} + w_{2 \rightarrow 1}^{[3]} a_{2}^{[2]} + w_{3 \rightarrow 1}^{[3]} a_{3}^{[2]} + b_1^{[3]}\right)
$$

The terms not introduced above mean:

- $w_{j \rightarrow k}^{[l]}$: The weight between node $j$ in layer $l-1$ and node $k$ in layer $l$.
- $a_k^{[l]}$: The activation of node $k$ in layer $l$
- $b_k^{[l]}$: The bias term for node $k$ in layer $l$

## Gradient Descent

Let's consider a simpler network, with one input, two hidden nodes, and one output:

![](../img/03_backprop_example_params.png)

Here I've also included a node after the network's output to represent the calculation of the loss, $\mathcal{L}(y, \hat{y})$, where $\hat{y} = g(z_1^{[2]})$ is the predicted value from the network and $y$ the true value.

This network has seven parameters: $w_1^{[1]}$, $w_2^{[1]}$, $b_1^{[1]}$, $b_2^{[1]}$, $w_1^{[2]}$, $w_2^{[2]}$, $b_1^{[2]}$

In gradient descent we use the partial derivative of the loss function with respect to the parameters to update the network, making small changes to the parameters like:

$$
w_1^{[1]}  = w_1^{[1]} - \alpha\frac{\partial \mathcal{L}}{\partial w_1^{[1]}}
$$

where $\alpha$ is the learning rate.

So to perform gradient descent we need the derivatives for each parameter, i.e. we need to compute:

$$
\frac{\partial \mathcal{L}}{\partial w_1^{[1]}},
\frac{\partial \mathcal{L}}{\partial w_2^{[1]}},
\frac{\partial \mathcal{L}}{\partial b_1^{[1]}},
\frac{\partial \mathcal{L}}{\partial b_2^{[1]}},
\frac{\partial \mathcal{L}}{\partial w_1^{[2]}},
\frac{\partial \mathcal{L}}{\partial w_2^{[2]}},
\frac{\partial \mathcal{L}}{\partial b_1^{[2]}}
$$

How can we compute all those terms?

## Chain Rule (for derivatives)

### Case 1: $f$ is a function of $g$, and $g$ is a function of $x$

$$
f = f(g(x))
$$

The chain rule states that the derivative of $w$ with respect to $x$ is given by:

$$
\frac{\mathrm{d} f}{\mathrm{d}x} = \frac{\mathrm{d} f}{\mathrm{d} g} \frac{\mathrm{d} g}{\mathrm{d} x}
$$

#### Example

Find the derivative of $f(x) = (e^x + x)^2$:

Let:

$$
g(x) = e^x + x \\
f(g) = g^2 \\
$$

Then by chain rule:

$$
\frac{\mathrm{d} f}{\mathrm{d} g} = 2g \\
\frac{\mathrm{d} g}{\mathrm{d} x} = e^x + 1 \\
\frac{\mathrm{d} f}{\mathrm{d}x} = \frac{\mathrm{d} f}{\mathrm{d} g} \frac{\mathrm{d} g}{\mathrm{d} x} = 2g(e^x + 1) \\
\frac{\mathrm{d} f}{\mathrm{d}x} = 2(e^x + x)(e^x + 1) 
$$

### Case 2: $f$ is a function of $g$ and $h$, which are both functions of $x$

$$
f = f(g(x), h(x))
$$

To find the derivative of $f$ with respect to $x$, the chain rule states that you must sum over its (partial) derivatives for each input ($g$, $h$):

$$
\frac{\mathrm{d} f}{\mathrm{d}x} = \frac{\partial f}{\partial g} \frac{\mathrm{d} g}{\mathrm{d} x} + \frac{\partial f}{\partial h} \frac{\mathrm{d} h}{\mathrm{d} x}
$$

This is the _multi-variable_ chain rule.

#### Example

Find the derivative of:

$$
f(x) = x^2(3x+1) - \sin(x^2) \\
$$

which can be written as:

$$
f(g, h) = h g - \sin(h) \\
g(x) = 3x + 1 \\
h(x) = x^2
$$

Then by chain rule:

$$
\frac{\partial f}{\partial g} = h\\
\frac{\mathrm{d} g}{\mathrm{d} x} = 3 \\
\frac{\partial f}{\partial h} = g - \cos(h) \\
\frac{\mathrm{d} h}{\mathrm{d} x} = 2x \\
\frac{\mathrm{d} f}{\mathrm{d}x} = \frac{\partial f}{\partial g} \frac{\mathrm{d} g}{\mathrm{d} x} + \frac{\partial f}{\partial h} \frac{\mathrm{d} h}{\mathrm{d} x} =
3h + 2x\left(g-\cos(h)\right) \\
\frac{\mathrm{d} f}{\mathrm{d}x} = 3x^2 + 2x\left(3x + 1 -\cos(x^2)\right)
$$

## Back Propagation

Here's the example network again, but with each edge (arrow) labeled by the partial derviative between the two connected nodes:

![](../img/03_backprop_example_diffs.png)

To compute the derivative of the loss with respect to any term in the network we can use the chain rule. Starting with the loss on the right, we move "backwards" through the network, multiplying the partial derivatives until we get to the term we want.

### Example 1: Computing the gradient for $b_1^{[2]}$

$$
\frac{\color{red}{\partial \mathcal{L}}}{\color{blue}{\partial b_1^{[2]}}} = \frac{\partial z_1^{[2]}}{\color{blue}{\partial b_1^{[2]}}} \frac{\color{green}{\partial a_1^{[2]}}}{\partial z_1^{[2]}} \frac{\color{red}{\partial \mathcal{L}}}{\color{green}{\partial a_1^{[2]}}}
$$

Log loss for one data point (remembering that $\hat{y} = a_1^{[2]}$):

$$
\mathcal{L} = - y \log(a_1^{[2]}) - (1 - y)\log(1 - a_1^{[2]}) \\
\frac{\color{red}{\partial \mathcal{L}}}{\color{green}{\partial a_1^{[2]}}} = -\frac{y}{a_1^{[2]}} + \frac{1-y}{1-a_1^{[2]}}
$$

If using a sigmoid activation function:

$$
a_1^{[2]} = \frac{1}{1+\exp(-z_1^{[2]})} \\
\frac{\color{green}{\partial a_1^{[2]}}}{\partial z_1^{[2]}} = a_1^{[2]} (1 - a_1^{[2]})
$$

$z_1^{[2]}$ is a linear combination of its inputs:

$$
z_1^{[2]} = w_1^{[2]}a_1^{[1]} + w_2^{[2]}a_2^{[1]} + b_1^{[2]} \\
\frac{\partial z_1^{[2]}}{\color{blue}{\partial b_1^{[2]}}} = 1
$$

So overall we could write the loss derivative with respect to the bias as:

$$
\frac{\color{red}{\partial \mathcal{L}}}{\color{blue}{\partial b_1^{[2]}}} =
1 . \frac{\color{green}{\partial a_1^{[2]}}}{\partial z_1^{[2]}}
\frac{\color{red}{\partial \mathcal{L}}}{\color{green}{\partial a_1^{[2]}}} =
\frac{\color{red}{\partial \mathcal{L}}}{\partial z_1^{[2]}}
$$


### Example 2: Computing the gradient for $w_2^{[1]}$

$$
\frac{\color{red}{\partial \mathcal{L}}}{\color{magenta}{\partial w_2^{[1]}}} =
\frac{\color{gray}{\partial z_2^{[1]}}}{\color{magenta}{\partial w_2^{[1]}}}
\frac{\color{orange}{\partial a_2^{[1]}}}{\color{gray}{\partial z_2^{[1]}}}
\frac{\partial z_1^{[2]}}{\color{orange}{\partial a_2^{[1]}}}
\frac{\color{green}{\partial a_1^{[2]}}}{\partial z_1^{[2]}}
\frac{\color{red}{\partial \mathcal{L}}}{\color{green}{\partial a_1^{[2]}}}
$$

We've seen the form of all the derivatives above in the first example, except for the first term:

$$
z_2^{[1]} = w_2^{[1]}x + b_2^{[1]} \\
\frac{\color{gray}{\partial z_2^{[1]}}}{\color{magenta}{\partial w_2^{[1]}}} = x
$$

For the weights after the first layer, the inputs $x$ are replaced by node activations $a$. We can relabel $x = a_1^{[0]}$ to make the general trend clearer.

The last four terms on the right side of the expression for the derivative can be simplified to $\color{red}{\partial \mathcal{L}} / \color{gray}{\partial z_2^{[1]}}$. Then we have:

$$
\frac{\color{red}{\partial \mathcal{L}}}{\color{magenta}{\partial w_2^{[1]}}} =
\frac{\color{gray}{\partial z_2^{[1]}}}{\color{magenta}{\partial w_2^{[1]}}}
\frac{\color{red}{\partial \mathcal{L}}}{\color{gray}{\partial z_2^{[1]}}}
=
a_1^{[0]}\frac{\color{red}{\partial \mathcal{L}}}{\color{gray}{\partial z_2^{[1]}}}
$$

We can also compute that $\frac{\color{gray}{\partial z_2^{[1]}}}{\color{magenta}{\partial b_2^{[1]}}} = 1$ (see example 1), so it follows that $\frac{\partial \mathcal{L}}{\partial w_2^{[1]}} =\frac{\partial \mathcal{L}}{\partial z_2^{[1]}}$.

### Multiple Paths

There is one case not covered by the simplified network and examples above - where you have multiple paths from the output (loss) back to the term of interest. Such as this:

![](../img/03_backprop_multipath.png)

In this case you must sum all the possible paths (this also follows from the multi-variable chain rule).

### Back Propagation and Efficiency

It's important to note that:

- The derivatives in the two examples share many terms in common (e.g. the derivative of the loss with respect to the final output)
- Each term is a fairly simple combination of quantities that must be computed during the forward pass (like the activation values in hidden layers)

These properties of back propagation form the basis for efficient implementations in major frameworks (pytorch, Tensorflow, JAX etc.), mostly via:

- Matrix operations
- Computation graphs
- Caching intermediate values
- Automatic differentiation

## Computation Graphs and Auto-Differentiation

In the background, large frameworks like pytorch use computation graphs and "auto-differentiation" to be able to compute gradients efficiently.

### Example

Here's an example of a simple logistic regression (one layer) network in the form of a _computation graph_:

<img src="../img/03_computation_graph.png" alt="Computation graph" width="500">

- Each node represents either an input variable/parameter (white/clear background), or an operation that applies to one or more of the previously defined values. In this graph the operations are summation ($+$), multiplication ($*$), sigmoid ($g$), and log loss ($\mathcal{L}$).
- The values of all the nodes on a row must be known before the next row can be calculated.

In the computation graph we'll store:

- The relationships between all the nodes (how they are connected and the operations that are performed on them)
- The value of each node for a given input
- The gradient of each node for a given input, with respect to the final node

Having all the node values and the relationships between the nodes let's us compute the gradient at each node efficiently:

#### Forward pass

When doing a forward (top to bottom) pass through the network we store the values computed at all nodes (i.e. including the intermediate values on each row):

In [42]:
# Forward pass

# =================
# Row 1 in diagram
# =================
x1 = 1.5
x2 = 2.5

w1 = -4
w2 = 3

# =================
# Row 2 in diagram
# =================
w1x1 = w1 * x1
w2x2 = w2 * x2

b = -1

# =================
# Row 3 in diagram
# =================
z = w1x1 + w2x2 + b

# =================
# Row 4 in diagram
# =================
yhat = 1 / (1 + np.exp(-z))  # sigmoid

y = 1

# =================
# Row 5 in diagram
# =================
L = -y * np.log(yhat) - (1 - y) * np.log(1 - yhat)  # log loss

print(f"yhat = {yhat:.4f}")
print(f"L = {L:.4f}")

yhat = 0.6225
L = 0.4741


#### Backward pass

Now we can use the node values from the forward pass, our knowledge of the computation graph, and the chain rule to compute the gradients:

In [43]:
# Backward pass

# =================
# Row 5 in diagram
# =================
dL_dyhat = -y / yhat + (1 - y) / (1 - yhat)  # derivative of log loss

# =================
# Row 4 in diagram
# =================
dyhat_dz = yhat * (1 - yhat)  # derivative of sigmoid

dL_dz = dyhat_dz * dL_dyhat

# =================
# Row 3 in diagram
# =================
dz_dw1x1 = 1  # summation nodes pass the same gradient backwards
dz_dw2x2 = 1  # z = w1x1 + w2x2 + b, dz/d(w2x2) = 1
dz_db = 1

dL_dw1x1 = dz_dw1x1 * dL_dz
dL_dw2x2 = dz_dw2x2 * dL_dz
dL_db = dz_db * dL_dz

# =================
# Row 2 in diagram
# =================
dw1x1_dw1 = x1  # multiplication node gradients take the value of the other node
dw2x2_dw2 = x2  # e.g. d(w2x2) / d(w2) = x2

dL_dw1 = dw1x1_dw1 * dL_dw1x1
dL_dw2 = dw2x2_dw2 * dL_dw2x2

print(f"dL_dw1 = {dL_dw1:.4f}")
print(f"dL_dw2 = {dL_dw2:.4f}")
print(f"dL_db = {dL_db:.4f}")

dL_dw1 = -0.5663
dL_dw2 = -0.9439
dL_db = -0.3775


This is an extremely verbose way of representing this, we'll see matrix notation in the next section that will let us represent networks in a much more concise way.

### Pytorch

We can manually do the same operations in pytorch and with pytorch's tensor class, being careful to:

- Set `requires_grad=True` for the parameters we're interested in the gradients of (w1, w2, and b) - we'll come back to this later
- Use torch's implementations of sigmoid (`torch.sigmoid`) and log loss (`torch.nn.functional.binary_cross_entropy`).

Here's the forward pass again:

In [52]:
# Forward pass

# =================
# Row 1 in diagram
# =================
x1 = torch.tensor([1.5], requires_grad=False)
x2 = torch.tensor([2.5], requires_grad=False)

w1 = torch.tensor([-4.0], requires_grad=True)
w2 = torch.tensor([3.0], requires_grad=True)

# =================
# Row 2 in diagram
# =================
w1x1 = w1 * x1
w2x2 = w2 * x2

b = torch.tensor([-1.0], requires_grad=True)

# =================
# Row 3 in diagram
# =================
z = w1x1 + w2x2 + b

# =================
# Row 4 in diagram
# =================
yhat = torch.sigmoid(z)

y = torch.tensor([1.0], requires_grad=False)

# =================
# Row 5 in diagram
# =================
L = nn.functional.binary_cross_entropy(yhat, y)

print(f"yhat = {yhat}")
print(f"L = {L:.4f}")

yhat = tensor([0.6225], grad_fn=<SigmoidBackward0>)
L = 0.4741


Note the values are the same as our own version above.

Now for the magic - in the background pytorch has built a computation graph from the variables we've defined and can compute the gradients (do a backward pass) for us automatically (for the parameters where we've specified `requires_grad=True`), as follows:

In [51]:
L.backward()

print(f"dL_dw1 = {w1.grad}")
print(f"dL_dw2 = {w2.grad}")
print(f"dL_db = {b.grad}")

dL_dw1 = tensor([-0.5663])
dL_dw2 = tensor([-0.9439])
dL_db = tensor([-0.3775])


Again, these match the gradients in our own version, but with a lot fewer lines of code (for us) 🎉

You might remember seeing `.backward()` before in the linear/logistic regression notebooks - hopefully this gives you the intuition for what it's doing! 