<a href="https://colab.research.google.com/github/doudi25/Triton/blob/main/SiLU.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import triton
import triton.language as tl

In [6]:
@triton.jit
def _forward_kernel(x_ptr,y_ptr,stride_m,stride_n,m,n,BLOCK_SIZE_ROW:tl.constexpr,BLOCK_SIZE_COL:tl.constexpr,num_warps=64):
  pid_m = tl.program_id(axis=0)
  pid_n = tl.program_id(axis=1)
  offs_m = pid_m * BLOCK_SIZE_ROW + tl.arange(0,BLOCK_SIZE_ROW)
  offs_n = pid_n * BLOCK_SIZE_COL + tl.arange(0,BLOCK_SIZE_COL)
  mask = (offs_m[:,None] < m ) & (offs_n[None,:] < n)
  x_ptrs = x_ptr + offs_m[:,None] * stride_m + offs_n[None,:] * stride_n
  x = tl.load(x_ptrs,mask=mask)
  y = x * tl.sigmoid(x)
  y_ptrs = y_ptr + offs_m[:,None] * stride_m + offs_n[None,:] * stride_n
  tl.store(y_ptrs,y,mask=mask)

In [7]:
def _forward_Silu(x:torch.tensor):
  assert x.is_cuda and x.is_contiguous()
  if x.ndim == 2:
    m,n = x.shape
  else:
    x = x.view(-1,x.shape[-1])
    m,n = x.shape
  y = torch.empty_like(x,device=x.device,dtype=x.dtype)
  BLOCK_SIZE_ROW = 32
  BLOCK_SIZE_COL = 64
  grid = (triton.cdiv(m,BLOCK_SIZE_ROW),triton.cdiv(n,BLOCK_SIZE_COL))
  _forward_kernel[grid](x,y,x.stride(0),x.stride(1),m,n,BLOCK_SIZE_ROW,BLOCK_SIZE_COL)
  return y

In [37]:
@triton.jit
def _backward_kernel(x_ptr,dx_ptr,dout_ptr,stride_m,stride_n,
                     m,n,BLOCK_SIZE_ROW:tl.constexpr,BLOCK_SIZE_COL:tl.constexpr,num_warps=64):
  pid_m = tl.program_id(axis=0)
  pid_n = tl.program_id(axis=1)
  offs_m = pid_m * BLOCK_SIZE_ROW + tl.arange(0,BLOCK_SIZE_ROW)
  offs_n = pid_n * BLOCK_SIZE_COL + tl.arange(0,BLOCK_SIZE_COL)
  mask = (offs_m[:,None] < m ) & (offs_n[None,:] < n)
  x_ptrs = x_ptr + offs_m[:,None] * stride_m + offs_n[None,:] * stride_n
  x = tl.load(x_ptrs,mask=mask)
  dout_ptrs = dout_ptr + offs_m[:,None] * stride_m + offs_n[None,:] * stride_n
  dout = tl.load(dout_ptrs,mask=mask)
  dx = dout * ( tl.sigmoid(x) + (tl.sigmoid(x) * (1-tl.sigmoid(x)))* x)
  dx_ptrs = dx_ptr + offs_m[:,None] * stride_m + offs_n[None,:] * stride_n
  tl.store(dx_ptrs,dx,mask=mask)



In [38]:
def _backward_Silu(x:torch.tensor,dout:torch.tensor):
  assert x.is_cuda and dout.is_cuda
  dout = dout.contiguous()
  assert x.is_contiguous() and dout.is_contiguous(),print(f'x is contiguous {x.is_contiguous()} , dout is contiguous {dout.is_contiguous}')
  m,n = x.shape
  dx = torch.empty_like(x,device=x.device,dtype=x.dtype)
  BLOCK_SIZE_ROW = 32
  BLOCK_SIZE_COL = 64
  grid = (triton.cdiv(m,BLOCK_SIZE_ROW),triton.cdiv(n,BLOCK_SIZE_COL))
  _backward_kernel[grid](x,dx,dout,x.stride(0),x.stride(1),m,n,BLOCK_SIZE_ROW,BLOCK_SIZE_COL)
  return dx

In [42]:
class Silu(torch.autograd.Function):
  @staticmethod
  def forward(ctx,input):
    ctx.save_for_backward(input)
    out = _forward_Silu(input)
    return out
  @staticmethod
  def backward(ctx,dout):
    input = ctx.saved_tensors[0]
    dx = _backward_Silu(input,dout)
    return dx

In [53]:
input = torch.rand((1024,2048),device='cuda',requires_grad=True)
out = Silu.apply(input)
out_torch = torch.nn.functional.silu(input)
loss = out_torch.sum()
loss.backward()
dinput_torch,input.grad = input.grad.clone(),None


In [64]:
def test_correctness():
  input = torch.rand((1024,2048),device='cuda',requires_grad=True)
  out = Silu.apply(input)
  out_torch = torch.nn.functional.silu(input)
  loss = out_torch.sum()
  loss.backward()
  dinput_torch,input.grad = input.grad.clone(),None
  out_triton = Silu.apply(input)
  loss = out_triton.sum()
  loss.backward()
  dinput_triton,input.grad = input.grad.clone(),None
  return print(torch.allclose(dinput_triton,dinput_torch),"The gradient of triton silu kernel is similar to pytorch autograd engine")

In [65]:
if __name__=="__main__":
  test_correctness()

True The gradient of triton silu kernel is similar to pytorch autograd engine
