In [1]:
!pip install torch triton



In [2]:
import torch, triton
print(torch.cuda.is_available())
# it has to be True, triton does not support CPU!

True


In [None]:
import torch
import triton
import triton.language as tl


# ---------------------------
# 1) The Triton kernel itself
# ---------------------------

@triton.jit
def vector_add_kernel(
    A,  # *device* pointer to input vector A (float32 elements)
    B,  # *device* pointer to input vector B (float32 elements)
    C,  # *device* pointer to output vector C (float32 elements)
    N,  # total number of elements (int32 or int64 scalar)
    BLOCK_SIZE: tl.constexpr,  # compile-time constant: how many elements a single program handles
):
    """
    A Triton kernel is a function annotated with @triton.jit.
    It will be JIT-compiled to GPU code (PTX) the first time you launch it.

    Conceptual model:
      - Triton launches many "programs" in parallel on the GPU.
      - Each program is responsible for a *tile* (a small chunk) of the problem.
      - `tl.program_id(axis=0)` returns the id of the current program along the 1D grid.
    """

    # Identify which "program" (i.e., CUDA-like block) we are in along axis 0.
    pid = tl.program_id(axis=0)

    # Within this program, we process a contiguous block of indices of length BLOCK_SIZE:
    # global indices = pid * BLOCK_SIZE + [0, 1, 2, ..., BLOCK_SIZE-1]
    # tl.arange(0, BLOCK_SIZE) creates a vector [0..BLOCK_SIZE-1] at *compile-time* (constexpr).
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)

    # Compute a boolean mask to guard against out-of-bounds when N is not a multiple of BLOCK_SIZE.
    # Any lane whose index >= N will be masked off (i.e., deactivated for loads/stores).
    mask = offsets < N

    # Load A[offsets] and B[offsets] from global memory.
    # tl.load supports a `mask=` argument so masked-out lanes won't cause invalid memory access.
    # For masked lanes, tl.load returns 0 by default (can be overridden via `other=`).
    a = tl.load(A + offsets, mask=mask)

    # Same for B.
    b = tl.load(B + offsets, mask=mask)

    # Elementwise addition happens on registers; write results back to global memory.
    # Masked lanes won't store anything.
    tl.store(C + offsets, a + b, mask=mask)


# ---------------------------------
# 2) A small helper for benchmarking
# ---------------------------------
def time_op_gpu(fn, sync=True, warmup=5, iters=20):
    """
    Time a GPU operation using CUDA events for better accuracy (no CPU scheduling noise).
    - fn: a callable that launches GPU work
    - sync: whether to synchronize after each iteration (True recommended)
    - warmup: warm-up iterations to let JIT/caches settle
    - iters: timed iterations

    Returns: average time in milliseconds over 'iters' runs.
    """
    # warm-up does JIT and warms caches
    for _ in range(warmup):
        fn()
    if sync:
        torch.cuda.synchronize()

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    elapsed_ms = 0.0
    for _ in range(iters):
        start.record()
        fn()
        end.record()
        # Wait for the events to be recorded & measure GPU time
        torch.cuda.synchronize()
        elapsed_ms += start.elapsed_time(end)
    return elapsed_ms / iters


# ---------------
# 3) Driver code
# ---------------
def main():
    assert torch.cuda.is_available(), "CUDA device not found. Please run on a machine with an NVIDIA GPU."

    # Problem size: try something reasonably large to see speed differences.
    N = 1 << 24  # ~16 million elements

    # Allocate inputs/outputs directly on the GPU.
    # Triton interoperates smoothly with PyTorch tensors as raw pointers are passed under the hood.
    a = torch.rand(N, device="cuda", dtype=torch.float32)
    b = torch.rand(N, device="cuda", dtype=torch.float32)
    c = torch.empty_like(a)

    # Choose how much work each program handles. Typical sizes: 128/256/512/1024.
    # Larger BLOCK_SIZE improves memory coalescing and reduces launch overhead but may reduce occupancy
    # if register/shared-mem pressure is high (not a big concern for this simple kernel).
    BLOCK_SIZE = 1024

    # Grid definition: how many programs to launch? We need one program per "tile" of size BLOCK_SIZE.
    # triton.cdiv(x, y) = ceil_div(x, y).
    grid = lambda META: (triton.cdiv(N, META["BLOCK_SIZE"]),)

    # 3.1) Correctness check ---------------------------------------------------
    vector_add_kernel[grid](a, b, c, N, BLOCK_SIZE=BLOCK_SIZE)
    # Synchronize to ensure kernel has finished before checking results.
    torch.cuda.synchronize()
    ok = torch.allclose(c, a + b, rtol=1e-5, atol=1e-6)
    print(f"[Correctness] Triton result matches PyTorch: {ok}")

    # 3.2) Performance: Triton vs PyTorch -------------------------------------
    # Define a launcher for Triton (so our timing helper can call it repeatedly).
    def launch_triton():
        vector_add_kernel[grid](a, b, c, N, BLOCK_SIZE=BLOCK_SIZE)

    # PyTorch baseline (also launches a highly optimized GPU kernel internally).
    def launch_torch():
        # We write to c to emulate the same output pattern as Triton; in practice you might reuse ref = a + b.
        c.copy_(a + b)

    triton_ms = time_op_gpu(launch_triton)
    torch_ms = time_op_gpu(launch_torch)

    # Throughput: number of bytes moved per second.
    # Each element reads A[i] and B[i] (2 * 4 bytes) and writes C[i] (4 bytes) => 12 bytes/element.
    bytes_moved = N * 12
    triton_bw = bytes_moved / (triton_ms / 1e3) / 1e9  # GB/s
    torch_bw = bytes_moved / (torch_ms / 1e3) / 1e9    # GB/s

    print(f"[Perf] Triton: {triton_ms:.3f} ms  (~{triton_bw:.1f} GB/s)")
    print(f"[Perf] PyTorch: {torch_ms:.3f} ms  (~{torch_bw:.1f} GB/s)")

    # 3.3) (Optional) Try different BLOCK_SIZE to see the effect on performance
    # for bs in [128, 256, 512, 1024, 2048]:
    #     grid2 = lambda META: (triton.cdiv(N, META["BLOCK_SIZE"]),)
    #     def launch():
    #         vector_add_kernel[grid2](a, b, c, N, BLOCK_SIZE=bs)
    #     t_ms = time_op_gpu(launch)
    #     bw = bytes_moved / (t_ms / 1e3) / 1e9
    #     print(f"BLOCK_SIZE={bs:4d}  ->  {t_ms:.3f} ms  (~{bw:.1f} GB/s)")


if __name__ == "__main__":
    main()




[Correctness] Triton result matches PyTorch: True
[Perf] Triton: 0.660 ms  (~304.9 GB/s)
[Perf] PyTorch: 1.087 ms  (~185.1 GB/s)
