# Dropout
Example how dropout module is working  
https://pytorch.org/docs/stable/nn.html#torch.nn.Dropout

In [1]:
import torch
import torch.nn as nn

In [2]:
p=0.5
dropout = nn.Dropout(p=p) # input units will be set to 0 with probability of 0.5

In [11]:
outputs = torch.randn(5, 4) # a mini-batch of 5 vectors with 4 elements

In [12]:
# data.mean(1, keepdim=True)
print(outputs)

tensor([[-1.6512, -0.2450,  1.6016,  1.5245],
        [ 0.5314,  0.1052, -0.5480, -0.4406],
        [ 0.5084, -1.0770, -1.7468,  0.1666],
        [ 0.8514, -0.5135,  0.9936,  0.0461],
        [-1.6448, -2.2044,  1.1541, -0.8650]])


In [13]:
outputs_dropped = dropout(outputs)

In [14]:
outputs_dropped # in each vector some elements are randomly set to 0, the rest is scaled by 1/(1-p)

tensor([[-3.3025, -0.0000,  0.0000,  3.0491],
        [ 0.0000,  0.2103, -1.0959, -0.8811],
        [ 0.0000, -0.0000, -0.0000,  0.0000],
        [ 0.0000, -1.0271,  1.9871,  0.0000],
        [-3.2897, -4.4089,  2.3083, -0.0000]])

In [15]:
outputs_dropped * (1-p) # proof that data is scaled; now the non-zero elements match the original tensor

tensor([[-1.6512, -0.0000,  0.0000,  1.5245],
        [ 0.0000,  0.1052, -0.5480, -0.4406],
        [ 0.0000, -0.0000, -0.0000,  0.0000],
        [ 0.0000, -0.5135,  0.9936,  0.0000],
        [-1.6448, -2.2044,  1.1541, -0.0000]])

# dropout.train() vs dropout.eval()
Dropout is one of the layers the behaves differently in training (applies transformation) and differently in evaluation (does nothing). That's why we use **model.train()** at the beginning of training loop and **model.eval()** at the beginning of evaluation loop.

In [16]:
dropout.eval()
dropout(outputs) # input tensor is returned

tensor([[-1.6512, -0.2450,  1.6016,  1.5245],
        [ 0.5314,  0.1052, -0.5480, -0.4406],
        [ 0.5084, -1.0770, -1.7468,  0.1666],
        [ 0.8514, -0.5135,  0.9936,  0.0461],
        [-1.6448, -2.2044,  1.1541, -0.8650]])

In [17]:
dropout.train()
dropout(outputs) # dropout is applied in training mode

tensor([[-0.0000, -0.4899,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.0000, -0.0000],
        [ 1.0168, -0.0000, -0.0000,  0.3331],
        [ 0.0000, -1.0271,  0.0000,  0.0000],
        [-0.0000, -4.4089,  0.0000, -0.0000]])