In [None]:
import torch
import triton
import triton.language as tl


@triton.jit
def vector_add_kernel(a, b, c, n_elements, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offs < n_elements
    a_vals = tl.load(a + offs, mask=mask, other=0.0)
    b_vals = tl.load(b + offs, mask=mask, other=0.0)
    c_vals = a_vals + b_vals
    tl.store(c + offs, c_vals, mask=mask)


# a, b, c are CUDA tensors
def solve(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, n: int):
    assert a.is_cuda and b.is_cuda and c.is_cuda, "All tensors must be on CUDA"
    assert a.is_contiguous() and b.is_contiguous() and c.is_contiguous(), "Tensors must be contiguous"
    assert a.numel() >= n and b.numel() >= n and c.numel() >= n, "n exceeds tensor size"

    block_size = 1024
    grid = (triton.cdiv(n, block_size),)
    vector_add_kernel[grid](a, b, c, n, BLOCK_SIZE=block_size)


In [3]:
def run_vector_add_tests():
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA GPU is required to run Triton tests")

    torch.manual_seed(0)
    test_sizes = [1, 17, 1023, 1024, 1025, 4096, 1000003]
    test_dtypes = [torch.float32, torch.float16]

    for dtype in test_dtypes:
        for n in test_sizes:
            a = torch.randn(n, device="cuda", dtype=dtype)
            b = torch.randn(n, device="cuda", dtype=dtype)
            c = torch.empty_like(a)

            solve(a, b, c, n)
            expected = a + b

            atol = 1e-2 if dtype == torch.float16 else 1e-5
            rtol = 1e-2 if dtype == torch.float16 else 1e-5
            assert torch.allclose(c, expected, atol=atol, rtol=rtol), (
                f"Mismatch for dtype={dtype}, n={n}"
            )

    # Quick smoke test on a larger tensor
    n = 1_000_000
    a = torch.randn(n, device="cuda", dtype=torch.float32)
    b = torch.randn(n, device="cuda", dtype=torch.float32)
    c = torch.empty_like(a)
    solve(a, b, c, n)
    assert torch.allclose(c, a + b, atol=1e-5, rtol=1e-5)

    print("All Triton vector-add tests passed.")


run_vector_add_tests()

All Triton vector-add tests passed.
