In [1]:
"""
You will learn about:

The basic programming model of Triton.

The triton.jit decorator, which is used to define Triton kernels.

The best practices for validating and benchmarking your custom ops against native reference implementations.

""" 

'\nYou will learn about:\n\nThe basic programming model of Triton.\n\nThe triton.jit decorator, which is used to define Triton kernels.\n\nThe best practices for validating and benchmarking your custom ops against native reference implementations.\n\n'

In [13]:
# Compute kernel


import torch
import time
import triton
import triton.language as tl
DEVICE = triton.runtime.driver.active.get_current_device()

print(DEVICE, triton.__version__)

@triton.jit
def add_kernel(x_ptr,  # *Pointer* to first input vector.
               y_ptr,  # *Pointer* to second input vector.
               output_ptr,  # *Pointer* to output vector.
               n_elements,  # Size of the vector.
               BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process.
               # NOTE: `constexpr` so it can be used as a shape value.
               ):  
    # There are multiple 'programs' processing different data. We identify which program
    # we are here:
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.
    # This program will process inputs that are offset from the initial data.
    # For instance, if you had a vector of length 256 and block_size of 64, the programs
    # would each access the elements [0:64, 64:128, 128:192, 192:256].
    # Note that offsets is a list of pointers:
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # Create a mask to guard memory operations against out-of-bounds accesses.
    mask = offsets < n_elements
    # Load x and y from DRAM, masking out any extra elements in case the input is not a
    # multiple of the block size.
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    # Write x + y back to DRAM.

    
    # We dont need any return, we just directly write to memory of the output_ptr! Thats why we need to know it before and get it initialised!!
    tl.store(output_ptr + offsets, output, mask=mask)



0 3.0.0


In [14]:
def add(x: torch.Tensor, y: torch.Tensor):
    # We need to preallocate the output.

    output = torch.empty_like(x)
    print(f"x device: {x.device}, y device: {y.device}, output device: {output.device}, DEVICE: {DEVICE}")
    start = time.time()
    # assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
    n_elements = output.numel()
    # The SPMD launch grid denotes the number of kernel instances that run in parallel.
    # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
    # In this case, we use a 1D grid where the size is the number of blocks:
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
    # NOTE:
    #  - Each torch.tensor object is implicitly converted into a pointer to its first element.
    #  - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel.
    #  - Don't forget to pass meta-parameters as keywords arguments.
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
    # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
    # running asynchronously at this point.
    end = time.time()

    print(f'triton time {end-start}')
    return output

In [15]:
torch.manual_seed(0)
size = 98432
x = torch.rand(size, device=DEVICE)
y = torch.rand(size, device=DEVICE)
start = time.time()
output_torch = x + y
end = time.time()
print(f'torch time {end-start}')
output_triton = add(x, y)
print(output_torch)
print(output_triton)
print(f'The maximum difference between torch and triton is '
      f'{torch.max(torch.abs(output_torch - output_triton))}')

torch time 0.00012874603271484375
x device: cuda:0, y device: cuda:0, output device: cuda:0, DEVICE: 0
triton time 0.10018682479858398
tensor([1.3713, 1.3076, 0.4940,  ..., 1.2472, 1.3889, 0.9225], device='cuda:0')
tensor([1.3713, 1.3076, 0.4940,  ..., 1.2472, 1.3889, 0.9225], device='cuda:0')
The maximum difference between torch and triton is 0.0
