In [1]:
import torch
import triton
import triton.language as tl

DEVICE = triton.runtime.driver.active.get_active_torch_device()

In [None]:
@triton.jit
def add_kernel(x_ptr, #ptr to first element of x input vec
               y_ptr, #ptr to first element of y input vec
               output_ptr, #ptr to output vec
               n_elements, #size of the vec
               BLOCK_SIZE: tl.constexpr
               #no of elems each program(block) should process and
               #we use 'constexpr' so it can be used as a shape value
               ):
  #there are multiple programs (block) processing different data
  #and we have to identify which program
  pid = tl.program_id(axis=0) #we use a 1D launch grid so axis is 0
  #this program will process inputs that are offset from the initial data
  #for instance, if you had a vector of length 256 and block_size of 64, the programs
  #would each access the elements [0:64, 64:128, 128:192, 192:256]
  #offsets is a list of pointers
  block_start = pid*BLOCK_SIZE
  offsets = block_start + tl.arange(0, BLOCK_SIZE)
  #create a mask to not access out of bounds elements i.e. elements that don't exist
  #since n_elements might not be a multiple of block_size, the last block will be created
  #with only some of the threads required so we have to turn the rest of threads in the block off
  mask = offsets < n_elements
  #we now load x and y from DRAM, masking out any extra elems in case the input
  #is not a multiple of block size
  x = tl.load(x_ptr + offsets, mask=mask)
  y = tl.load(y_ptr + offsets, mask=mask)
  output = x + y
  #write x + y i.e. output back to DRAM
  tl.store(output_ptr + offsets, output, mask=mask)
