In [1]:
import torch
import triton
import triton.language as tl


@triton.jit
def reverse_kernel(input, N, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < N // 2                          # only first half

    a = tl.load(input + offsets, mask=mask)          # load left side
    reverse_offsets = N - 1 - offsets
    b = tl.load(input + reverse_offsets, mask=mask)  # load right side

    tl.store(input + reverse_offsets, a, mask=mask)  # write left → right
    tl.store(input + offsets, b, mask=mask)          # write right → left


# input is a tensor on the GPU
def solve(input: torch.Tensor, N: int):
    BLOCK_SIZE = 1024
    n_blocks = triton.cdiv(N // 2, BLOCK_SIZE)
    grid = (n_blocks,)

    reverse_kernel[grid](input, N, BLOCK_SIZE)


In [2]:
def _reference_reverse(x: torch.Tensor) -> torch.Tensor:
    return torch.flip(x, dims=[0])


def run_reverse_tests():
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA GPU is required to run Triton tests")

    test_sizes = [0, 1, 2, 3, 31, 32, 33, 1023, 1024, 1025, 1000001]
    test_dtypes = [torch.float32, torch.int64]

    for dtype in test_dtypes:
        for n in test_sizes:
            if dtype.is_floating_point:
                x = torch.randn(n, device="cuda", dtype=dtype)
            else:
                x = torch.randint(-1000, 1000, (n,), device="cuda", dtype=dtype)

            expected = _reference_reverse(x).clone()
            solve(x, n)

            if dtype.is_floating_point:
                ok = torch.allclose(x, expected, atol=1e-5, rtol=1e-5)
            else:
                ok = torch.equal(x, expected)

            assert ok, f"reverse mismatch for dtype={dtype}, n={n}"

    # Deterministic sanity case
    x = torch.tensor([1, 2, 3, 4, 5], device="cuda", dtype=torch.int64)
    solve(x, x.numel())
    assert torch.equal(x, torch.tensor([5, 4, 3, 2, 1], device="cuda", dtype=torch.int64))

    print("All reverse kernel tests passed.")


run_reverse_tests()

All reverse kernel tests passed.
