<a href="https://colab.research.google.com/github/doudi25/Triton/blob/main/BatchNorm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [123]:
import triton
import torch
import triton.language as tl


In [154]:
@triton.jit
def batchnorm_kernel(input_ptr,out_ptr,stride_m,stride_n,bs,seq_len,gemma,beta,eps,BLOCK_SEQ:tl.constexpr,BLOCK_BATCH:tl.constexpr):
  seq_id = tl.program_id(axis=0)
  seq_offs = seq_id * BLOCK_SEQ + tl.arange(0,BLOCK_SEQ)
  batch_offs = tl.arange(0,BLOCK_BATCH)
  seq_mask = seq_offs < seq_len
  mask = (batch_offs[:,None] < bs )& (seq_offs[None,:] < seq_len)
  input_ptrs = input_ptr + batch_offs[:,None] * stride_m + seq_offs[None,:] * stride_n
  input = tl.load(input_ptrs,mask=mask)
  input = input.to(dtype=tl.float32)
  mean = tl.sum(input,axis=0,keep_dims=True) / bs
  mean = mean.to(dtype=tl.float32)
  eps = eps.to(dtype=tl.float32)
  var = (tl.sum((input-mean)*(input-mean),axis=0,keep_dims=True)/bs).to(dtype=tl.float32)
  gemma = tl.load(gemma + seq_offs * stride_n,mask=seq_mask)
  gemma = gemma.to(dtype=tl.float32)
  beta = tl.load(beta + seq_offs * stride_n,mask=seq_mask)
  beta = beta.to(dtype=tl.float32)
  out  = ( tl.rsqrt(var+eps)*(input - mean) * gemma ) + beta
  out = out.to(dtype=tl.float32)
  out_ptrs = out_ptr +  batch_offs[:,None] * stride_m + seq_offs[None,:] * stride_n
  tl.store(out_ptrs,out)

In [155]:
def batchnorm_1d(input,gemma,beta,eps):
  assert input.is_cuda
  assert gemma.is_cuda and beta.is_cuda
  assert input.ndim == 2
  bs,seq_len = input.shape
  out = torch.empty_like(input,device=input.device,dtype=input.dtype)
  BLOCK_SEQ = triton.next_power_of_2(bs)
  BLOCK_BATCH = triton.next_power_of_2(bs)
  grid = (seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ
  batchnorm_kernel[(grid,)](input,out,input.stride(0),input.stride(1),bs,seq_len,gemma,beta,eps,
                         BLOCK_SEQ,BLOCK_BATCH)
  return out

In [172]:
layer = torch.nn.BatchNorm1d(1024).to('cuda')
input =  torch.randn((256,1024),device='cuda')
out_triton = batchnorm_1d(input,layer.weight,layer.bias,layer.eps)
out = layer(input)


In [175]:
torch.allclose(out,out_triton,1e-3)

True