<a href="https://colab.research.google.com/github/doudi25/Triton/blob/main/Dropout.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
import torch
import triton
import triton.language as tl


In [20]:
@triton.autotune(configs=[
    triton.Config({'BLOCK_SIZE_M':64,'BLOCK_SIZE_N':64},num_warps=4,num_stages=2),
    triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':128},num_warps=8,num_stages=2)],
    key=['m','n'] )
@triton.jit
def dropout_kernel(input_ptr,mask_ptr,out_ptr,stride_m,stride_n,m,n,pb,BLOCK_SIZE_M:tl.constexpr,BLOCK_SIZE_N:tl.constexpr):
  pid_m = tl.program_id(axis=0)
  pid_n = tl.program_id(axis=1)
  offset_m = pid_m * BLOCK_SIZE_M + tl.arange(0,BLOCK_SIZE_M)
  offset_n = pid_n * BLOCK_SIZE_N + tl.arange(0,BLOCK_SIZE_N)
  mask = (offset_m[:,None] < m) & (offset_n[None,:] < n)
  offset_m = offset_m[:,None] * stride_m
  offset_n = offset_n[None,:] * stride_n
  input = tl.load(input_ptr+offset_m+offset_n,mask=mask,other=0.0)
  masked = tl.load(mask_ptr+offset_m+offset_n,mask=mask,other=1.0)
  out= tl.where(masked<pb,0.0,input)
  out = out /(1-pb)
  tl.store(out_ptr+offset_m+offset_n,out)


In [21]:
torch.manual_seed(42)
def triton_dropout(input:torch.tensor,pb:float):
  assert input.is_cuda
  assert isinstance(pb,float)
  m,n = input.shape
  mask = torch.rand_like(input,device=input.device,dtype=input.dtype)
  out = torch.empty_like(input,device=input.device,dtype=input.dtype)
  grid = lambda meta: (triton.cdiv(m,meta['BLOCK_SIZE_M']),triton.cdiv(n,meta['BLOCK_SIZE_N']))
  dropout_kernel[grid](input,mask,out,input.stride(0),input.stride(1),m,n,pb)
  return out

In [22]:
input = torch.randn((1024,2048),device='cuda')
out = triton_dropout(input,0.8)