In [2]:
import torch

In [42]:
def dropout_with_scaling(X: torch.Tensor, dropout:float):
    assert 0 <= dropout <= 1
    if dropout == 1: return torch.zeros_like(X)
    mask = (torch.rand(X.shape) > dropout).float()
    return mask * X / (1.0 - dropout)

def dropout_without_scaling(X:torch.Tensor, dropout: float):
    assert 0 <= dropout <= 1
    if dropout == 1: return torch.zeros_like(X)
    mask = (torch.rand(X.shape) > dropout).float()
    return mask * X

In [43]:
X = torch.arange(16, dtype = torch.float32).reshape((2, 8))
print(X)

tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11., 12., 13., 14., 15.]])


Why we do need scaling

In [44]:
y_with_scaling = dropout_with_scaling(X, 0.2)
y_without_scaling = dropout_without_scaling(X, 0.2)

In [45]:
X.mean(),X.var()

(tensor(7.5000), tensor(22.6667))

In [46]:
y_with_scaling.mean(),y_with_scaling.var()

(tensor(6.8750), tensor(53.9583))

In [47]:
y_without_scaling.mean(),y_without_scaling.var()

(tensor(3.8750), tensor(21.9833))

Backpropagation

In [205]:
pre_layer = torch.arange(16, dtype = torch.float32,requires_grad=True).reshape((2, 8))
pre_layer

tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11., 12., 13., 14., 15.]], grad_fn=<ViewBackward0>)

In [206]:
dropout_layer = torch.nn.Dropout(0.5)

In [207]:
act = dropout_layer(pre_layer)
pre_layer.retain_grad()  

In [208]:
act

tensor([[ 0.,  2.,  0.,  0.,  0., 10., 12.,  0.],
        [ 0., 18., 20.,  0., 24., 26.,  0.,  0.]], grad_fn=<MulBackward0>)

In [209]:
loss = act.sum()
loss

tensor(112., grad_fn=<SumBackward0>)

In [210]:
loss.backward()

In [211]:
pre_layer.grad

tensor([[0., 2., 0., 0., 0., 2., 2., 0.],
        [0., 2., 2., 0., 2., 2., 0., 0.]])

In [39]:
import torch
import triton
import triton.language as tl

@triton.jit
def dropout_kernel(X, Y, MASK, N, P, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    x = tl.load(X + offsets, mask=offsets < N, other=0.0)
    mask = tl.load(MASK + offsets, mask=offsets < N, other=0)
    y = x * mask / (1 - P)
    tl.store(Y + offsets, y, mask=offsets < N)

def dropout_triton(x, p=0.5, block_size=1024):
    y = torch.empty_like(x)
    mask = (torch.rand_like(x) > p).float().cuda()
    grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']),)
    dropout_kernel[grid](x, y, mask, x.numel(), p, BLOCK_SIZE=block_size)
    return y


In [41]:
x = torch.randn(24).cuda()
p = 0.5
y = dropout_triton(x)
print(y)

tensor([ 0.0000, -0.3780,  0.0000,  0.0000,  2.7230, -2.7820, -1.6979, -0.6969,
        -0.0000, -1.2190,  0.0000, -0.0000,  3.8628,  0.0000, -3.6797, -0.4877,
        -0.0000, -1.1632, -0.0000, -0.0000,  2.4258,  2.9421, -0.0000,  3.0436],
       device='cuda:0')
