# Maths Details for Batch Norm Backprop

This [post](http://cthorey.github.io./backpropagation/) showed some of the maths work behind **backproping through batch norm**.

Here I produce my work following this article.

Notations:

* Input minibatch: $X$, size $N \times D$
* Hidden Layer: $H$, with weights $\mathit{w}$ of size $D \times L$ and bias $\mathit{b}$ of size $L$

### Affine Transform:

$$ h = XW + b, \text{ size } = N\times D \cdot D \times L + L = N \times L + L = N \times L$$

### Batch Norm Transform:

$$
\begin{aligned}
\hat{h} &= \frac{h - \mu}{\sqrt{\sigma^2 + \delta}} &\text{size } N \times L \\
y &= \gamma \hat{h} + \beta &\text{size } N \times L
\end{aligned}
$$

where $\gamma$ and $\beta$ are learnt parameters. Both $\mu$ and $\sigma^2$ are vectors of $L$ size.

For matrix $\hat{h}$, each element is calculated as:

$$
\begin{aligned}
k &\text{ - row index} \\
l &\text{ - column index} \\
\hat{h}_{kl} &= \frac{h_{kl} - \mu_l}{\sqrt{\sigma^2_l + \delta}}\\
y_{kl} &= \gamma_l \hat{h}_{kl} + \beta_l \\
\mu_l &= \frac{1}{N}\sum_p h_{pl} \\
\sigma^2_l &= \frac{1}{N} \sum_p \big( h_{pl} - \mu_l \big)^2
\end{aligned}
$$

This means that the normalization is done over columns of the minibatch matrix.

### Activation

$$ a = ReLu(y) $$

### Backprop

Given loss function, $\mathcal{L}$, we need to find: 

$$ \frac{\partial \mathcal{L}}{d \gamma}, \frac{\partial\mathcal{L}}{d \beta}, \frac{\partial \mathcal{L}}{d h} $$

With chain rule:

$$  \frac{\partial\mathcal{L}}{\partial h_{ij}} = \sum_{k,l} \frac{\partial\mathcal{L}}{d y_{kl}} \frac{d y_{kl}}{d \hat{h}_{kl}} \frac{d\hat{h}_{kl}}{d h_{ij}} $$

To derive this:

$$
\begin{aligned}
\frac{d y_{kl}}{d \hat{h}_{kl}} & = \gamma_l 
\end{aligned}
$$

**The last term is complicated, but let's look at it in two parts, using the [product rule](https://en.wikipedia.org/wiki/Product_rule).**

Let's start by defining: 

$$
\begin{aligned}
u &= h_{kl} - \mu_l \\
v &= (\sigma^2_l + \delta)^{-1/2} \\
\frac{\partial u}{d h_{ij}} &= \frac{\partial(h_{kl} - \frac{1}{N}\sum_p h_{pl})}{d h_{ij}}
\end{aligned}
$$

where:

$$
u = h - \mu = 
 \begin{pmatrix}
  h_{1,1} - \frac{1}{N}\sum h_{p, 1} & h_{1,2} -\frac{1}{N}\sum h_{p,2} & \cdots & h_{1,l} -\frac{1}{N}\sum h_{p,l} \\
  \vdots  & \vdots  & \ddots & \vdots  \\
  h_{N,1} - \frac{1}{N}\sum h_{p, 1} & h_{N,2} -\frac{1}{N}\sum h_{p,2} & \cdots & h_{N,l} -\frac{1}{N}\sum h_{p,l} 
 \end{pmatrix}
$$

Therefore:

$$ 
\begin{aligned}
u_{1,1} &= h_{1,1} - \frac{1}{N}\big(h_{1,1} + h_{2, 1} + \cdots + h_{N, 1}\big) \\
\frac{\partial u_{1,1}}{\partial h_{1,1}} &= 1 - \frac{1}{N}(1 + 0 + 0 + \cdots + 0)
\end{aligned}
$$

$$
\ \frac{\partial u}{\partial h_{1,1}} = 
 \begin{pmatrix}
  1 - \frac{1}{N} & 0 & \cdots & 0 \\
  -\frac{1}{N} & 0 & \cdots & 0 \\
  \vdots  & \vdots  & \ddots & \vdots  \\
  -\frac{1}{N} & 0 & \cdots & 0 
 \end{pmatrix}
$$

$$
\ \frac{\partial u}{\partial h_{1,2}} = 
 \begin{pmatrix}
  0 & 1 - \frac{1}{N} & 0 & \cdots & 0 \\
  0 & -\frac{1}{N} & 0 & \cdots & 0 \\
  \vdots  & \vdots  & \ddots & \vdots  \\
  0 & -\frac{1}{N} & 0 & \cdots & 0 
 \end{pmatrix}
$$

To generalize into a more convenient notation, we try to solve for: 

$$\frac{\partial u_{k,l}}{\partial h_{i,j}}$$

We know that: 

$$
\begin{aligned}
\frac{\partial u_{1,1}}{\partial h_{1,1}} &= 1 - \frac{1}{N} & k=1, l=1, i=1, j=1\\
\frac{\partial u_{1,2}}{\partial h_{1,1}} &= 0 & k=1, l=2, i=1, j=2\\
\frac{\partial u_{2,1}}{\partial h_{1,1}} &= -\frac{1}{N} & k=2, l=1, i=1, j=1\\
\frac{\partial u_{2,1}}{\partial h_{1,2}} &= 0 & k=2, l=1, i=1, j=2\\
\frac{\partial u_{1,2}}{\partial h_{1,2}} &= 1 - \frac{1}{N} & k=1, l=2, i=1, j=2\\
\end{aligned}
$$

We have:

$$
\begin{aligned}
\frac{\partial u_{k,l}}{\partial h_{i,j}} =
    \begin{cases}
    k == i, l == j: 1 - \frac{1}{N} \\
    k != i, l == j: - \frac{1}{N} \\
    k == i, l != j: 0 \\
    k != i, l != j: 0 \\
    \end{cases}
\end{aligned}
$$

A more convenient way to write this is to use:

$$
\begin{aligned}
\Delta_{i,j} = 
\begin{cases}
1, i == j \\
0, i <> j
\end{cases}
\end{aligned}
$$

$$ \frac{\partial u_{k,l}}{\partial h_{i,j}} = \Delta_{i,k}\Delta_{j,l} - \frac{1}{N}\Delta_{j,l} $$

Product Rule: $ (f \cdot g)' = f' \cdot g + f \cdot g'$

$$
\begin{aligned}
\frac{\partial v}{\partial h_{i,j}} &= -\frac{1}{2}(\sigma^2_l + \delta)^{-3/2} \cdot \frac{\partial \sigma^2_l}{\partial h_{i,j}} \\
\frac{\partial \sigma^2_l}{\partial h_{i,j}} &= \frac{2}{N} \sum_p (h_{p,l} - \mu_l)(\Delta_{i,p}\Delta_{j,l} - \frac{1}{N}\Delta_{j,l}) \\
&= \frac{2}{N}(h_{i,l} - \mu_l)\Delta_{j,l}
\end{aligned}
$$

**This is better understood when we write out some concrete examples.**

$$
\begin{aligned}
\sigma^2_l &= \frac{1}{N} \sum_p \big(h_{pl} - \mu_l \big)^2 \\
\sigma^2_l &= \frac{1}{N} \sum_p u_p^2 \\
\frac{\partial \sigma^2_l}{\partial h_{i,j}} &= \frac{2}{N} \sum_p u_p \cdot \frac{\partial u_p}{\partial h_{i,j}} \\
\sigma^2_1 &= \frac{1}{N} \big[ (h_{1,1}-\mu_1)^2 + (h_{2,1}-\mu_1)^2 + \cdots + (h_{N,1}-\mu_1)^2 \big] \\
\frac{\partial \sigma^2_1}{\partial h_{1,1}} &= \frac{2}{N} \big[ (h_{1,1} - \mu_1)(1 - 1/N) + (h_{2,1} - \mu_1)(0 - 1/N) + \cdots + (h_{N,1} - \mu_1)(0 - 1/N) \\
&= \frac{2}{N^2} \big[ (h_{1,1} - \mu_1)(N - 1) + (h_{2,1} - \mu_1)(-1) + \cdots + (h_{N,1} - \mu_1)(-1) \big]\\
&= \frac{2}{N^2} \big[ (h_{1,1} - \mu_1)(N - 1) + (\mu_1 - h_{2,1}) + \cdots + (\mu_1- h_{N,1}) \big]\\
&= \frac{2}{N^2} \big[ h_{1,1}N - h_{1,1} - N\mu_1 + \mu_1 + (N-1)\mu_1 - h_{2,1} \cdots - h_{N,1} \big]\\
&= \frac{2}{N^2} \big[ h_{1,1}N - N\mu_1\big]\\
&= \frac{2}{N}(h_{1,1}-\mu_1)
\end{aligned}
$$

$$
\begin{aligned}
\sigma^2_2 &= \frac{1}{N} \big[ (h_{1,2}-\mu_2)^2 + (h_{2,2}-\mu_2)^2 + \cdots + (h_{N,2}-\mu_2)^2 \big] \\
\frac{\partial \sigma^2_2}{\partial h_{1,1}} &= \frac{1}{N} \big[ 0 + 0 + \cdots + 0 \big] \\
&= 0
\end{aligned}
$$

$$
\begin{aligned}
\sigma^2_2 &= \frac{1}{N} \big[ (h_{1,2}-\mu_2)^2 + (h_{2,2}-\mu_2)^2 + \cdots + (h_{N,2}-\mu_2)^2 \big] \\
\frac{\partial \sigma^2_2}{\partial h_{2,2}} &= \frac{2}{N} \big[ (h_{1,2}-\mu_2)(0-1/N) + (h_{2,2}-\mu_2)(1-1/N) + \cdots + (h_{N,2}-\mu_2)(0-1/N) \big] \\
&= \frac{2}{N^2}\big[(\mu_2 - h_{1,2}) + (h_{2,2}N - h_{2,2}-N\mu_2 + \mu_2) + (\mu_2 - h_{3,2}) + \cdots + (\mu_2 - h_{N,2}) \big]\\
&= \frac{2}{N^2}\big[N\mu_2 - N\mu_2 + h_{2,2}N -N\mu_2 \big]\\
&= \frac{2}{N}(h_{2,2}-\mu_2)
\end{aligned}
$$

Finally, we have:

$$
\begin{aligned}
\frac{\partial \hat{h}_{k,l}}{\partial h_{i,j}} &= u' \cdot v + u \cdot v' \\
&= (\Delta_{i,k}\Delta_{j,l} - \frac{1}{N}\Delta_{j,l})(\sigma^2_l + \delta)^{-1/2} -\frac{1}{N}(h_{kl} - \mu_l)(\sigma^2_l + \delta)^{-3/2}(h_{i,l} - \mu_l)\Delta_{j,l}
\end{aligned}
$$

Now we have all the elements needed for $\frac{\partial\mathcal{L}}{\partial h_{ij}}$