<a href="https://colab.research.google.com/github/doudi25/Triton/blob/main/element_wise_addition.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [13]:
import torch
import triton
import triton.language as tl

In [16]:
@triton.jit
def matrix_add(a_ptr,b_ptr,c_ptr,M,N,stride_am,stride_an,BLOCK_SIZE_M:tl.constexpr,BLOCK_SIZE_N:tl.constexpr):
  pid_m = tl.program_id(axis=0)
  pid_n = tl.program_id(axis=1)
  offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0,BLOCK_SIZE_M)
  offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0,BLOCK_SIZE_N)
  a_ptrs = a_ptr + offs_m[:,None] * stride_am + offs_n[None,:] * stride_an
  b_ptrs = b_ptr + offs_m[:,None] * stride_am + offs_n[None,:] * stride_an
  mask = (offs_m[:,None] < M) & (offs_n[None,:] < N )
  a = tl.load(a_ptrs,mask=mask)
  b = tl.load(b_ptrs,mask=mask)
  c = a + b
  c_ptrs = c_ptr + offs_m[:,None] * stride_am + offs_n[None,:] * stride_an
  tl.store(c_ptrs,c,mask=mask)

In [19]:
def mat_add(a,b):
  assert a.is_contiguous() and b.is_contiguous()
  assert a.shape == b.shape
  M,N = a.shape
  c = torch.empty_like(a,device=a.device,dtype=a.dtype)
  grid = lambda meta: (triton.cdiv(M,meta['BLOCK_SIZE_M']),triton.cdiv(N,meta['BLOCK_SIZE_N']))
  matrix_add[grid](
      a,b,c,
      M,N,
      a.stride(0),a.stride(1),
      BLOCK_SIZE_M=128,BLOCK_SIZE_N=128)
  return c

In [23]:
B = torch.rand((16384,16384),device='cuda')
A = torch.rand((16384,16384),device='cuda')

In [26]:
C = mat_add(A,B)
D = A + B

In [27]:
print(torch.allclose(C,D))

True
