In [1]:
import os
from IPython.core.debugger import set_trace
import torch

os.environ['TRITON_INTERPRET'] = '1' # needs to be set *before* triton is imported

import triton
import triton.language as tl

def cdiv(a,b): return (a + b - 1) // b

def copy(x, bs, kernel_fn):
    z = torch.zeros_like(x)
    n = x.numel()
    n_blocks = cdiv(n, bs)
    grid = (n_blocks,)  # how many blocks do we have? can be 1d/2d/3d-tuple or function returning 1d/2d/3d-tuple

    # launch grid!
    # - kernel_fn is the triton kernel, which we write below
    # - grid is the grid we constructed above
    # - x,z,n,bs are paramters that are passed into each kernel function
    kernel_fn[grid](x,z,n,bs)

    return z    

@triton.jit
# When we pass torch tensors, they are automatically converted into a pointer to their first value
# E.g., above we passed x, but here we receive x_ptr
def copy_k(x_ptr, z_ptr, n, bs: tl.constexpr):
    pid = tl.program_id(0)
    offs = tl.arange(0, bs)  # compute the offsets from the pid 
    mask = offs < n
    x = tl.load(x_ptr + offs, mask) # load a vector of values, think of `x_ptr + offs` as `x_ptr[offs]`
    tl.store(z_ptr + offs, x, mask) # store a vector of values

    # print_if(f'pid = {pid} | offs = {offs}, mask = {mask}, x = {x}', '')

In [6]:
@triton.jit
def add_one(a, b, c, n: tl.constexpr):
    pid = tl.program_id(0)
    offsets = tl.arange(0, n)
    nums_a = tl.load(a + offsets)
    nums_b = tl.load(b + offsets)
    nums_c = nums_a + nums_b
    tl.store(c + offsets, nums_c)
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [14]:
x = torch.tensor([1,2,3,4,5,6])
y = torch.tensor([0,1,0,1,0,1])
c = torch.zeros_like(x)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
num_points = 100
x = torch.tensor([1,2,3,4,5,6], device=device)
# y = torch.arange(num_points, device=device)
c = torch.zeros_like(x)
print(x)

add_one[1,](x, y, c, 5)

tensor([1, 2, 3, 4, 5, 6])
