In [3]:
import torch

import triton
import triton.language as tl

In [4]:
@triton.jit
def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)

    block_start=BLOCK_SIZE*pid
    offsets = tl.arange(0, BLOCK_SIZE)

    mask = (block_start+offsets) < n_elements

    x = tl.load(x_ptr+block_start+offsets, mask=mask)
    y = tl.load(y_ptr+block_start+offsets, mask=mask)
    output = x + y

    tl.store(output_ptr+block_start+offsets, output, mask=mask)
    return

In [5]:

def add(x: torch.Tensor, y: torch.Tensor):
    # We need to preallocate the output.
    output = torch.empty_like(x)
    assert x.is_cuda and y.is_cuda and output.is_cuda
    n_elements = output.numel()
    # The SPMD launch grid denotes the number of kernel instances that run in parallel.
    # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
    # In this case, we use a 1D grid where the size is the number of blocks:
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
    # NOTE:
    #  - Each torch.tensor object is implicitly converted into a pointer to its first element.
    #  - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel.
    #  - Don't forget to pass meta-parameters as keywords arguments.
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
    # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
    # running asynchronously at this point.
    return output



In [6]:
# %%
# We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness:

torch.manual_seed(0)
size = 98432
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
output_torch = x + y
output_triton = add(x, y)
print(output_torch)
print(output_triton)
print(f'The maximum difference between torch and triton is '
      f'{torch.max(torch.abs(output_torch - output_triton))}')


tensor([1.3713, 1.3076, 0.4940,  ..., 0.9592, 0.3409, 1.2567], device='cuda:0')
tensor([1.3713, 1.3076, 0.4940,  ..., 0.9592, 0.3409, 1.2567], device='cuda:0')
The maximum difference between torch and triton is 0.0
