# Tutorial #6: Dropout and BatchNorm

### 6.1: Dropout

Dropout is a method of regularizing networks. Dropout is an extremely simple: during training, multiply the input to each unit by either 0 or 1, multiplying by 1 with probability $(1-p)$ and by 0 with probability $p$. Typically, $p=0.5$ is chosen for hidden units, as it empirically tends to work well in applications.

During testing, none of the units are multiplied by zero. Empiricially, it has been shown that dropout has better performance if the expected value of the output of each hidden unit matches the distribution seen during training. Since $1/(1-p)$ times more hidden units are included in each linear transformation during testing compared to training, the expected value of each output unit will be multiplied by $1/(1-p)$. To counteract this multiplication in expected value, the weights of the linear transformation are simply multiplied by $(1-p)$ during testing.

A simpler option (called inverted dropout) is to do nothing differently during testing, but during training to multiply the hidden units by $1/(1-p)$ in addition to setting them to zero with probability $p$. This keeps the distribution of hidden units comparable between testing and training.

For a layer within a neural network with input $\boldsymbol{x}$, non-linear activation function $g(\boldsymbol{z})$, and linear transformation matrix $\boldsymbol{W}$, and bias $\boldsymbol{b}$, the output $\boldsymbol{h}$ of the hidden layer is given by $$\boldsymbol{h} = g(\boldsymbol{W} \boldsymbol{x} + \boldsymbol{b}).$$ Dropout performs elementwise multiplication of $\boldsymbol{x}$ with a binary mask $\boldsymbol{r}$ where each element of $\boldsymbol{r}$ is drawn from a Bernoulli distribution with probability $1-p$. During training, the dropout layer becomes (assuming inverted dropout) $$r_j \sim \frac{1}{1-p}\textnormal{Bernoulli}(1-p)$$ $$ \boldsymbol{\tilde{x}} = \boldsymbol{r} * \boldsymbol{x}$$ $$\boldsymbol{h} = g(\boldsymbol{W} \boldsymbol{\tilde{x}} + \boldsymbol{b}).$$ During testing, the layer is still $$\boldsymbol{h} = g(\boldsymbol{W} \boldsymbol{x} + \boldsymbol{b}).$$

Dropout tends to take a different form for convolutional networks, recurrent networks, and attention layers than in MLPs. However, the basic idea of probabalistically setting hidden units to zero still remains.

Good resources for understanding dropout include the original 2014 paper by [Srivastava et al.](TODO), section 7.12 of the [Deep Learning Book](TODO), and [this blog post](https://medium.com/biased-algorithms/the-role-of-dropout-in-neural-networks-fffbaa77eee7#:~:text=Dropout%20has%20been%20successfully%20integrated,layers%20after%20the%20convolutional%20blocks.).

### 6.1.1: Dropout in `flax`

Dropout in `flax` is extremely easy to implement. We simply create a `nnx.Dropout(p, rngs)` layer, and call it to mask the input to the each layer. 

In [1]:
import flax
import jax
from flax import nnx
from jax import numpy as jnp

In [2]:
class DropoutTestMLP(nnx.Module):
    def __init__(self, rngs: nnx.Rngs):
        self.linear1 = nnx.Linear(4, 16, rngs=rngs)
        self.linear2 = nnx.Linear(16, 10, rngs=rngs)
        self.dropout = nnx.Dropout(0.5, rngs = rngs)

    def __call__(self, x):
        x = nnx.relu(self.linear1(x))
        return self.linear2(self.dropout(x))

In flax, we have to call `model.train()` to set the dropout layers to mask the inputs, and `model.eval()` to multiply the weights by $1/p$ for testing. 

In [3]:
model = DropoutTestMLP(nnx.Rngs(0))

model.train()

test_input = jax.random.normal(jax.random.PRNGKey(0), (4,))
print("First train dropout, mask chosen randomly")
print(model(test_input))
print("Second train dropout, mask chosen randomly with different RNG key")
print(model(test_input))

model = DropoutTestMLP(nnx.Rngs(0))

model.eval()

print("First eval dropout, no mask")
print(model(test_input))
print("Second eval dropout, deterministic output same as first eval")
print(model(test_input))

First train dropout, mask chosen randomly
[-0.24484156  1.3231333  -3.9299839  -3.0473146   3.225388    0.22339313
  0.5151764  -1.9571856   0.47121215  0.05025537]
Second train dropout, mask chosen randomly with different RNG key
[-1.6144576   0.61325073 -3.1674788  -2.2446618   2.230607   -0.08878337
  1.1725848   0.04968039  0.73547506  2.1024349 ]
First eval dropout, no mask
[-0.87284786  0.84970397 -1.2577161  -0.75616425  0.4537591  -0.6594648
  0.9112347   0.19184317  0.8342736   1.534824  ]
Second eval dropout, deterministic output same as first eval
[-0.87284786  0.84970397 -1.2577161  -0.75616425  0.4537591  -0.6594648
  0.9112347   0.19184317  0.8342736   1.534824  ]


When using Dropout layers, we have to remember to call these functions before training and testing, and when evaluating performance on a development set or testing set.

#### 6.1.2: Understanding `nnx.Dropout`

Let's create a dropout layer and display it using `nnx.display()`.

In [4]:
layer = nnx.Dropout(0.2, rngs=nnx.Rngs(0))
nnx.display(layer)

Dropout(rate=0.2, broadcast_dims=(), deterministic=False, rng_collection='dropout', rngs=Rngs(
  default=RngStream(
    key=RngKey(
      value=Array((), dtype=key<fry>) overlaying:
      [0 0],
      tag='default'
    ),
    count=RngCount(
      value=Array(0, dtype=uint32),
      tag='default'
    )
  )
))


There are a couple of things to note here. First, the dropout layer takes as input the keyword argument `broadcast_dims`, which allows for an entire dimension of activations to be dropped, rather than individual activations to be dropped at random. 

In [5]:
layer = nnx.Dropout(0.5, broadcast_dims=(1,), rngs=nnx.Rngs(0))
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (10,2))
print(layer(x))

[[ 0.          0.        ]
 [-0.8671889  -0.1572347 ]
 [ 0.3521818  -1.9441785 ]
 [-0.9905975   0.9887572 ]
 [ 1.3286986  -1.900327  ]
 [ 0.          0.        ]
 [ 0.          0.        ]
 [ 2.5541694   3.0209296 ]
 [ 0.          0.        ]
 [ 0.04940141 -3.8329544 ]]


This is useful when using dropout with convolutional layers, as dropout with convolutional layers will often choose to drop out entire feature maps. For example, applying a convolutional layer to a 2D image with C channels of size H x W x C will result in a matrix of size H x W x D, where D is the number of output filters; dropout sets to zero all of the activations in each of the D filters at random with probability $p$.

Second, the probability is the probability of dropping the activation and NOT the probability of keeping the activation. In the below example, we choose $p=0.9$ and see that almost all of the activations are dropped out.

In [6]:
layer = nnx.Dropout(0.9, rngs=nnx.Rngs(0))
key, subkey = jax.random.split(key)
x = jax.random.normal(subkey, (10,))
print(layer(x))
print(layer(x))
print(layer(x))

[ 0.        0.        0.        0.        0.        0.        0.
  0.        0.       21.302025]
[-24.424557 -20.356806   0.         0.         0.         0.
   0.         0.       -13.105359   0.      ]
[ 0.        0.        0.        0.        0.        0.        0.
  0.        0.       21.302025]


Third, by default dropout layers are set to be in training mode (`deterministic=False`) rather than evaluation mode (`deterministic=True`). But by calling `model.eval()`, we set `deterministic=True`. In eval mode, none of the activations are dropped out.

In [7]:
layer = nnx.Dropout(0.5, rngs=nnx.Rngs(1))
key, subkey = jax.random.split(key)
x = jax.random.normal(subkey, (10,))

print(layer.deterministic)
print(layer(x))

layer.eval()
print(layer.deterministic)
print(layer(x))

False
[ 0.         -0.8032088   0.          1.7567555  -1.7235099   0.
  0.         -0.24468681  0.          0.14304025]
True
[-1.2574776  -0.4016044  -1.1213601   0.87837774 -0.86175495  0.34651348
  0.9404431  -0.12234341 -1.1891836   0.07152013]


Fourth, in training mode, the non-zero activations are multiplied by $1/(1-p)$, while in testing mode the activations are not changed.

In [8]:
key, subkey = jax.random.split(key)
x = jax.random.normal(subkey, (10,))

layer = nnx.Dropout(0.9, rngs=nnx.Rngs(1))
layer.train()
print(layer.deterministic)
print(layer(x))

layer = nnx.Dropout(0.9, rngs=nnx.Rngs(1))
layer.eval()
print(layer.deterministic)
print(layer(x))

False
[ 0.        0.        0.        0.       -8.629354  0.        0.
  0.        0.        0.      ]
True
[-1.3877681   0.77485436  1.5404932   1.8419101  -0.86293536 -0.8070163
 -0.2005241   0.7834719  -0.9859735   0.26376608]


### 6.2: BatchNorm

BatchNorm is a way of improving the training process of neural networks with gradient descent. BatchNorm can be used simultaneously with Dropout, or it can be used without Dropout. 

BatchNorm is performed in two steps. 

First, for each activation, normalize the activation so that it has mean zero and variance 1. Instead of calculating the mean and variance over the entire dataset, during training BatchNorm uses the statistics from the current batch to estimate a mean and variance. So, for the $k$th activation value $x^{k}$ in a BatchNorm layer, during training the activation is normalized with the equation $$\hat{x}^{k} = \frac{x^k - \mathbb{E}[x^k]}{\sqrt{\textnormal{Var}[x^k]}}$$  where the expectation and variance are calculated by summing over the batch using $$\mathbb{E}[x^k] = \frac{1}{m}\sum_{i=1}^n x^k_i$$ $$\textnormal{Var}[x^k] = \frac{1}{m}\sum_{i=1}^m (x^k_i - \mathbb{E}[x^k])^2 + \delta$$ where $\delta$ is a small number to prevent numerical underflow. During testing, the means and variances are calculated using stored running averages from the training $$\mathbb{E}[x^k]_{\textnormal{running}} = \alpha \mathbb{E}[x^k]_{\textnormal{running}} + (1-\alpha) \mathbb{E}[x^k]_{\textnormal{batch}}$$ $$\sqrt{\textnormal{Var}[x^k]_{\textnormal{running}}} = \alpha \sqrt{\textnormal{Var}[x^k]_{\textnormal{running}}} + (1-\alpha) \sqrt{\textnormal{Var}[x^k]_{\textnormal{batch}}}$$ where $\alpha$ is a momentum term close to 1.

Second, each normalized activation $\hat{x}^k$ is then transformed by two trainable parameters, $\gamma^k$ and $\beta^k$. The final transformed activation value $h^k$ is given by $$h^k = \gamma^k \hat{x}^k + \beta^k.$$ This second transformation ensures that BatchNorm can, in practice, learn the identity transformation and thus does not reduce the functional approximation capability of the neural network.

Each transformation is differentiable, allowing BatchNorm layers to be incorporated into a gradient descent optimization algorithm.

One way of intuitively understanding the purpose of the first transformation is in terms of preventing activation function saturation. Recall that the sigmoid activation function $1/(1+e^{-x})$ has a linear regime near $x=0$ where the gradient is close to $1$, but a saturated regime at large $|x|$ where the gradient is close to zero. If the magnitude of the activation value is very large, the magnitude of the gradient will be very small and gradient descent will take a long time to converge. By normalizing the activations to have mean 0 and variance 1, fewer units will be in the saturated regime and the network can train faster.

The purpose of the second transformation is not immediately clear. After all, what is the purpose of multiplying by a variance and adding a mean right after subtracting off the mean and normalizing the variance? The explanation can be found by understanding the training dynamics of stacked transformations. Consider, for example, a linear network $$y = w_1 w_2 \dots w_n x.$$ Because $x$ is multiplied only by linear matrices, $y$ can only represent linear transformations of $x$. However, when this network is trained during gradient descent, each of the weights are updated according to the gradient times the learning rate. As a result, the activation means can undergo large changes even after a small gradient update, depending on the values of $w_i$ in the below layers. This is called 'internal covariate shift'.  However, when the mean of each hidden unit is normalized using BatchNorm, the mean of each activation depends only on the parameter $\beta^k$, and not on a complicated interaction between the gradient descent step size and the layer weights. As a result, the network activation means don't change dramatically in response to each gradient descent step size, and the network becomes easier to train.

In practice, networks trained with BatchNorm can use larger learning rates than methods without BatchNorm. While training speed is not proportional to the learning rate used, empirically these larger learning rates lead to faster training speed. Furthermore, Batchnorm-optimized networks tend to have better performance than methods without using BatchNorm. 

While these intuitive explanations give the reader some sense of why BatchNorm might be successful, note that there is debate among researchers about why BatchNorm gives better performance in practice. Some researchers have found that BatchNorm doesn't actually reduce internal covariate shift, and that BatchNorm gives improved training speed only due to the larger learning rate allowed, while BatchNorm gives improved generalization due to the regularizing effects of stochastic gradient descent with larger step sizes.

Good introductory resources on BatchNorm include section 8.7.1 of the [Deep Learning Book](https://www.deeplearningbook.org/) and the original [BatchNorm paper](https://arxiv.org/abs/1502.03167). 

### 6.2.1: BatchNorm in Flax

Using a BatchNorm layer in Flax is also easy. `flax.nnx` only requires specifying explicitly the size of the activation map given as input to the BatchNorm layer.

In [9]:
class BatchNormMLP(nnx.Module):
    def __init__(self, din: int, dmid: int, dout: int, rngs: nnx.Rngs):
        self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
        self.bn = nnx.BatchNorm(dmid, rngs=rngs)
        self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)

    def __call__(self, x):
        x = nnx.relu(self.bn(self.linear1(x)))
        return self.linear2(x)

In [10]:
model = BatchNormMLP(3, 10, 2, rngs=nnx.Rngs(0))
nnx.display(model)

BatchNormMLP(
  linear1=Linear(
    kernel=Param(
      value=Array(shape=(3, 10), dtype=float32)
    ),
    bias=Param(
      value=Array(shape=(10,), dtype=float32)
    ),
    in_features=3,
    out_features=10,
    use_bias=True,
    dtype=None,
    param_dtype=<class 'jax.numpy.float32'>,
    precision=None,
    kernel_init=<function variance_scaling.<locals>.init at 0x10ca17920>,
    bias_init=<function zeros at 0x10b4474c0>,
    dot_general=<function dot_general at 0x10ac593a0>
  ),
  bn=BatchNorm(
    mean=BatchStat(
      value=Array(shape=(10,), dtype=float32)
    ),
    var=BatchStat(
      value=Array(shape=(10,), dtype=float32)
    ),
    scale=Param(
      value=Array(shape=(10,), dtype=float32)
    ),
    bias=Param(
      value=Array(shape=(10,), dtype=float32)
    ),
    num_features=10,
    use_running_average=False,
    axis=-1,
    momentum=0.99,
    epsilon=1e-05,
    dtype=None,
    param_dtype=<class 'jax.numpy.float32'>,
    use_bias=True,
    use_scale=True,
   

We can see that the BatchNorm layer has two variables of type `nnx.Param`: `scale` and `bias`. It also has two variables of type `nnx.BatchStat`: `mean` and `var`. The scale and bias correspond to the parameters $\gamma$ and $\beta$ from earlier, while the mean and variance correspond to running averages of the batch mean and variance calculated during training. Since there are 10 activations in the hidden layer, BatchNorm adds $2*10=20$ parameters to the model parameters.

Note also that the BatchNorm layer begins in training mode, with `use_running_average=False`. When `model.eval()` is called, `use_running_average=True` and the model switches from updating the running averages `mean` and `var` to using them to normalize the activations.

In [11]:
print(model.bn.use_running_average)
model.eval()
print(model.bn.use_running_average)

False
True


Note also that if dropout and BatchNorm are used simultaneously in the same layer, BatchNorm is usually applied first and Dropout is applied second.

In [12]:
class BatchNormMLP(nnx.Module):
  def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
    self.linear1 = Linear(din, dmid, rngs=rngs)
    self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.linear2 = Linear(dmid, dout, rngs=rngs)

  def __call__(self, x: jax.Array):
    x = nnx.relu(self.dropout(self.bn(self.linear1(x)))) # BatchNorm applied first, Dropout second
    return self.linear2(x)