<a href="https://colab.research.google.com/github/doudi25/Triton/blob/main/Rms_LayerNorm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torch import nn as nn
import triton
import triton.language as tl

In [48]:
@triton.jit
def RMSNorm_kernel(input_ptr,out_ptr,weight_ptr,stride_m,stride_n,bs,seq_len,BLOCK_SIZE_SEQ:tl.constexpr,BLOCK_SIZE_BATCH:tl.constexpr):
  batch_id = tl.program_id(axis=0)
  batch_offs = batch_id * BLOCK_SIZE_BATCH + tl.arange(0,BLOCK_SIZE_BATCH)
  seq_offs = tl.arange(0,BLOCK_SIZE_SEQ)
  mask = (batch_offs[:,None] < bs ) & (seq_offs[None,:] < seq_len)
  batch_offs = batch_offs[:,None] * stride_m
  seq_offs = seq_offs[None,:] * stride_n
  input = tl.load(input_ptr + batch_offs + seq_offs,mask=mask)
  input = input.to(dtype=tl.float32)
  weight = tl.load(weight_ptr + seq_offs,mask=seq_offs< seq_len)
  weight = weight.to(dtype=tl.float32)
  ms = tl.sum(input*input,axis=1,keep_dims=True)/seq_len
  ms = ms.to(dtype=tl.float32)
  out = tl.rsqrt(ms) * input * weight
  out = out.to(dtype=tl.float32)
  tl.store(out_ptr + batch_offs + seq_offs,out)


In [49]:
def triton_RMSNorm(input,weight):
  assert input.is_cuda
  assert weight.is_cuda
  bs,seq_len = input.shape
  out = torch.empty_like(input,device=input.device,dtype=input.dtype)
  BLOCK_SIZE_SEQ = triton.next_power_of_2(seq_len)
  BLOCK_SIZE_BATCH = BLOCK_SIZE_SEQ
  grid = triton.cdiv(bs,BLOCK_SIZE_BATCH)
  RMSNorm_kernel[(grid,)](input,out,weight,
                       input.stride(0),input.stride(1),bs,seq_len,BLOCK_SIZE_SEQ,BLOCK_SIZE_BATCH)
  return out


In [55]:
layer = nn.RMSNorm(128).to('cuda')
layer = layer.to(dtype=torch.float32)
input = 5 * torch.rand((512,128),device='cuda',dtype=torch.float32)
out = layer(input)
out_triton = triton_RMSNorm(input,layer.weight)

In [56]:
print(torch.allclose(out,out_triton))

True
