In [1]:
import torch
import triton
import triton.language as tl


@triton.jit
def matrix_transpose_kernel(input_ptr, output_ptr, rows, cols, stride_ir, stride_ic, stride_or, stride_oc):
    pid_r = tl.program_id(axis=0)
    pid_c = tl.program_id(axis=1)
    offs_r = pid_r * stride_ir # Each program handles one element
    offs_c = pid_c * stride_ic
    #Since each program handles exactly one element with a direct offset, every (pid_r, pid_c) combination is guaranteed to be within bounds. No mask needed.
    input_offset = offs_r + offs_c
    output_offset = pid_r * stride_oc + pid_c * stride_or
    el = tl.load(input_ptr + input_offset)
    tl.store(output_ptr + output_offset,el)



# input, output are tensors on the GPU
def solve(input: torch.Tensor, output: torch.Tensor, rows: int, cols: int):
    stride_ir, stride_ic = cols, 1
    stride_or, stride_oc = rows, 1

    grid = (rows, cols)
    matrix_transpose_kernel[grid](
        input, output, rows, cols, stride_ir, stride_ic, stride_or, stride_oc
    )


In [2]:
def test_matrix_transpose(rows, cols, dtype=torch.float32, seed=42):
    torch.manual_seed(seed)
    device = "cuda"

    inp = torch.randn((rows, cols), dtype=dtype, device=device)
    out = torch.empty((cols, rows), dtype=dtype, device=device)

    solve(inp, out, rows, cols)

    expected = inp.T.contiguous()
    assert torch.allclose(out, expected, atol=1e-5), \
        f"[FAIL] ({rows}x{cols}): max diff = {(out - expected).abs().max().item()}"
    print(f"[PASS] ({rows}x{cols}): max diff = {(out - expected).abs().max().item():.2e}")


# Square matrices
for N in [1, 32, 128, 512, 1024]:
    test_matrix_transpose(N, N)

# Non-square matrices
for rows, cols in [(2, 8), (3, 7), (128, 256), (500, 300)]:
    test_matrix_transpose(rows, cols)

[PASS] (1x1): max diff = 0.00e+00
[PASS] (32x32): max diff = 0.00e+00
[PASS] (128x128): max diff = 0.00e+00
[PASS] (512x512): max diff = 0.00e+00
[PASS] (1024x1024): max diff = 0.00e+00
[PASS] (2x8): max diff = 0.00e+00
[PASS] (3x7): max diff = 0.00e+00
[PASS] (128x256): max diff = 0.00e+00
[PASS] (500x300): max diff = 0.00e+00
