# Batch Normalization
Batch normalization is *invented* and widely popularized by the paper *Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift*. In deep neural network, activations between neural layers are extremely dependent on the parameter initialization, which in turn affects how outputs are backprop into each layer during training. Poor initialization can greatly affect how well a network is trained and how fast it can be trained. Batch normalization is a powerful technique for decoupling the weight updates from parameter initialization. Quoted from the paper, *batch normalization allows us to use much higher learning rates and be less careful about initialization.*

Much of the derivation comes from the paper itself and also from Kevin Zakka's blog on Github.

## Notations
* **BN** stands for batch normalization
* $x$ is the input matrix/vector to the **BN** layer
* $\mu$ is the batch mean
* $\sigma^{2}$ is the batch variance
* $\epsilon$ is a small constant added to avoid dividing by zero
* $\hat{x}$ is the normalized input matrix/vector
* $y$ is the linear transformation which scales $x$ by $\gamma$ and $\beta$
* $f$ represents the next layer after **BN** layer, if we assume a forward pass ordering

## Forward Pass
Forward pass is very easy intuitively and mathematically.

First we find the mean across a mini-batch of training examples

$$
\mu = \frac{1}{m} \sum^{m}_{i = 1} x_{i}
$$

Find the variance across the same mini-batch of training examples

$$
\sigma^{2} = \frac{1}{m} \sum^{m}_{i = 1} (x_{i} - \mu)^{2}
$$

And then apply normalization

$$
\hat{x_{i}} = \frac{x_{i} - \mu}{\sqrt{\sigma^{2} + \epsilon}}
$$

Finally, apply linear transformation with learned parameters to enable network to recover identity. In case we wonder why do we need to do this. 
> Note that simply normalizing each input of a layer may change what the layer can represent. For instance, normalizing the inputs of a sigmoid would constrain them to the linear regime of the nonlinearity. To address this, ew make sure that the transformation inserted in the network can represent the identity transform. 
$$
y_{i} = \gamma \hat{x_{i}} + \beta = BN_{\gamma, \beta}(x_{i})
$$

If $\gamma$ is 1 and $beta$ is 0 then the linear transformation is an identity transformation.

In [9]:
import numpy as np

def batch_norm_forward(x, gamma, beta, bn_params):
    eps = bn_params.get('eps', 1e-5)
    momentum = bn_params.get('momentum', 0.9)
    mode = bn_params.get('mode', 'train')
    
    N, D = x.shape
    running_mean = bn_params.get('running_mean', np.zeros(D, dtype=x.dtype))
    running_var = bn_params.get('running_var', np.zeros(D, dtype=x.dtype))
    
    y = None
    if mode == 'train':
        mean = x.mean(axis=0)
        var = x.var(axis=0)
        x_norm = (x - mean) / np.sqrt(var + eps)
        y = x_norm * gamma + beta
        
        # Update running mean and running variance during training time
        running_mean = momentum * running_mean + (1 - momentum) * mean
        running_var = momentum * running_var + (1 - momentum) * var
        
    elif mode == 'test':
        # Use running mean and runningvariance for making test predictions
        x_norm = (x - running_mean) / np.sqrt(running_var + eps)
        y = x_norm * gamma + beta
    else:
        raise ValueError('Invalid forward pass batch norm mode %s' % mode)
    
    bn_params['running_mean'] = running_mean
    bn_params['running_var'] = running_var
    
    return y

x = np.random.rand(5, 5)
bn_params = {}
y = batch_norm_forward(x, 1, 0, bn_params)
print y.mean(axis=0)
print y.var(axis=0)
print bn_params

[  1.33226763e-16   1.22124533e-16   0.00000000e+00  -3.60822483e-17
  -2.22044605e-17]
[ 0.9998869   0.99986432  0.99977066  0.99991854  0.99989989]
{'running_var': array([ 0.00884096,  0.00736936,  0.00435933,  0.01227432,  0.00998846]), 'running_mean': array([ 0.04333566,  0.05228123,  0.06205922,  0.05746892,  0.05503796])}


## Backward Pass
Now here comes the hard part. We are given an upstream gradient, i.e. the gradient of loss function w.r.t to output of the batch normalization layer. 

$$
\frac{\partial L}{\partial y} = \frac{\partial L}{\partial f} \frac{\partial f}{\partial y}
$$

We need to find 

$$
\frac{\partial L}{\partial \hat{x}}, \; \frac{\partial L}{\partial \sigma^{2}}, \; \frac{\partial L}{\partial \mu}, \; \frac{\partial L}{\partial x}, \; \frac{\partial L}{\partial \gamma}, \; and \; \frac{\partial L}{\partial \beta} 
$$


### Gradient of Normalized Input
The derivative of $y$ with respect to $\hat{x}$ is simple:

$$
\frac{\partial y}{\partial \hat{x}} = \gamma
$$

Thus, 

$$
\frac{\partial L}{\partial \hat{x}} = \frac{\partial L}{\partial y} \gamma
$$

In Python:
```python
grad_x_norm = grad_y * gamma # Element-wise multiplication
```

### Gradient of Gamma
The derivative of $y$ with respect to $\gamma$ is:

$$
\frac{\partial y}{\partial \gamma} = \hat{x}
$$

Thus,

$$
\frac{\partial L}{\partial \gamma} = \sum^{m}_{i=1}\frac{\partial L}{\partial y_{i}} \cdot \hat{x}_{i}
$$

We need to perform a sum across all training examples in the mini-batch and squash the shape `(N, M)` to `(M,)`

In Python:
```python
grad_gamma = (grad_y * x_norm).sum(axis=0)
```

### Gradient of Beta
The derivative of $y$ with respect to $\beta$ is:

$$
\frac{\partial y}{\partial \beta} = 1
$$

Thus,

$$
\frac{\partial L}{\partial \beta} = \sum^{m}_{i=1}\frac{\partial L}{\partial y_{i}}
$$

We need to perform a sum across all training examples in the mini-batch and squash the shape `(N, M)` to `(M,)`

In Python:
```python
grad_beta = grad_y.sum(axis=0)
```

### Gradient of Variance
The derivative of $\hat{x}$ with respect to $\sigma^{2}$ is:

$$
\frac{\partial \hat{x}}{\partial \sigma^{2}} = \frac{-1}{2} (x - \mu) (\sigma^{2} + \epsilon)^{-3/2}
$$

Thus,

$$
\frac{\partial L}{\partial \sigma^{2}} = \sum^{m}_{i=1} \frac{\partial L}{\partial \hat{x}_{i}} (\frac{-1}{2}) (x_{i} - \mu) (\sigma^{2} + \epsilon)^{-3/2}
$$

We need to perform a sum across all training examples in the mini-batch and squash the shape `(N, M)` to `(M,)`

In Python:
```python
dvar = (-0.5) * (x - mean) * (var + eps)**(-3.0/2)
grad_var = np.sum(grad_x_norm * dvar, axis=0)
```

### Gradient of Mean
We are going to use chain rule to solve for this gradient:
$$
\frac{\partial L}{\partial \mu} = \frac{\partial L}{\partial \hat{x}} \cdot \frac{\partial \hat{x}}{\partial \mu} + \frac{\partial L}{\partial \sigma^{2}} \cdot \frac{\partial \sigma^{2}}{\partial \mu}
$$

$$
\frac{\partial \hat{x}}{\partial \mu}  = \frac{-1}{\sqrt{\sigma^{2} + \epsilon}}
$$

$$
\frac{\partial \sigma^{2}}{\partial \mu} = \frac{-2}{m}\sum_{i=1}^{m} (x_{i} - \mu)
$$

Thus,

$$
\frac{\partial L}{\partial \mu} =  \sum_{i=1}^{m} \frac{\partial L}{\partial \hat{x}_{i}} \frac{-1}{\sqrt{\sigma^{2} + \epsilon}} + \frac{\partial L}{\partial \sigma^{2}} \frac{-2}{m}\sum_{i=1}^{m} (x_{i} - \mu)
$$

In Python:
```python
dxnorm_dmean = -1 / np.sqrt(var + eps)
dvar_dmean = np.sum((-2 / x.shape[0]) * (x - mean), axis=0)
grad_mean = np.sum(grad_x_norm * dxnorm_dmean, axis=0) + grad_var * dvar_dmean
```

### Gradient of Input
Use chain rule again to solve for the final gradient:

$$
\frac{\partial L}{\partial x} =  \frac{\partial L}{\partial \hat{x}} \cdot \frac{\partial \hat{x}}{\partial x} + \frac{\partial L}{\partial \sigma^{2}} \cdot \frac{\partial \sigma^{2}}{\partial x} + \frac{\partial L}{\partial \mu} \cdot \frac{\partial \mu}{\partial x}
$$

Now fill in the missing pieces:

$$
\frac{\partial \hat{x}}{\partial x} = \frac{1}{\sqrt{\sigma^{2} + \epsilon}}
$$

$$
\frac{\partial \sigma^{2}}{\partial x} = \frac{2 (x - \mu)}{m}
$$

$$
\frac{\partial \mu}{\partial x} = \frac{1}{m}
$$

Now we just plug and chuck

$$
\frac{\partial L}{\partial x_{i}} =  \frac{\partial L}{\partial \hat{x}_{i}} \cdot \frac{1}{\sqrt{\sigma^{2} + \epsilon}} + \frac{\partial L}{\partial \sigma^{2}} \cdot \frac{2 (x_{i} - \mu)}{m} + \frac{\partial L}{\partial \mu} \cdot \frac{1}{m}
$$

In Python:
```python
dxnorm_dx = 1 / np.sqrt(var + eps)
dvar_dx = 2 * (x - mean)
dmean_dx = 1 / x.shape[0]
grad_x = grad_x_norm * dxnorm_dx + grad_var * dvar_dx + grad_mean * dmean_dx
```

## Simplification