In [2]:
import torch
import triton
import triton.language as tl


@triton.jit
def matrix_add_kernel(a_ptr, b_ptr, c_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0) # Think 2D matrix as flattened 1D vector, you just need to add same index of a and b
    offs = pid * BLOCK_SIZE + tl.arange(0,BLOCK_SIZE)
    mask = offs < n_elements
    A = tl.load(a_ptr+offs, mask = mask)
    B = tl.load(b_ptr+offs, mask = mask)
    C = A + B
    tl.store(c_ptr+offs,C, mask = mask)

# a, b, c are tensors on the GPU
def solve(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, N: int):
    BLOCK_SIZE = 1024
    n_elements = N * N
    grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
    matrix_add_kernel[grid](a, b, c, n_elements, BLOCK_SIZE)


In [3]:
def test_matrix_addition(N, dtype=torch.float32, seed=42):
    torch.manual_seed(seed)
    device = "cuda"

    a = torch.randn((N, N), dtype=dtype, device=device)
    b = torch.randn((N, N), dtype=dtype, device=device)
    c = torch.empty((N, N), dtype=dtype, device=device)

    solve(a, b, c, N)

    expected = a + b
    assert torch.allclose(c, expected, atol=1e-5), \
        f"Mismatch: max diff = {(c - expected).abs().max().item()}"
    print(f"[PASS] N={N}: max diff = {(c - expected).abs().max().item():.2e}")


for N in [1, 32, 128, 512, 1024, 4096]:
    test_matrix_addition(N)

[PASS] N=1: max diff = 0.00e+00
[PASS] N=32: max diff = 0.00e+00
[PASS] N=128: max diff = 0.00e+00
[PASS] N=512: max diff = 0.00e+00
[PASS] N=1024: max diff = 0.00e+00
[PASS] N=4096: max diff = 0.00e+00
