<a href="https://colab.research.google.com/github/doudi25/Triton/blob/main/BatchNorm2d.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [28]:
import torch
import triton
import triton.language as tl

In [10]:
@triton.jit
def batchnorm2d_kernel(img_ptr,stride_i0,stride_i1,stride_i2,stride_i3,gemma_ptr,beta_ptr,eps,
                       bs,c,h,w,BLOCK_SIZE_ROW:tl.constexpr,BLOCK_SIZE_COL:tl.constexpr,
                       num_warps=8):
  pid_c = tl.program_id(axis=0)
  offs_row = tl.arange(0,BLOCK_SIZE_ROW)
  offs_col = tl.arange(0,BLOCK_SIZE_COL)
  mask = (offs_row[:,None] < h) & (offs_col[None,:] < w)
  mu = 0.0
  var = 0.0
  gemma = tl.load(gemma_ptr+pid_c)
  beta = tl.load(beta_ptr+pid_c)
  for step in range(bs):
    img_ptrs = img_ptr + step * stride_i0 + pid_c * stride_i1 + offs_row[:,None] * stride_i2 + offs_col[None,:] * stride_i3
    img_chunk = tl.load(img_ptrs,mask=mask)
    mu += tl.sum(img_chunk)/( h * w )
  mu = mu / bs
  for step in range(bs):
    img_ptrs = img_ptr + step * stride_i0 + pid_c * stride_i1 + offs_row[:,None] * stride_i2 + offs_col[None,:] * stride_i3
    img_chunk = tl.load(img_ptrs,mask=mask)
    var += tl.sum((img_chunk-mu)*(img_chunk-mu))/( h * w )
  var = var / bs
  for step in range(bs):
    img_ptrs = img_ptr + step * stride_i0 + pid_c * stride_i1 + offs_row[:,None] * stride_i2 + offs_col[None,:] * stride_i3
    img_chunk = tl.load(img_ptrs,mask=mask)
    img_chunk = ((img_chunk-mu)* tl.rsqrt(var+eps)) * gemma + beta
    tl.store(img_ptrs,img_chunk)

In [14]:
def batchnorm2d(image:torch.tensor,gemma,beta,eps):
  assert image.is_cuda and image.is_contiguous()
  assert image.ndim == 4
  bs,c,h,w = image.shape
  BLOCK_SIZE_ROW = triton.next_power_of_2(h)
  BLOCK_SIZE_COL = triton.next_power_of_2(w)
  grid = (c,)
  batchnorm2d_kernel[grid](image,image.stride(0),image.stride(1),
                           image.stride(2),image.stride(3),gemma,
                           beta,eps,bs,c,h,w,BLOCK_SIZE_ROW,BLOCK_SIZE_COL)
  return image

In [41]:
image = torch.randn((2,16,32,32),device='cuda')
layer = torch.nn.BatchNorm2d(16).to('cuda')
out = layer(image)
out_triton = batchnorm2d(image,layer.weight,layer.bias,layer.eps)
torch.allclose(out,out_triton)

True

In [43]:
from prettytable import PrettyTable
def benchmark(fn, *args, warmup=8, steps=128):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    for _ in range(warmup):
        fn(*args)
    torch.cuda.synchronize()
    start.record()
    for _ in range(steps):
        fn(*args)
    end.record()
    torch.cuda.synchronize()
    return start.elapsed_time(end) / steps

# Define shapes for benchmarking
shapes = [
    (4, 8, 64, 64),
    (4, 16, 32, 32),
    (8, 8, 64, 64),
    (16, 16, 32, 32),
]

# Create random inputs
images = [torch.randn(shape, device='cuda', dtype=torch.float32) for shape in shapes]
gemmas = [torch.ones(shape[1], device='cuda', dtype=torch.float32) for shape in shapes]
betas = [torch.zeros(shape[1], device='cuda', dtype=torch.float32) for shape in shapes]
eps = 1e-5
layers = [torch.nn.BatchNorm2d(shape[1],device='cuda',dtype=torch.float32) for shape in shapes]

triton_time = [benchmark(batchnorm2d, image, gemma, beta, eps) for image, gemma, beta in zip(images, gemmas, betas)]

layers = [torch.nn.BatchNorm2d(shape[1], device='cuda', dtype=torch.float32) for shape in shapes]

for layer in layers:
    layer.eval()

torch_time = [benchmark(lambda x, g, b, layer=layer: layer(x), image, gemma, beta)
              for layer, image, gemma, beta in zip(layers, images, gemmas, betas)]

table = PrettyTable()
table.field_names = ["Shape", "Triton Time (ms)", "PyTorch Time (ms)"]

for shape, triton_t, torch_t in zip(shapes, triton_time, torch_time):
    table.add_row([shape, f"{triton_t:.4f}", f"{torch_t:.4f}"])

print(table)

+------------------+------------------+-------------------+
|      Shape       | Triton Time (ms) | PyTorch Time (ms) |
+------------------+------------------+-------------------+
|  (4, 8, 64, 64)  |      0.0382      |       0.0446      |
| (4, 16, 32, 32)  |      0.0355      |       0.0505      |
|  (8, 8, 64, 64)  |      0.0533      |       0.0446      |
| (16, 16, 32, 32) |      0.0486      |       0.0445      |
+------------------+------------------+-------------------+
