<a href="https://colab.research.google.com/github/doudi25/Triton/blob/main/AvgPool2d.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torch import nn
import triton
import triton.language as tl

In [60]:
@triton.jit
def AvgPool2d_kernel(
    image_ptr, out_ptr,
    stride_i0, stride_i1, stride_i2, stride_i3,
    stride_o0, stride_o1, stride_o2, stride_o3,
    H, W, k_h, k_w,BLOCK_SIZE_ROW: tl.constexpr, BLOCK_SIZE_COL: tl.constexpr,num_warps=2,num_stages=2):
  pid_bs = tl.program_id(axis=0)
  pid_ch = tl.program_id(axis=1)
  pid_2 = tl.program_id(axis=2)
  nbr_pid_w = W // k_w
  pid_w = pid_2 % nbr_pid_w
  pid_h = pid_2 // nbr_pid_w
  offs_row = pid_h * k_h + tl.arange(0,BLOCK_SIZE_ROW)
  offs_col = pid_w* k_w + tl.arange(0,BLOCK_SIZE_COL)
  mask  = (offs_row[:,None] < H) & (offs_col[None,:] < W)
  offs_row_ker = tl.arange(0,BLOCK_SIZE_ROW)
  offs_col_ker = tl.arange(0,BLOCK_SIZE_COL)
  mask_ker = (offs_row_ker[:,None] < k_h) & (offs_col_ker[None,:] < k_w)
  input_ptrs = image_ptr + pid_bs * stride_i0 + pid_ch * stride_i1 + offs_row[:,None] * stride_i2 + offs_col[None,:] * stride_i3
  input = tl.load(input_ptrs,mask=mask)
  input = tl.where(mask_ker,input,0.0)
  elem = tl.sum(input) / (k_h * k_w)
  out_ptrs = out_ptr + pid_bs * stride_o0 + pid_ch * stride_o1 + pid_h * stride_o2 + pid_w
  tl.store(out_ptrs,elem)

In [61]:
def avgpool2d(image:torch.tensor,kernel_shape:tuple):
  assert image.is_cuda and image.is_contiguous()
  assert image.ndim == 4
  bs,c,h,w = image.shape
  k_h, k_w = kernel_shape[0],kernel_shape[1]
  assert h % k_h == 0 and w % k_w == 0
  out = torch.empty((bs,c,h//k_h,w//k_w),device=image.device,dtype=image.dtype)
  BLOCK_SIZE_ROW = triton.next_power_of_2(k_h)
  BLOCK_SIZE_COL = triton.next_power_of_2(k_w)
  grid = (bs,c,triton.cdiv(h,k_h)*triton.cdiv(w,k_w))
  AvgPool2d_kernel[grid](image,out,image.stride(0),image.stride(1),image.stride(2),
                         image.stride(3),out.stride(0),out.stride(1),out.stride(2),
                         out.stride(3),h,w,k_h,k_w,BLOCK_SIZE_ROW,BLOCK_SIZE_COL)
  return out

In [62]:
avgpool = nn.AvgPool2d(kernel_size=(16,16),stride=(16,16)).to('cuda')
image = torch.randn((12,8,64,64),device='cuda')
out = avgpool(image)
out_triton = avgpool2d(image,avgpool.kernel_size)

In [63]:
print(torch.allclose(out,out_triton,1e-3))

True


In [64]:
def benchmark(fn,*args,warmup=8,steps=128):
  start = torch.cuda.Event(enable_timing=True)
  end = torch.cuda.Event(enable_timing=True)
  for _ in range(warmup):
    fn(*args)
  torch.cuda.synchronize()
  start.record()
  for _ in range(steps):
    fn(*args)
  end.record()
  torch.cuda.synchronize()
  return start.elapsed_time(end) / steps

In [74]:
shapes = [
    (2, 8, 64, 64),
    (2, 16, 32, 32),
    (4, 8, 64, 64),
    (4, 16, 32, 32),]
images = [torch.randn(shape, device='cuda', dtype=torch.float32) for shape in shapes]
triton_time = [benchmark(avgpool2d,image,avgpool.kernel_size) for image in images]
torch_time =  [benchmark(lambda a,b:avgpool(a),image,avgpool.kernel_size) for image in image]

In [75]:
from prettytable import PrettyTable

# Create a PrettyTable
table = PrettyTable()
table.field_names = ["Shape", "Triton Time (ms)", "PyTorch Time (ms)"]

# Add rows to the table
for shape, triton_t, torch_t in zip(shapes, triton_time, torch_time):
    table.add_row([shape, f"{triton_t:.4f}", f"{torch_t:.4f}"])

# Print the table
print(table)

+-----------------+------------------+-------------------+
|      Shape      | Triton Time (ms) | PyTorch Time (ms) |
+-----------------+------------------+-------------------+
|  (2, 8, 64, 64) |      0.0362      |       0.0360      |
| (2, 16, 32, 32) |      0.0343      |       0.0360      |
|  (4, 8, 64, 64) |      0.0332      |       0.0362      |
| (4, 16, 32, 32) |      0.0346      |       0.0363      |
+-----------------+------------------+-------------------+
