<a href="https://colab.research.google.com/github/doudi25/Triton/blob/main/SwiGLU_Kernel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import triton
import triton.language as tl
import torch

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
@triton.jit
def forward_SwiGLU(a_ptr,g_ptr,out_ptr,n_elements,BLOCK_SIZE:tl.constexpr):
  pid = tl.program_id(axis=0)
  block_id = pid * BLOCK_SIZE
  offsets  = block_id + tl.arange(0,BLOCK_SIZE)
  mask = offsets < n_elements
  a = tl.load(a_ptr + offsets,mask = mask)
  g = tl.load(g_ptr + offsets,mask = mask)
  silu = a * tl.sigmoid(a)
  # out = swish(a) * g  => swish(a) = a * sigmoid(a) =====> * is element wise multiplication
  out = silu * g
  tl.store(out_ptr + offsets,out,mask=mask)



In [12]:

def SwiGLU(a,g):
  batch , seq_len , hidden = a.shape
  # reshaping tensors to M,N type
  a = a.view(batch * seq_len , hidden)
  g = g.view(batch * seq_len , hidden)
  n_elements = g.numel()
  # allocate memory for the result
  out = torch.empty_like(a,device='cuda')
  grid = lambda meta : ( triton.cdiv(n_elements,meta['BLOCK_SIZE']),)
  forward_SwiGLU[grid](a,g,out,n_elements,BLOCK_SIZE=1024)
  return out


In [23]:
@triton.jit
def backward_SwiGLU_kernel(DOUT_ptr,a_ptr,g_ptr,n_elements,BLOCK_SIZE:tl.constexpr):
  pid = tl.program_id(axis=0)
  block_id = pid * BLOCK_SIZE
  offsets = block_id + tl.arange(0,BLOCK_SIZE)
  DOUT = tl.load(DOUT_ptr + offsets,mask = offsets < n_elements)
  a = tl.load(a_ptr + offsets,mask = offsets < n_elements)
  g = tl.load(g_ptr + offsets,mask = offsets < n_elements)
  sig = tl.sigmoid(a)
  Swish = a
  # the gradient of the gate will be DOUT * Swish since out = swish * g (* is element wise multiplication)
  Dg = DOUT * Swish
  # Da = ( dout / dswish ) * (dswish / da) we will fuse it directly in one line => dout/dswish = Dout * gate --------- dswish/da = sigmoid(a) + sigmoid(a) * (1-sigmoid(a)) * a
  Da =  DOUT * g * ( sig + sig * (1 - sig ) * a)
  tl.store(a_ptr + offsets,Da,mask = offsets < n_elements)
  tl.store(g_ptr + offsets,Dg,mask = offsets < n_elements)


In [26]:
def backward_SwiGLU(DOUT,a,g):
  # DOUT is the gradient of out with respect to loss ----> dloss/Dout
  batch_size , seq_len , hidden = DOUT.shape
  DOUT = DOUT.view(DOUT.shape[0] * DOUT.shape[1],-1)
  a = a.view(batch_size * seq_len,-1)
  g = b.view(batch_size * seq_len,-1)
  n_elements = a.numel()
  grid = lambda meta : (triton.cdiv(n_elements,meta['BLOCK_SIZE']),)
  backward_SwiGLU_kernel[grid](DOUT,a,g,n_elements,BLOCK_SIZE=1024,)
  return a , g
