In [2]:
import torch
import triton
import triton.language as tl

DEVICE = triton.runtime.driver.active.get_active_torch_device()

In [3]:
@triton.jit
def add_kernel(x_ptr, #ptr to first element of x input vec
               y_ptr, #ptr to first element of y input vec
               output_ptr, #ptr to output vec
               n_elements, #size of the vec
               BLOCK_SIZE: tl.constexpr
               #no of elems each program(block) should process and
               #we use 'constexpr' so it can be used as a shape value
               ):
  #there are multiple programs (block) processing different data
  #and we have to identify which program
  pid = tl.program_id(axis=0) #we use a 1D launch grid so axis is 0
  #this program will process inputs that are offset from the initial data
  #for instance, if you had a vector of length 256 and block_size of 64, the programs
  #would each access the elements [0:64, 64:128, 128:192, 192:256]
  #offsets is a list of pointers
  block_start = pid*BLOCK_SIZE
  offsets = block_start + tl.arange(0, BLOCK_SIZE)
  #create a mask to not access out of bounds elements i.e. elements that don't exist
  #since n_elements might not be a multiple of block_size, the last block will be created
  #with only some of the threads required so we have to turn the rest of threads in the block off
  mask = offsets < n_elements
  #we now load x and y from DRAM, masking out any extra elems in case the input
  #is not a multiple of block size
  x = tl.load(x_ptr + offsets, mask=mask)
  y = tl.load(y_ptr + offsets, mask=mask)
  output = x + y
  #write x + y i.e. output back to DRAM
  tl.store(output_ptr + offsets, output, mask=mask)


In [4]:
#creating a helper function to allocate the output tensor and enqueue the above kernel with appropriate grid/block sizes

def add(x: torch.Tensor, y: torch.Tensor):
  #we need to preallocate the output
  output = torch.empty_like(x)
  assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
  n_elements = output.numel()
  #SPMD launch grid denotes the number of kernel instances that run in parallel
  #in this case, we use a 1D launch grid where the size is the number of blocks
  grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)

  add_kernel[grid](x,y, output, n_elements, BLOCK_SIZE=1024)

  return output

In [5]:
import time

In [8]:
torch.manual_seed(0)
size = 100000000
x = torch.rand(size, device=DEVICE)
y = torch.rand(size, device=DEVICE)

start_torch = time.time()
output_torch = x + y
end_torch = time.time()

start_ton = time.time()
output_triton = add(x,y)
end_ton = time.time()

print(output_torch)
print(output_triton)
print(f'The maximum difference between torch and triton is '
      f'{torch.max(torch.abs(output_torch - output_triton))}')

print(f"Time taken by torch is: {end_torch-start_torch} seconds")
print(f"Time taken by triton is: {end_ton-start_ton} seconds")



tensor([0.9809, 0.6194, 0.9237,  ..., 0.8593, 1.6668, 1.0278], device='cuda:0')
tensor([0.9809, 0.6194, 0.9237,  ..., 0.8593, 1.6668, 1.0278], device='cuda:0')
The maximum difference between torch and triton is 0.0
Time taken by torch is: 0.0005977153778076172 seconds
Time taken by triton is: 0.0005478858947753906 seconds
