In [3]:
import triton
import torch
import triton.language as tl

@triton.jit
def add_kernel(x_ptr,
               y_ptr,
               output_ptr,
               n_elements,
               BLOCK_SIZE: tl.constexpr):
    
    pid = tl.program_id(axis=0)

    block_start = pid * BLOCK_SIZE

    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # print(offsets.shape)

    mask = offsets < n_elements
    # print(mask.shape)

    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)

    output = x + y
    # print(output.shape)
    tl.store(output_ptr + offsets, output, mask=mask)



In [4]:
def add(x: torch.Tensor, y: torch.Tensor):

    output = torch.empty_like(x)

    assert x.is_cuda and y.is_cuda and output.is_cuda
    n_elements = output.numel()

    grid = lambda meta : (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)

    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)

    return output

In [5]:
torch.manual_seed(0)
size = 98432
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')

output_torch = x + y
output_triton = add(x, y)
print(output_torch)
print(output_triton)

print(f'The maximum difference between torch and triton is '
      f'{torch.max(torch.abs(output_torch - output_triton))}')

tensor([1.3713, 1.3076, 0.4940,  ..., 0.6682, 1.1984, 1.2696], device='cuda:0')
tensor([1.3713, 1.3076, 0.4940,  ..., 0.6682, 1.1984, 1.2696], device='cuda:0')
The maximum difference between torch and triton is 0.0


In [9]:
MASK = torch.triu(torch.ones((9, 9), device='cuda'))
print(MASK)

tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1.]], device='cuda:0')


In [12]:
import triton.language as tl
import triton
@triton.jit
def test():
    offs_q = 2 * 4 + tl.arange(0, 4)
    offs_kv = tl.arange(0, 4)
    mask = offs_q[:, None] >= (0 + offs_kv[None, :])
    tl.device_print("offs_q:", mask)


grid = lambda args: (
            1,
            1,
            1,
        )
test[grid]()

<triton.compiler.compiler.CompiledKernel at 0x772372130e00>

pid (0, 0, 0) idx (0, 0) offs_q: 4294967295
pid (0, 0, 0) idx (0, 1) offs_q: 4294967295
pid (0, 0, 0) idx (0, 2) offs_q: 4294967295
pid (0, 0, 0) idx (0, 3) offs_q: 4294967295
pid (0, 0, 0) idx (1, 0) offs_q: 4294967295
pid (0, 0, 0) idx (1, 1) offs_q: 4294967295
pid (0, 0, 0) idx (1, 2) offs_q: 4294967295
pid (0, 0, 0) idx (1, 3) offs_q: 4294967295
pid (0, 0, 0) idx (2, 0) offs_q: 4294967295
pid (0, 0, 0) idx (2, 1) offs_q: 4294967295
pid (0, 0, 0) idx (2, 2) offs_q: 4294967295
pid (0, 0, 0) idx (2, 3) offs_q: 4294967295
pid (0, 0, 0) idx (3, 0) offs_q: 4294967295
pid (0, 0, 0) idx (3, 1) offs_q: 4294967295
pid (0, 0, 0) idx (3, 2) offs_q: 4294967295
pid (0, 0, 0) idx (3, 3) offs_q: 4294967295
pid (0, 0, 0) idx (0, 0) offs_q: 4294967295
pid (0, 0, 0) idx (0, 1) offs_q: 4294967295
pid (0, 0, 0) idx (0, 2) offs_q: 4294967295
pid (0, 0, 0) idx (0, 3) offs_q: 4294967295
pid (0, 0, 0) idx (1, 0) offs_q: 4294967295
pid (0, 0, 0) idx (1, 1) offs_q: 4294967295
pid (0, 0, 0) idx (1, 2) offs_q: