In [None]:
import os
from IPython.core.debugger import set_trace
import torch

In [None]:
# os.environ['TRITON_INTERCEPT'] = '1'

In [None]:
def check_tensors_gpu_ready(*tensors):
    for t in tensors:
        assert t.is_contiguous, f"Tensor {t} is not contiguous"
        if not os.environ.get('TRITON_INTERCEPT') == '1':
            assert t.is_cuda, f"Tensor {t} is not on GPU"

In [None]:
a = torch.tensor([1, 2, 3], device='cuda')
check_tensors_gpu_ready(a)

In [None]:
def test_pid_conds(conds, pid_0=[0], pid_1=[0], pid_2=[0]):
    # Q: Are pids 1 element lists?
    pids = pid_0[0], pid_1[0], pid_2[0]
    conds = conds.replace(' ', '').split(',')
    for i, (cond, pid) in enumerate(zip(conds, pids)):
        print(f"{pid} ... {cond}")
        if cond=='': continue
        op, threshold = cond[0], int(cond[1:])
        if op not in ['<','>','>=','<=','=', '!=']:
            raise ValueError(f"Rules may only use these ops: '<','>','>=','<=','=', '!='. Invalid rule: '{condition}'.")
        op = '==' if op == '=' else op
        if not eval(f'{pid} {op} {threshold}'): return False
    return True

In [None]:
assert test_pid_conds('')

In [None]:
assert test_pid_conds('>0', [1], [1])

In [None]:
a = [1]
b = [2, 3, 4]
for i, j in zip(a,  b):
    print(i, j)

In [None]:
def breakpoint_if(conds, pid_0=[0], pid_1=[0], pid_2=[0]):
    '''Stop kernel, if any condition of pids is fulfilled'''
    if test_pid_conds(conds, pid_0, pid_1, pid_2): set_trace()

def print_if(txt, conds, pid_0=[0], pid_1=[0], pid_2=[0]):
    '''Print txt, if any condition of pids is fulfilled'''
    if test_pid_conds(conds, pid_0, pid_1, pid_2): print(txt)

In [None]:
def cdiv(a, b): return (a + b - 1) // b
assert cdiv(10, 2) == 5
assert cdiv(10, 3) == 4

In [None]:
import triton
import triton.language as tl

In [None]:
def copy(x, bs, kernel_fn):
    z = torch.zeros_like(x, device=x.device)
    check_tensors_gpu_ready(x, z)
    n = x.numel()
    n_blocks = cdiv(n, bs)
    grid = (n_blocks, )

    kernel_fn[grid](x, z, n, bs)
    return z


In [None]:
@triton.jit
def copy_k(x_ptr, z_ptr, n, bs: tl.constexpr):
    pid = tl.program_id(0)
    offs = pid * bs + tl.arange(0, bs)
    mask = offs < n
    x = tl.load(x_ptr + offs, mask)
    tl.store(z_ptr + offs, x, mask)
    # print("n is {}".format(n))

    # print(f"pid = {pid} | offs = {offs}, mask = {mask}, x = {x}")

In [None]:
x = torch.tensor([1, 2, 3, 4, 5, 6], device='cuda')
x.cuda()

In [None]:
z = copy(x, 2, copy_k)

In [None]:
x, z

### Addition using triton kernel

In [None]:
@triton.jit
def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, bs:tl.constexpr):
    pid = tl.program_id(0)
    offs = pid * bs + tl.arange(0, bs)
    mask = offs < n_elements

    x = tl.load(x_ptr + offs, mask)
    y = tl.load(y_ptr + offs, mask)
    tl.store(output_ptr + offs, x + y, mask)

In [None]:
def add(x: torch.tensor, y: torch.tensor, bs) -> torch.tensor:
    output = torch.empty_like(x)
    check_tensors_gpu_ready(x, y, output)
    n_elements = x.numel()
    n_blocks = cdiv(n_elements, bs)
    grid = (n_blocks, )

    add_kernel[grid](x, y, output, n_elements, bs)
    return output

In [None]:
torch.manual_seed(0)
bs = 128; size = 128 * 16

x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
output_torch = x + y
output_triton = add(x, y, bs)
print(output_triton)
torch.allclose(output_torch, output_triton)