<a href="https://colab.research.google.com/github/doudi25/Triton/blob/main/Triton_Addition_Kernel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [17]:
import triton
import triton.language as tl
import torch


In [18]:
@triton.jit
def add_kernel(a_ptr,b_ptr,out_ptr,n_elements,BLOCK_SIZE:tl.constexpr):
  # assign program id is like assign nth thread block in cuda
  pid = tl.program_id(axis=0)
  # block_id to assign the index of which block we are work with know
  block_id = pid * BLOCK_SIZE
  #
  offsets = block_id + tl.arange(0,BLOCK_SIZE)
  # create  boolean mask to avoid threads that are not needed to participate in computation
  mask = offsets < n_elements
  # load the values located by the pointer that is pointed to the first element in array , since the vector is contiguous sequence of memory and have stride 1 we can add directly the offset to pointer
  a = tl.load(a_ptr + offsets,mask = mask)
  b = tl.load(b_ptr + offsets,mask = mask)
  # compute the addition
  out = a + b
  # store the result out_ptr + offsets , if we are working with first part then we are saving and ex block_size = 4 , out_ptr + [0,1,2,3] = out this will change the first four values
  tl.store(out_ptr + offsets,out,mask=mask)


In [19]:
def add(a,b):
  # allocate the output in the memory
  out = torch.empty_like(a)
  # calculate the number of elements
  n_elements = out.numel()
  # configuration of the grid , to know how much thread block we need (aka pid)
  grid = lambda meta : (triton.cdiv(n_elements,meta['BLOCK_SIZE']),)
  # launch the kernel
  add_kernel[grid](a,b,out,n_elements,BLOCK_SIZE=1024)
  return out

In [20]:
torch.cuda.manual_seed(42)
# generate random vectors with size of 1 million
a = torch.randn(1000000,device='cuda')
b = torch.randn(1000000,device='cuda')
out = add(a,b)

In [21]:
# in pytorch
result = a + b

In [23]:
print(f'the result is {result==out}')
print(f'{torch.allclose(out,result)}')

the result is tensor([True, True, True,  ..., True, True, True], device='cuda:0')
True
