In [202]:
import torch as nn
import triton
import triton.language as tt
import numpy as np
import time

In [203]:
N = 1<<24
BLOCK_SIZE = 1024

In [204]:
@triton.jit
def init(x, val, N, BLOCK_SIZE: tt.constexpr):
    tid = tt.program_id(axis=0)
    offsets = tid * BLOCK_SIZE + tt.arange(0, BLOCK_SIZE)
    mask = offsets < N
    tt.store(x + offsets, val, mask=mask)

In [205]:
@triton.jit
def add(a_ptr, b_ptr, c_ptr, N, BLOCK_SIZE: tt.constexpr):
    tid = tt.program_id(axis=0)
    offsets = tid * BLOCK_SIZE + tt.arange(0, BLOCK_SIZE)
    mask = offsets < N
    a = tt.load(a_ptr + offsets, mask=mask)
    b = tt.load(b_ptr + offsets, mask=mask)
    c = a + b
    tt.store(c_ptr + offsets, c, mask=mask)

In [206]:
def main():
    a = nn.empty(N, dtype=nn.int32, device='cuda')
    b = nn.empty(N, dtype=nn.int32, device='cuda')
    c = nn.empty(N, dtype=nn.int32, device='cuda')

    block = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']),)

    init[block](a, 1, N, BLOCK_SIZE=BLOCK_SIZE)
    nn.cuda.synchronize()
    init[block](b, 2, N, BLOCK_SIZE=BLOCK_SIZE)
    nn.cuda.synchronize()
    init[block](c, 0, N, BLOCK_SIZE=BLOCK_SIZE)
    nn.cuda.synchronize()

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    add[block](a, b, c, N, BLOCK_SIZE=BLOCK_SIZE)
    end.record()
    nn.cuda.synchronize()
    c_copy = c.cpu().numpy()
    print("Success" if (c_copy == 3).all() else "Failure")
    print("Triton addition time: {:.3f} ms".format(start.elapsed_time(end)))

In [207]:
main()

Success
Triton addition time: 0.933 ms


In [208]:
a_cpu = np.empty(N, dtype=np.int32)
b_cpu = np.empty(N, dtype=np.int32)
c_cpu = np.empty(N, dtype=np.int32)
a_cpu[:] = 1
b_cpu[:] = 2

start = time.perf_counter()
for i in range(N):
    c_cpu[i] = a_cpu[i] + b_cpu[i]
end = time.perf_counter()
print("CPU loop addition time: {:.3f} ms".format((end - start) * 1000))

start_vec = time.perf_counter()
c_vec = a_cpu + b_cpu
end_vec = time.perf_counter()
print("Numpy addition time: {:.3f} ms".format((end_vec - start_vec) * 1000))

print("Success" if (c_cpu == 3).all() and (c_vec == 3).all() else "Error in CPU addition")

CPU loop addition time: 5051.784 ms
Numpy addition time: 22.490 ms
Success
