<a href="https://colab.research.google.com/github/doudi25/Triton/blob/main/Layer_Norm_Forward.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import triton
import triton.language as tl


In [None]:
input = 10 * torch.rand((256,16),dtype=torch.float32,device='cuda')
layer_norm = nn.LayerNorm(16,device='cuda')
input_normalized = layer_norm(input)

In [None]:
@triton.jit
def forward_layer_norm(x_ptr,x_norm_ptr,gamma_ptr,beta_ptr,mean_ptr,var_ptr,eps,M:tl.constexpr,N:tl.constexpr,stride_xm,stride_xn,BLOCK_SIZE:tl.constexpr):
  pid = tl.program_id(axis=0)
  block_id = pid * BLOCK_SIZE
  offs_row = block_id + tl.arange(0,BLOCK_SIZE)
  offs_col = tl.arange(0,N)
  mask = (offs_row[:,None] < M) & (offs_col[None,:] < N)
  x_ptrs = x_ptr + offs_row[:,None] * stride_xm + offs_col[None,:] * stride_xn
  gamma_ptrs = gamma_ptr + offs_col[None,:]
  beta_ptrs = beta_ptr + offs_col[None,:]
  x = tl.load(x_ptrs,mask=mask)
  gamma = tl.load(gamma_ptrs)
  beta = tl.load(beta_ptrs)
  mean = tl.sum(x,axis=1,keep_dims=True) / N
  mean = mean.to(tl.float32)
  var = (tl.sum((x-mean)* (x-mean),axis=1,keep_dims=True) / N)
  var = var.to(tl.float32)
  x_norm = (x-mean) / (tl.sqrt(var + eps))
  x_norm = x_norm * gamma + beta
  tl.store(x_norm_ptr + offs_row[:,None] * stride_xm + offs_col[None,:] * stride_xn,x_norm,mask=mask)
  tl.store(mean_ptr + offs_row,tl.reshape(mean,[BLOCK_SIZE]),mask=offs_row < M)
  tl.store(var_ptr + offs_row,tl.reshape(var,[BLOCK_SIZE]),mask=offs_row < M) # Assuming rstd_ptr was intended, not std_ptr


In [None]:
def _layer_norm_forward(x,gamma,beta,eps):
  assert x.shape[1] == gamma.shape[0] , f'Incompatible shape'
  assert x.shape[1] == beta.shape[0] , f'Incompatible shape'
  assert x.is_contiguous(), f'x is not contiguous tensor'
  rows , cols = x.shape
  mean = torch.empty((rows,1),device=x.device,dtype=x.dtype)
  var = torch.empty_like(mean,device=x.device,dtype=x.dtype)
  x_norm = torch.empty_like(x,device=x.device,dtype=x.dtype)
  grid = lambda meta : (triton.cdiv(rows,meta['BLOCK_SIZE']),)
  # since the stride of vector with one dim is one we do not need to pass it , for mean , std , gamma , beta
  forward_layer_norm[grid](x,
                           x_norm,
                           gamma,
                           beta,
                           mean,
                           var,
                           eps,
                           rows,
                           cols,
                           x.stride(0),
                           x.stride(1),
                           BLOCK_SIZE=128,)
  return x_norm


In [None]:
input_norm  = _layer_norm_forward(input,layer_norm.weight,layer_norm.bias,layer_norm.eps)

In [None]:
print(torch.allclose(input_normalized,input_norm,1e-4))

True
