<a href="https://colab.research.google.com/github/doudi25/Triton/blob/main/l2_norm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import triton
import triton.language as tl


In [14]:
@triton.jit
def l2_norm_kernel(input_ptr,out_ptr,stride_m,stride_n,m,n,BLOCK_SIZE_M:tl.constexpr,BLOCK_SIZE_N:tl.constexpr):
  pid = tl.program_id(axis=0)
  m_offs = pid * BLOCK_SIZE_M + tl.arange(0,BLOCK_SIZE_M)
  n_offs = tl.arange(0,BLOCK_SIZE_N)
  mask = (m_offs[:,None] < m) & (n_offs[None,:] < n)
  m_offs = m_offs[:,None] * stride_m
  n_offs = n_offs[None,:] * stride_n
  input = tl.load(input_ptr + m_offs + n_offs,mask=mask)
  input = input.to(dtype=tl.float32)
  factor = tl.sum(input * input,axis=1,keep_dims=True)
  out = tl.rsqrt(factor) * input
  out_ptrs = out_ptr + m_offs + n_offs
  tl.store(out_ptrs,out)

In [15]:
def l2_norm(input):
  assert input.is_cuda
  assert input.ndim == 2
  out = torch.empty_like(input,dtype=input.dtype,device=input.device)
  m,n = input.shape
  BLOCK_SIZE_N = triton.next_power_of_2(n)
  BLOCK_SIZE_M = 128
  grid = triton.cdiv(m,BLOCK_SIZE_M)
  l2_norm_kernel[(grid,)](input,out,input.stride(0),input.stride(1),m,n,BLOCK_SIZE_M,BLOCK_SIZE_N)
  return out

In [17]:
import torch

x = torch.rand((256,1024),device='cuda')
l2norm = torch.norm(x, p=2, dim=1, keepdim=True)

x_normalized = x / l2norm
triton_normalized = l2_norm(x)


In [18]:
print(torch.allclose(x_normalized,triton_normalized))

True
