<a href="https://colab.research.google.com/github/doudi25/Triton/blob/main/LayerNorm_backward.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import triton
import triton.language as tl


In [104]:
# built custom layer of layernorm to save mean and var
class Custom_LayerNorm(torch.nn.LayerNorm):
  def forward(self,x :torch.tensor):
    self.mean = x.mean(dim=1)
    self.var = x.var(dim=1)
    return super().forward(x)


In [145]:
@triton.jit
def _backward_kernel(dout_ptr,x_ptr,dgamma_ptr,dbias_ptr,mean_ptr,var_ptr,stride_m,
                     stride_n,m,n,eps,BLOCK_SIZE_ROW:tl.constexpr,BLOCK_SIZE_COL:tl.constexpr,num_warps=16):
  pid_m = tl.program_id(axis=0)
  pid_n = tl.program_id(axis=1)
  # assign offs_m and offs_n
  offs_m = pid_m * BLOCK_SIZE_ROW + tl.arange(0,BLOCK_SIZE_ROW)
  offs_n = pid_n * BLOCK_SIZE_COL + tl.arange(0,BLOCK_SIZE_COL)
  # assign mask
  mask = (offs_m[:,None] < m) & (offs_n[None,:] < n)
  # assign input_ptrs , dout_ptrs
  input_ptrs = x_ptr + offs_m[:,None] * stride_m + offs_n[None,:] * stride_n
  input = tl.load(input_ptrs,mask=mask)
  dout_ptrs = dout_ptr + offs_m[:,None] * stride_m + offs_n[None,:] * stride_n
  dout = tl.load(dout_ptrs,mask=mask)
  # load mean and var with offs_m ( have shape of  m,)
  mean = tl.load(mean_ptr+offs_m,mask=offs_m < m)
  mean = tl.reshape(mean,(BLOCK_SIZE_ROW,1))
  var = tl.load(var_ptr+offs_m,mask=offs_m < m)
  var = tl.reshape(var,(BLOCK_SIZE_ROW,1))
  # calculate the normalized version of x
  x_norm = (input - mean) * tl.rsqrt(var+eps)
  # dgamma is = xnorm * dout and then summing across the first dim
  dgamma = tl.sum((x_norm * dout),axis=0)
  # dbias is = dout * ones_like(dout) and summing across the first dim -> (ones_like can be neglected since a * 1 = a )
  dbias = tl.sum(dout,axis=0)
  # perform atomic addition
  tl.atomic_add(dgamma_ptr + offs_n,dgamma,mask=offs_n < n)
  tl.atomic_add(dbias_ptr + offs_n,dbias,mask=offs_n < n)

In [212]:
def _LayerNorm_backward(dout:torch.tensor,x: torch.tensor, mean:torch.tensor,var:torch.tensor,eps:float):
  assert x.is_cuda and x.is_contiguous()
  dout = dout.contiguous()
  assert dout.shape == x.shape
  m,n = x.shape
  dgamma = torch.zeros(x.shape[1],device=x.device,dtype=x.dtype)
  dbias =  torch.zeros(x.shape[1],device=x.device,dtype=x.dtype)
  BLOCK_SIZE_ROW = 32
  BLOCK_SIZE_COL = 32
  grid = (triton.cdiv(m,BLOCK_SIZE_ROW),triton.cdiv(n,BLOCK_SIZE_COL))
  _backward_kernel[grid](dout,x,dgamma,dbias,mean,var,x.stride(0),x.stride(1),
                         m,n,eps,BLOCK_SIZE_ROW,BLOCK_SIZE_COL)
  return dbias , dgamma

In [228]:
layer = Custom_LayerNorm(768,device='cuda')
input = torch.rand((768,768),device='cuda')
out = layer(input)
# summing out element and do backprop gives us dout = ones_like(out) -> the same if we use autograd.grad
loss = out.sum()
loss.backward()

In [229]:
dbias,dgamma = _LayerNorm_backward(torch.ones_like(input),input,layer.mean,layer.var,layer.eps)

In [230]:
print(torch.allclose(dbias,layer.bias.grad))
# maybe you get false of gamma grad due to numerical stability
print(torch.allclose(dgamma,layer.weight.grad,1e-3))

True
True
