<a href="https://colab.research.google.com/github/doudi25/Triton/blob/main/MaxPool1d.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import triton
import triton.language as tl
import torch

In [3]:
@triton.jit
def maxpool1d_kernel(input_ptr,out_ptr,channels,dim,stride_in_b,stride_in_m,stride_in_n,
                     stride_b,stride_m,stride_n,BLOCK_SIZE_K:tl.constexpr,BLOCK_SIZE_C:tl.constexpr):
  batch_id = tl.program_id(axis=0)
  channel_id = tl.program_id(axis=1)
  col_id = tl.program_id(axis=2)
  channel_offs = channel_id * BLOCK_SIZE_C + tl.arange(0,BLOCK_SIZE_C)
  col_offs = col_id * BLOCK_SIZE_K + tl.arange(0,BLOCK_SIZE_K)
  channel_mask = channel_offs[:,None] < channels
  channel_out_mask = channel_mask
  channel_out_offs = channel_offs
  col_mask = col_offs[None,:] < dim
  channel_offs = channel_offs[:,None] * stride_in_m
  col_offs = col_offs[None,:] * stride_in_n
  input_offs = input_ptr + batch_id * stride_in_b + channel_offs + col_offs
  input = tl.load(input_offs,mask=(channel_mask & col_mask))
  maximum = tl.max(input,axis=1)
  channel_out_off = channel_out_offs[:,None] * stride_m
  out_offs = batch_id * stride_b + channel_out_off + col_id[None,:]
  tl.store(out_offs+out_ptr,maximum[:,None])


In [4]:
def triton_maxpool1d(input,kernel_shape):
  assert input.is_cuda
  assert input.ndim == 3
  assert input.shape[-1] % kernel_shape == 0
  bs, channels, dim = input.shape
  out = torch.empty((bs,channels,dim//kernel_shape),device=input.device)
  BLOCK_SIZE_K = kernel_shape
  BLOCK_SIZE_C = triton.next_power_of_2(kernel_shape)
  grid = (bs,channels//BLOCK_SIZE_C,dim//BLOCK_SIZE_K)
  maxpool1d_kernel[grid](input,out,channels,dim,input.stride(0),input.stride(1),input.stride(2)
  ,out.stride(0),out.stride(1),out.stride(2),BLOCK_SIZE_K,BLOCK_SIZE_C)
  return out

In [5]:
a = torch.rand(2,8,12).to('cuda')
b = torch.nn.MaxPool1d(4,4).to('cuda')
out = b(a)
out_triton = triton_maxpool1d(a,4)
print(torch.allclose(out,out_triton))

True
