<a href="https://colab.research.google.com/github/doudi25/Triton/blob/main/MaxPool2d.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import triton
import triton.language as tl

In [8]:
@triton.jit
def maxpool2d_kernel(input_ptr,out_ptr,stride_in_b,stride_in_c,stride_in_m,stride_in_n,stride_b,stride_c,stride_m,stride_n,h,w,kernel_size,BLOCK_SIZE_K:tl.constexpr,BLOCK_SIZE_R:tl.constexpr):
  batch_id = tl.program_id(axis=0)
  channel_id = tl.program_id(axis=1)
  row_id = tl.program_id(axis=2)
  row_offs = row_id * BLOCK_SIZE_R + tl.arange(0,BLOCK_SIZE_R)
  row_mask = row_offs[:,None] < h
  row_offs = row_offs[:,None] * stride_in_m
  in_offs = batch_id * stride_in_b + channel_id * stride_in_c
  out_offs = batch_id * stride_b + channel_id * stride_c + row_id * stride_m
  for step in range(0,w//kernel_size):
    col_offs = step * BLOCK_SIZE_K + tl.arange(0,BLOCK_SIZE_K)
    col_mask = col_offs[None,:] < w
    col_offs = col_offs[None,:] * stride_in_n
    input = tl.load(input_ptr + in_offs + row_offs + col_offs,mask=(row_mask&col_mask))
    max = tl.max(input)
    out_ptrs  = out_ptr + out_offs + step
    tl.store(out_ptrs,max)

In [9]:
def triton_maxpool2d(input,kernel_size):
  assert input.is_cuda
  assert input.ndim == 4
  assert input.shape[-2] % kernel_size == 0 and input.shape[-1] % kernel_size == 0
  bs,c,h,w = input.shape
  out = torch.empty((bs,c,h//kernel_size,w//kernel_size),device=input.device)
  grid = (bs,c,h//kernel_size)
  BLOCK_SIZE_K = kernel_size
  BLOCK_SIZE_R = kernel_size
  maxpool2d_kernel[grid](input,out,input.stride(0),input.stride(1),input.stride(2),input.stride(3),out.stride(0),out.stride(1),
                         out.stride(2),out.stride(3),h,w,kernel_size,BLOCK_SIZE_K,BLOCK_SIZE_R)
  return out

In [11]:
input = torch.rand((2,4,32,32),device='cuda')
max = torch.nn.MaxPool2d((4,4),4).to('cuda')
out = max(input)
out_triton = triton_maxpool2d(input,4)
print(torch.allclose(out,out_triton))

True
