# Deep dive into backpropagation

In [1]:
import torch

In [None]:
model = torch.nn.Linear(
    in_features=1,
    out_features=1,
    bias=True
)
model.state_dict() # print the weights and biases of the model

In [None]:
x = torch.randn(1)    # (seq_len, input_size)
x

In [None]:
w = 3.0
b = 2.0

y_true = w * x + b
y_true

In [5]:
loss_fn = torch.nn.MSELoss()

### Forward pass

In [None]:
y_pred = model(x)
y_pred

In [None]:
error = loss_fn(y_pred, y_true)
error

In [None]:
# The loss is the mean squared error between the predicted and true values
(y_pred - y_true)**2

$E = \left( f(w \cdot x + b) - y^* \right)^2$

$$
\begin{align}
E &= \left( y - y^* \right)^2 \\
y &= f(\sigma) \\
\sigma &= w \cdot x + b \\
\end{align}
$$

### Backward pass

In [9]:
error.backward()

In [None]:
# Get gradients
for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"Gradient for {name}: {param.grad}")

In [None]:
model.state_dict()

### Forward computation of $\frac{\partial E}{\partial w}$

#### General case

Using the chain rule:

$$
\frac{\partial E}{\partial w} =
\frac{\partial E}{\partial \color{red}{y}}
\frac{\partial \color{red}{y}}{\partial \color{green}{\sigma}} ~
\frac{\partial \color{green}{\sigma}}{\partial w} ~
$$

knowing that:

$$
\begin{align}
\frac{\partial E}{\partial \color{red}{y}}                     &= 2 (y - y^*) \\
\frac{\partial \color{red}{y}}{\partial \color{green}{\sigma}} &= f'(\sigma) \\
\frac{\partial \color{green}{\sigma}}{\partial w}              &= x \\
\end{align}
$$

we can write:

$$
\frac{\partial E}{\partial w} = 2(y - y^*) \cdot f'(\sigma) \cdot x
$$

#### Current case (no activation function)

Using the chain rule:

$$
\frac{\partial E}{\partial w} =
\frac{\partial E}{\partial \color{green}{\sigma}} ~
\frac{\partial \color{green}{\sigma}}{\partial w} ~
$$

knowing that:

$$
\begin{align}
\frac{\partial E}{\partial \color{green}{\sigma}} &= 2 (\sigma - y^*) \\
\frac{\partial \color{green}{\sigma}}{\partial w} &= x \\
\end{align}
$$

we can write:

$$
\frac{\partial E}{\partial w} = 2(\sigma - y^*) \cdot x
$$

#### Naive detailed computation

Let's write the forward computation in a (naive) detailed way.

In [None]:
f = torch.nn.functional.tanh

sigma = y_pred

grad_w = 2 * (sigma - y_true) * x
grad_w

#### Algebraic computation

Let's rewrite the forward computation in a less naive way (using linear algebra).

...

### Forward computation of $\frac{\partial E}{\partial b}$

#### General case

Using the chain rule:

$$
\frac{\partial E}{\partial b} =
\frac{\partial E}{\partial \color{red}{y}}
\frac{\partial \color{red}{y}}{\partial \color{green}{\sigma}} ~
\frac{\partial \color{green}{\sigma}}{\partial b} ~
$$

knowing that:

$$
\begin{align}
\frac{\partial E}{\partial \color{red}{y}}                     &= 2 (y - y^*) \\
\frac{\partial \color{red}{y}}{\partial \color{green}{\sigma}} &= f'(\sigma) \\
\frac{\partial \color{green}{\sigma}}{\partial b}              &= 1 \\
\end{align}
$$

we can write:

$$
\frac{\partial E}{\partial b} = 2(y - y^*) \cdot f'(\sigma)
$$

#### Current case (no activation function)

Using the chain rule:

$$
\frac{\partial E}{\partial b} =
\frac{\partial E}{\partial \color{green}{\sigma}} ~
\frac{\partial \color{green}{\sigma}}{\partial b} ~
$$

knowing that:

$$
\begin{align}
\frac{\partial E}{\partial \color{green}{\sigma}} &= 2 (\sigma - y^*) \\
\frac{\partial \color{green}{\sigma}}{\partial b} &= 1 \\
\end{align}
$$

we can write:

$$
\frac{\partial E}{\partial w} = 2(\sigma - y^*)
$$

#### Naive detailed computation

Let's write the forward computation in a (naive) detailed way.

In [None]:
f = torch.nn.functional.tanh

sigma = y_pred

grad_b = 2 * (sigma - y_true)
grad_b

#### Algebraic computation

Let's rewrite the forward computation in a less naive way (using linear algebra).

...

In [13]:
# f = torch.nn.functional.tanh

# h1 = f(x[0] @ model.weight_ih_l0 + h0 @ model.weight_hh_l0)   # hidden state at time step 1
# h2 = f(x[1] @ model.weight_ih_l0 + h1 @ model.weight_hh_l0)   # hidden state at time step 2
# h3 = f(x[2] @ model.weight_ih_l0 + h2 @ model.weight_hh_l0)   # hidden state at time step 3

# print(f"Output for time step 1:\nh1 = \n{ h1 }\n\n")
# print(f"Output for time step 2:\nh2 = \n{ h2 }\n\n")
# print(f"Output for time step 3:\nh3 = \n{ h3 }\n\n")