<a href="https://colab.research.google.com/github/doudi25/Triton/blob/main/GroupNorm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import triton
import triton.language as tl
from torch import nn

In [30]:
@triton.jit
def groupnorm_kernel(img_ptr,gamma_ptr,bias_ptr,eps, stride_i0, stride_i1, stride_i2,
                      stride_i3, stride_i4, bs, n_grs, c, h, w, BLOCK_SIZE_ROW:tl.constexpr,
                      BLOCK_SIZE_COL:tl.constexpr,num_warps=4):
  pid_b = tl.program_id(axis=0)
  pid_gr = tl.program_id(axis=1)
  offs_row = tl.arange(0,BLOCK_SIZE_ROW)
  offs_col = tl.arange(0,BLOCK_SIZE_COL)
  mask = (offs_row[:,None]<h) & (offs_col[None,:] < w)
  mu = 0.0
  var = 0.0
  for step in range(c//n_grs):
    img_ptrs = img_ptr + pid_b * stride_i0 + pid_gr * stride_i1 + step * stride_i2 + offs_row[:,None] * stride_i3 + offs_col[None,:] * stride_i4
    img_chunk = tl.load(img_ptrs,mask=mask)
    mu += tl.sum(img_chunk) / (h * w)
  mu = mu / ( c // n_grs)
  for step in range(c//n_grs):
    img_ptrs = img_ptr + pid_b * stride_i0 + pid_gr * stride_i1 + step * stride_i2 + offs_row[:,None] * stride_i3 + offs_col[None,:] * stride_i4
    img_chunk = tl.load(img_ptrs,mask=mask)
    var += tl.sum((img_chunk -mu) * (img_chunk -mu)) / (h * w)
  var = var / ( c // n_grs)
  for step in range(c//n_grs):
    img_ptrs = img_ptr + pid_b * stride_i0 + pid_gr * stride_i1 + step * stride_i2 + offs_row[:,None] * stride_i3 + offs_col[None,:] * stride_i4
    img_chunk = tl.load(img_ptrs,mask=mask)
    gamma = tl.load(gamma_ptr + pid_gr * (c//n_grs) + step)
    bias = tl.load(bias_ptr + pid_gr * (c//n_grs) + step)
    img_chunk = ( (img_chunk - mu) * tl.rsqrt(var + eps) )* gamma + bias
    tl.store(img_ptrs,img_chunk,mask=mask)


In [31]:
def groupnorm(image:torch.tensor,n_grs:int,gamma:float,bias:float,eps:float):
  assert image.is_cuda and image.is_contiguous()
  assert image.ndim == 4
  bs,c,h,w = image.shape
  assert c % n_grs == 0
  image = image.view(bs,n_grs,c//n_grs,h,w)
  BLOCK_SIZE_ROW = triton.next_power_of_2(h)
  BLOCK_SIZE_COL = triton.next_power_of_2(w)
  grid = (bs,n_grs)
  groupnorm_kernel[grid](image,gamma,bias,eps,image.stride(0),image.stride(1),
                         image.stride(2),image.stride(3),image.stride(4),bs,n_grs
                         ,c,h,w,BLOCK_SIZE_ROW,BLOCK_SIZE_COL)
  return image.view(bs,c,h,w)

In [34]:
normlayer = nn.GroupNorm(3,12).to('cuda')
image = torch.rand((4,12,16,16),device='cuda',dtype=torch.float32)
img_norm = normlayer(image)
img_norm_triton = groupnorm(image,normlayer.num_groups,normlayer.weight,normlayer.bias,normlayer.eps)

In [37]:
from prettytable import PrettyTable
def benchmark(fn, *args, warmup=8, steps=128):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    for _ in range(warmup):
        fn(*args)
    torch.cuda.synchronize()
    start.record()
    for _ in range(steps):
        fn(*args)
    end.record()
    torch.cuda.synchronize()
    return start.elapsed_time(end) / steps

# Define shapes for benchmarking
shapes = [
    (4, 8, 64, 64),
    (4, 16, 32, 32),
    (8, 32, 64, 64),
    (16, 64, 32, 32),
]

# Create random inputs
images = [torch.randn(shape, device='cuda', dtype=torch.float32) for shape in shapes]
gemmas = [torch.ones(shape[1], device='cuda', dtype=torch.float32) for shape in shapes]
betas = [torch.zeros(shape[1], device='cuda', dtype=torch.float32) for shape in shapes]
eps = 1e-5
layers = [torch.nn.GroupNorm(4,shape[1],device='cuda',dtype=torch.float32) for shape in shapes]

triton_time = [benchmark(groupnorm, image,4, gemma, beta, eps) for image, gemma, beta in zip(images, gemmas, betas)]

layers = [torch.nn.BatchNorm2d(shape[1], device='cuda', dtype=torch.float32) for shape in shapes]

for layer in layers:
    layer.eval()

torch_time = [benchmark(lambda x, g, b, layer=layer: layer(x), image, gemma, beta)
              for layer, image, gemma, beta in zip(layers, images, gemmas, betas)]

table = PrettyTable()
table.field_names = ["Shape", "Triton Time (ms)", "PyTorch Time (ms)"]

for shape, triton_t, torch_t in zip(shapes, triton_time, torch_time):
    table.add_row([shape, f"{triton_t:.4f}", f"{torch_t:.4f}"])

print(table)

+------------------+------------------+-------------------+
|      Shape       | Triton Time (ms) | PyTorch Time (ms) |
+------------------+------------------+-------------------+
|  (4, 8, 64, 64)  |      0.0373      |       0.0529      |
| (4, 16, 32, 32)  |      0.0528      |       0.0522      |
| (8, 32, 64, 64)  |      0.0599      |       0.0559      |
| (16, 64, 32, 32) |      0.0625      |       0.0520      |
+------------------+------------------+-------------------+
