In [None]:
import torch

import triton
import triton.language as tl

@triton.jit
def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    # 这个arange其实就是一种for 循环 等价于于for i in 0...Block_Size
    offsets = block_start + tl.arange(0, BLOCK_SIZE)

    # 可以看出来下面这些代码体现的是mask其实是一个bool，代表着是否越界，如果越界了，就不处理了，在cuda里面是一个 if(offsets < n_elements) ...
    # offsets = 100
    # n_elements = 1000
    # mask = offsets < n_elements
    # print(mask) True

    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)

    output = x + y
    tl.store(output_ptr + offsets, output, mask=mask)


def Add(x: torch.Tensor, y: torch.Tensor):
    output = torch.empty_like(x)

    n_elements = output.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )

    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
    return output

size = 98432
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')

output = Add(x, y)

print(output)





tensor([1.1092, 0.2770, 1.3412,  ..., 0.8313, 0.9233, 1.6184], device='cuda:0')
