<a href="https://colab.research.google.com/github/doudi25/Triton/blob/main/Sum.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import triton
import triton.language as tl

In [34]:
@triton.autotune(
    configs=[ triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':128},num_warps=4,num_stages=2),
             triton.Config({'BLOCK_SIZE_M':256,'BLOCK_SIZE_N':256},num_warps=4,num_stages=4)],
    key = ['m','n'])
@triton.jit
def sum_kernel(a_ptr,out_ptr,stride_am,stride_an,m,n,BLOCK_SIZE_M:tl.constexpr,BLOCK_SIZE_N:tl.constexpr):
  pid_m = tl.program_id(axis=0)
  pid_n = tl.program_id(axis=1)
  offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0,BLOCK_SIZE_M)
  offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0,BLOCK_SIZE_N)
  mask = (offs_m[:,None] < m) & (offs_n[None,:] < n)
  a = tl.load(a_ptr+offs_m[:,None] * stride_am+offs_n[None,:]*stride_an,mask=mask)
  a = a.to(dtype=tl.float32)
  out = tl.sum(a,axis=1,keep_dims=True)
  out = out.to(dtype=tl.float32)
  mask_out = offs_m[:,None] < m
  tl.atomic_add(out_ptr+offs_m[:,None],out,mask=mask_out)

In [35]:
def sum(a:torch.tensor):
  assert a.is_cuda and a.is_contiguous()
  assert a.ndim > 1
  m,n = a.shape
  grid = lambda meta:(triton.cdiv(m,meta['BLOCK_SIZE_M']),triton.cdiv(n,meta['BLOCK_SIZE_N']))
  out = torch.empty((m,1),device=a.device,dtype=a.dtype)
  sum_kernel[grid](a,out,a.stride(0),a.stride(1),m,n)
  return out

In [36]:
a = torch.randn((1024,2048),device='cuda')

In [38]:
triton_sum = sum(a)
torch_sum = a.sum(dim=1,keepdim=True)

In [41]:
print(torch.allclose(triton_sum,torch_sum,1e-3))

True
