<a href="https://colab.research.google.com/github/doudi25/Triton/blob/main/GELU.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import triton
import triton.language as tl

In [None]:
@triton.jit
def tanh(x):
    exp_x = tl.exp(x)
    exp_neg_x = 1.0 / exp_x
    return (exp_x - exp_neg_x) / (exp_x + exp_neg_x)

In [54]:
PI = 3.141592653589793
@triton.jit
def forward_kernel(x_ptr,out_ptr,stride_m,stride_n,m,n,PI,BLOCK_SIZE_ROW:tl.constexpr,BLOCK_SIZE_COL:tl.constexpr,num_warps=64):
  pid_m = tl.program_id(axis=0)
  pid_n = tl.program_id(axis=1)
  offs_m = pid_m * BLOCK_SIZE_ROW + tl.arange(0,BLOCK_SIZE_ROW)
  offs_n = pid_n * BLOCK_SIZE_COL + tl.arange(0,BLOCK_SIZE_COL)
  mask = (offs_m[:,None] < m ) & (offs_n[None,:] < n)
  input_ptrs = x_ptr + offs_m[:,None] * stride_m + offs_n[None,:] * stride_n
  x = tl.load(input_ptrs,mask=mask)
  out = 0.5 * x * ( 1 + tanh(tl.sqrt(2/PI) * ( x + 0.044715 * x * x * x)))
  tl.store(out_ptr + offs_m[:,None] * stride_m + offs_n[None,:] * stride_n,out)


In [55]:
def _gelu_forward(x:torch.tensor):
  assert x.is_cuda and x.is_contiguous()
  m,n = x.shape
  BLOCK_SIZE_ROW = 32
  BLOCK_SIZE_COL = 64
  out = torch.empty_like(x,device=x.device,dtype=x.dtype)
  grid = (triton.cdiv(m,BLOCK_SIZE_ROW),triton.cdiv(n,BLOCK_SIZE_COL))
  forward_kernel[grid](x,out,x.stride(0),x.stride(1),m,n,PI,BLOCK_SIZE_ROW,BLOCK_SIZE_COL)
  return out

In [64]:
input = torch.rand((1024,1024),device='cuda')
out = torch.nn.functional.gelu(input)
out_triton = _gelu_forward(input)

In [57]:
print(torch.allclose(out,out_triton,1e-3))

True


In [58]:
def benchmark(fn,*arg,warmup=16,steps=64):
  start = torch.cuda.Event(enable_timing=True)
  end = torch.cuda.Event(enable_timing=True)
  for _ in range(warmup):
    fn(*arg)
  torch.cuda.synchronize()
  start.record()
  for _ in range(steps):
    fn(*arg)
  end.record()
  torch.cuda.synchronize()
  return start.elapsed_time(end)/steps

In [65]:
time_triton = benchmark(_gelu_forward,input)
time_torch = benchmark(torch.nn.functional.gelu,input)

In [67]:
print(f'triton_time is {time_triton:.4f} ms ')
print(f'torch_time i s {time_torch:.4f} ms')

triton_time is 0.0480 ms 
torch_time i s 0.0368 ms
