<a href="https://colab.research.google.com/github/doudi25/Triton/blob/main/inner_product%2Cipynb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import triton
import triton.language as tl

In [None]:
@triton.jit
def inner_kernel(a_ptr,b_ptr,out_ptr,n_elements,BLOCK_SIZE:tl.constexpr):
  pid = tl.program_id(axis=0)
  offs = pid * BLOCK_SIZE + tl.arange(0,BLOCK_SIZE)
  mask = offs < n_elements
  a = tl.load(a_ptr + offs,mask=mask)
  b = tl.load(b_ptr + offs,mask=mask)
  a = a.to(dtype=tl.float32)
  b = b.to(dtype=tl.float32)
  partial_sum = tl.sum(a*b)
  tl.atomic_add(out_ptr,partial_sum)

In [None]:
def inner_product(a,b):
  assert a.is_cuda
  assert b.is_cuda
  out = torch.zeros(1,device=a.device,dtype=a.dtype)
  n_elements = a.numel()
  grid = lambda meta:(triton.cdiv(n_elements,meta['BLOCK_SIZE']),)
  inner_kernel[grid](a,b,out,n_elements,BLOCK_SIZE=256)
  return out

In [None]:
a = torch.randn(1,256*256,device='cuda')
b = torch.randn(1,256*256,device='cuda')
c = (a*b).sum()
d = inner_product(b,a)

In [None]:
torch.allclose(c,d)

True