# Tutorial #6: Dropout and BatchNorm

Dropout is a method of regularizing networks. Dropout is an extremely simple: during training, multiply the input to each unit by either 0 or 1, multiplying by 1 with probability $p$ and by 0 with probability $(1-p)$. Typically, $p=0.5$ is chosen for hidden units, and $p=0.8$ is chosen for input units. These have been shown to empirically work well. 

During testing, all of the units are retained; none are multipled by 0. Empiricially, it has been shown that dropout has better performance if the expected value of the output of each hidden unit matches the distribution seen during training. Since $1/p$ times more hidden units are included in each linear transformation during testing compared to training, the expected value of each output unit will be multiplied by $1/p$. To counteract this multiplication in expected value, the weights of the linear transformation are simply multiplied by $1/p$ during testing. 

For a layer within a neural network with input $\boldsymbol{x}$, non-linear activation function $g(\boldsymbol{z})$, and linear transformation matrix $\boldsymbol{W}$, and bias $\boldsymbol{b}$, the output $\boldsymbol{h}$ of the hidden layer is given by $$\boldsymbol{h} = g(\boldsymbol{W} \boldsymbol{x} + \boldsymbol{b}).$$ Dropout performs elementwise multiplication of $\boldsymbol{x}$ with a binary mask $\boldsymbol{r}$ where each element of $\boldsymbol{r}$ is drawn from a Bernoulli distribution with probability $p$. During training, the dropout layer becomes $$r_j \sim \textnormal{Bernoulli}(p)$$ $$ \boldsymbol{\tilde{x}} = \boldsymbol{r} * \boldsymbol{x}$$ $$\boldsymbol{h} = g(\boldsymbol{W} \boldsymbol{\tilde{x}} + \boldsymbol{b}).$$ During testing, the layer now becomes $$\boldsymbol{h} = g(\frac{1}{p}\boldsymbol{W} \boldsymbol{x} + \boldsymbol{b}).$$

TODO: is dropout still used in LLMs and state-of-the-art vision models?

Good resources for understanding dropout include the original 2014 paper by [Srivastava et al.](TODO) and section 7.12 of the [Deep Learning Book](TODO).

### 6.1: Dropout in `flax`

Dropout in `flax` is extremely easy to implement. We simply create a `nnx.Dropout(p, rngs)` layer, and call it to mask the input to the each layer. 

In [1]:
import flax
import jax
from flax import nnx
from jax import numpy as jnp

In [2]:
class DropoutTestMLP(nnx.Module):
    def __init__(self, rngs: nnx.Rngs):
        self.linear1 = nnx.Linear(4, 16, rngs=rngs)
        self.dropout_in = nnx.Dropout(0.8, rngs = rngs)
        self.linear2 = nnx.Linear(16, 10, rngs=rngs)
        self.dropout = nnx.Dropout(0.5, rngs = rngs)

    def __call__(self, x):
        x = nnx.relu(self.linear1(self.dropout_in(x)))
        return self.linear2(self.dropout(x))

We have to call `model.train()` to set the dropout layers to mask the inputs, and `model.eval()` to multiply the weights by $1/p$ for testing. 

In [3]:
model = DropoutTestMLP(nnx.Rngs(0))

model.train()

test_input = jax.random.normal(jax.random.PRNGKey(0), (4,))
print("First train dropout, mask chosen randomly")
print(model(test_input))
print("Second train dropout, mask chosen randomly with different RNG key")
print(model(test_input))

model.eval()

print("First eval dropout, no mask")
print(model(test_input))
print("Second eval dropout, deterministic output same as first eval")
print(model(test_input))

First train dropout, mask chosen randomly
[-3.3543632   1.2700179  -6.3404975  -4.5124435   4.522447   -0.21097493
  2.3068194   0.05415332  1.5272924   4.24723   ]
Second train dropout, mask chosen randomly with different RNG key
[-4.4009852  0.4407386 -3.4723148  1.7922347  7.803579  -2.2121768
 -1.5838387 -4.6009502 -2.6703286 -1.5528221]
First eval dropout, no mask
[-0.87284786  0.84970397 -1.2577161  -0.75616425  0.4537591  -0.6594648
  0.9112347   0.19184317  0.8342736   1.534824  ]
Second eval dropout, deterministic output same as first eval
[-0.87284786  0.84970397 -1.2577161  -0.75616425  0.4537591  -0.6594648
  0.9112347   0.19184317  0.8342736   1.534824  ]


When using Dropout layers, we have to remember to call `model.train()` before training and `model.eval()` before testing.

### 6.2: BatchNorm in `flax`

