<a href="https://colab.research.google.com/github/casanovaalonso/TritonTutorials/blob/main/04_triton_dropout.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install triton
!pip install torch

Collecting triton
  Downloading triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.3 kB)
Downloading triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (209.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m209.5/209.5 MB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: triton
Successfully installed triton-3.1.0


# Dropout Regularization

Dropout is a technique to prevent overfitting in neural networks by randomly "dropping out" (setting to zero) a fraction of neurons during training.

### 1. **Forward Pass with Dropout**
Each neuron in a layer has a probability $ p $ of being kept:

$$
d_i \sim \text{Bernoulli}(p)
$$

The activations are then scaled:

$$
\hat{\mathbf{a}}_l = \frac{\mathbf{a}_l \odot \mathbf{d}}{(1-p)}
$$

### 2. **Backpropagation**
During backpropagation, the gradients are computed with the dropout mask applied, ensuring correct updates.

### 3. **Test Time**
At inference time, no neurons are dropped. The output is not scaled, to account for the dropout during training.Ç

In [2]:
import tabulate
import torch

import triton
import triton.language as tl

In [5]:
DEVICE = torch.device(f"cuda:{torch.cuda.current_device()}")

In [13]:
# The function will work on 1D tensors
@triton.jit
def _dropout(
    input_ptr,
    input_keep_ptr,
    output_ptr,
    n_elements,
    p,
    BLOCK_SIZE: tl.constexpr,
):
  pid = tl.program_id(0)
  block_start_ptr = pid * BLOCK_SIZE
  offsets = block_start_ptr + tl.arange(0, BLOCK_SIZE)
  mask = offsets < n_elements
  input = tl.load(input_ptr+offsets, mask=mask)
  input_keep = tl.load(input_keep_ptr+offsets, mask=mask)
  output = tl.where(input_keep, input/(1-p), 0.0)
  tl.store(output_ptr+offsets, output, mask=mask)

def dropout(input, input_keep, p):
  assert input.device == input_keep.device == DEVICE
  assert input.shape == input_keep.shape
  assert input.is_contiguous()
  assert input_keep.is_contiguous()
  n_elements = input.numel()
  output = torch.empty_like(input)
  grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
  _dropout[grid](
      input,
      input_keep,
      output,
      n_elements,
      p,
      BLOCK_SIZE=1024,
  )
  return output


In [17]:
# Input tensor
x = torch.randn(size=(10, )).cuda()
# Dropout mask
p = 0.5
x_keep = (torch.rand(size=(10, )) > p).to(torch.int32).cuda()
output = dropout(x, input_keep=x_keep, p=p)
print(tabulate.tabulate([
    ["input"] + x.tolist(),
    ["keep mask"] + x_keep.tolist(),
    ["output"] + output.tolist(),
]))

---------  --------  ---------  --------  --------  -------  --------  ---------  --------  --------  ------
input      0.589746  -0.701648  0.394555  0.904477  1.23426  0.356815  -0.804551  0.169323  -1.50376  0.3185
keep mask  0          0         0         1         0        0          1         1          0        0
output     0          0         0         1.80895   0        0         -1.6091    0.338647   0        0
---------  --------  ---------  --------  --------  -------  --------  ---------  --------  --------  ------


According to the [tutorial](https://triton-lang.org/main/getting-started/tutorials/04-low-memory-dropout.html) there is a more efficient way of doing the dropout. As you can see we are keeping the dropout mask stored in the memory of the GPU. One alternative is using a seeded random number generator. With this approach we can always get the mask for the backpropagation by just knowing the seed.

In [20]:
@triton.jit
def _seeded_dropout(
    input_ptr,
    output_ptr,
    n_elements,
    p,
    seed,
    BLOCK_SIZE: tl.constexpr,
):
  pid = tl.program_id(0)
  block_start_ptr = pid * BLOCK_SIZE
  offsets = block_start_ptr + tl.arange(0, BLOCK_SIZE)
  mask = offsets < n_elements
  input = tl.load(input_ptr+offsets, mask=mask)
  dropout_mask = tl.rand(seed, offsets) < (1-p)
  output = tl.where(dropout_mask, input/(1-p), 0.0)
  tl.store(output_ptr+offsets, output, mask=mask)

def seeded_dropout(input, p, seed):
  assert input.device == DEVICE
  assert input.is_contiguous()
  n_elements = input.numel()
  output = torch.empty_like(input)
  grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
  _seeded_dropout[grid](
      input,
      output,
      n_elements,
      p,
      seed,
      BLOCK_SIZE=1024,
  )
  return output

In [21]:
x = torch.randn(size=(10, )).cuda()
# Compare this to the baseline - dropout mask is never instantiated!
output = seeded_dropout(x, p=0.5, seed=123)
output2 = seeded_dropout(x, p=0.5, seed=123)
output3 = seeded_dropout(x, p=0.5, seed=512)

print(
    tabulate.tabulate([
        ["input"] + x.tolist(),
        ["output (seed = 123)"] + output.tolist(),
        ["output (seed = 123)"] + output2.tolist(),
        ["output (seed = 512)"] + output3.tolist(),
    ]))

-------------------  --------  --------  --------  --------  -------  -------  --------  --------  ---------  ---------
input                0.744358  0.673647  0.805808  0.709042  1.07801  1.33072  0.702965  -2.23653  -0.326397  -0.925518
output (seed = 123)  1.48872   0         1.61162   1.41808   2.15601  0        1.40593   -4.47306   0          0
output (seed = 123)  1.48872   0         1.61162   1.41808   2.15601  0        1.40593   -4.47306   0          0
output (seed = 512)  1.48872   1.34729   0         0         2.15601  0        0         -4.47306  -0.652794  -1.85104
-------------------  --------  --------  --------  --------  -------  -------  --------  --------  ---------  ---------
