<a href="https://colab.research.google.com/github/casanovaalonso/MathForDL/blob/main/01_triton_basics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
# prompt: install torch

!pip install torch
!pip install triton

Collecting triton
  Downloading triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.3 kB)
Downloading triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (209.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m209.5/209.5 MB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: triton
Successfully installed triton-3.1.0


In [5]:
import torch

import triton
import triton.language as tl

In [8]:
DEVICE = torch.device(f"cuda:{torch.cuda.current_device()}")
print(f"Device: {DEVICE}")

Device: cuda:0


In [11]:
@triton.jit
def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
  """
  Triton language tutorial 1. This is a kenerl to add two vectors.

  Steps of the kernel
    * pid: program identifier used to know which "thread" we are running, also
    informs the index to access the block.
    * block_start: The kernel will access the BLOCK_SIZE positions.
        (eg. 10:64, 65:128, 129:192])
        This parameter tells the kernel the index to start the access.
    * offsets: We create a list of indices for example [0, 1, 2, 3... BLOCK_SIZE]
    * mask: a binary mask to avoid index errors
    * x, y: we access the data using the pointers to the vector + the offsets.
        This means the vectors are stored in contigous memory positions.
    * output: we store the result in the output vector.
  """
  pid = tl.program_id(axis=0)
  block_start = pid * BLOCK_SIZE
  offsets = block_start + tl.arange(0, BLOCK_SIZE)
  mask = offsets < n_elements
  x = tl.load(x_ptr + offsets, mask=mask)
  y = tl.load(y_ptr + offsets, mask=mask)
  output = x + y
  tl.store(output_ptr + offsets, output, mask=mask)

In [14]:
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
  output = torch.empty_like(x)
  assert x.device == y.device == output.device == DEVICE
  n_elements = output.numel()
  grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
  add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
  return output


In [15]:
torch.manual_seed(0)
size = 98432
x = torch.rand(size, device=DEVICE)
y = torch.rand(size, device=DEVICE)
output_torch = x + y
output_triton = add(x, y)
print(output_torch)
print(output_triton)
print(f'The maximum difference between torch and triton is '
      f'{torch.max(torch.abs(output_torch - output_triton))}')

tensor([1.3713, 1.3076, 0.4940,  ..., 0.4024, 1.7918, 1.0686], device='cuda:0')
tensor([1.3713, 1.3076, 0.4940,  ..., 0.4024, 1.7918, 1.0686], device='cuda:0')
The maximum difference between torch and triton is 0.0
