In [16]:
import torch
import random
random.seed(42)

N = 16
a = torch.randn(N, device='cuda')
b = torch.randn(N, device='cuda')
c = a + b
print(a, b, c, sep="\n")

tensor([-0.3327, -0.3802,  0.0867,  0.1939,  0.4104, -0.9106, -1.0601, -0.1699,
         0.4178, -1.8070, -1.0283,  0.2256,  0.2209, -1.0756,  0.1709, -0.6684],
       device='cuda:0')
tensor([ 0.7863,  0.1051,  0.0466, -0.1470,  1.3219,  0.5543, -0.5274,  0.7996,
         1.1139, -0.1291, -1.2053,  1.1623, -1.4873, -0.4576,  1.1796, -0.7119],
       device='cuda:0')
tensor([ 0.4535, -0.2751,  0.1332,  0.0469,  1.7324, -0.3563, -1.5874,  0.6298,
         1.5317, -1.9362, -2.2336,  1.3879, -1.2665, -1.5332,  1.3505, -1.3803],
       device='cuda:0')


In [17]:
print(a.data_ptr())

140081249648640


In [18]:
import triton
import triton.language as tl

@triton.jit
def vector_add_kernel(a_ptr, b_ptr, c_ptr):
    pass

def solve(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, N: int):
    grid = (1,) # `1` denotes the number of blocks we used, for a 16-element vector, 1 is more than enough
    vector_add_kernel[grid](a, b, c)

In [20]:
import triton
import triton.language as tl

@triton.jit
def vector_add_kernel(a_ptr, b_ptr, c_ptr):
    offsets = tl.arange(0, 16) # generates indexes [0, 1, ..., 15] for 16 elements

    a = tl.load(a_ptr + offsets) # load values from memory with their indexes, variable `a` is a Triton tensor, more specifically, the datatype of `a` is tl.tensor(float32, (16,))
    b = tl.load(b_ptr + offsets)

    c = a + b # element wise addition

    tl.store(c_ptr + offsets, c)

In [21]:
import torch
import triton
import triton.language as tl

@triton.jit
def vector_add_kernel(a_ptr, b_ptr, c_ptr):
    offsets = tl.arange(0, 16)
    a = tl.load(a_ptr + offsets)
    b = tl.load(b_ptr + offsets)
    c = a + b
    tl.store(c_ptr + offsets, c)

def solve(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, N: int):
    grid = (1,)
    vector_add_kernel[grid](a, b, c)

if __name__ == "__main__":
    N = 16
    a = torch.randn(N, device='cuda')
    b = torch.randn(N, device='cuda')
    torch_output = a + b
    triton_output = torch.empty_like(a)
    solve(a, b , triton_output, N)
    if torch.allclose(triton_output, torch_output):
        print("✅ Triton and Torch match")
    else:
        print("❌ Triton and Torch differ")

✅ Triton and Torch match


In [22]:
""" WARNING: FOLLOWING CODE SAMPLE DEMONSTRATES A WRONG PATTERN"""
import torch
import triton
import triton.language as tl

@triton.jit
def vector_add_kernel(a_ptr, b_ptr, c_ptr):
    offsets = tl.arange(0, 16)
    a = tl.load(a_ptr + offsets)
    b = tl.load(b_ptr + offsets)
    c = a + b
    tl.store(c_ptr + offsets, c)

def solve(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, N: int):
    grid = (1,)
    vector_add_kernel[grid](a, b, c)

if __name__ == "__main__":
    N = 15 # <- the only line we edit
    a = torch.randn(N, device='cuda')
    b = torch.randn(N, device='cuda')
    torch_output = a + b
    triton_output = torch.empty_like(a)
    solve(a, b , triton_output, N)
    if torch.allclose(triton_output, torch_output):
        print("✅ Triton and Torch match")
    else:
        print("❌ Triton and Torch differ")

✅ Triton and Torch match


In [23]:
""" WARNING: FOLLOWING CODE SAMPLE DEMONSTRATES A WRONG PATTERN"""
import torch
import triton
import triton.language as tl

@triton.jit
def vector_add_kernel(a_ptr, b_ptr, c_ptr):
    offsets = tl.arange(0, 15) # <- the line we edit
    a = tl.load(a_ptr + offsets)
    b = tl.load(b_ptr + offsets)
    c = a + b
    tl.store(c_ptr + offsets, c)

def solve(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, N: int):
    grid = (1,)
    vector_add_kernel[grid](a, b, c)

if __name__ == "__main__":
    N = 15 # <- the line we edit
    a = torch.randn(N, device='cuda')
    b = torch.randn(N, device='cuda')
    torch_output = a + b
    triton_output = torch.empty_like(a)
    solve(a, b , triton_output, N)
    if torch.allclose(triton_output, torch_output):
        print("✅ Triton and Torch match")
    else:
        print("❌ Triton and Torch differ")

CompilationError: at 2:14:
def vector_add_kernel(a_ptr, b_ptr, c_ptr):
    offsets = tl.arange(0, 15) # <- the line we edit
              ^

In [24]:
import torch
import triton
import triton.language as tl

@triton.jit
def vector_add_kernel(a_ptr, b_ptr, c_ptr, N):
    offsets = tl.arange(0, 16)

    mask = offsets < N # size of your input vector

    a = tl.load(a_ptr + offsets, mask=mask)
    b = tl.load(b_ptr + offsets, mask=mask)
    c = a + b
    tl.store(c_ptr + offsets, c, mask=mask)

def solve(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, N: int):
    grid = (1,)
    vector_add_kernel[grid](a, b, c, N)

if __name__ == "__main__":
    for N in range(1, 16+1):
        a = torch.randn(N, device='cuda')
        b = torch.randn(N, device='cuda')
        torch_output = a + b
        triton_output = torch.empty_like(a)
        solve(a, b , triton_output, N)
        if torch.allclose(triton_output, torch_output):
            print("✅ Triton and Torch match")
        else:
            print("❌ Triton and Torch differ")

✅ Triton and Torch match
✅ Triton and Torch match
✅ Triton and Torch match
✅ Triton and Torch match
✅ Triton and Torch match
✅ Triton and Torch match
✅ Triton and Torch match
✅ Triton and Torch match
✅ Triton and Torch match
✅ Triton and Torch match
✅ Triton and Torch match
✅ Triton and Torch match
✅ Triton and Torch match
✅ Triton and Torch match
✅ Triton and Torch match
✅ Triton and Torch match


In [29]:
import time
import torch
import triton
import triton.language as tl

@triton.jit
def vector_add_kernel(a_ptr, b_ptr, c_ptr, N):

    pid = tl.program_id(axis=0) # pid is a unique ID for each Thread Block

    block_start = pid * 16 # slicing data for each block

    offsets = block_start + tl.arange(0, 16)
    mask = offsets < N
    a = tl.load(a_ptr + offsets, mask=mask)
    b = tl.load(b_ptr + offsets, mask=mask)
    c = a + b
    tl.store(c_ptr + offsets, c, mask=mask)

def solve(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, N: int):
    grid = (triton.cdiv(N, 16), )
    vector_add_kernel[grid](a, b, c, N)


def time_op_gpu(fn, sync=True, warmup=5, iters=20):
    """
    Time a GPU operation using CUDA events for better accuracy (no CPU scheduling noise).
    - fn: a callable that launches GPU work
    - sync: whether to synchronize after each iteration (True recommended)
    - warmup: warm-up iterations to let JIT/caches settle
    - iters: timed iterations

    Returns: average time in milliseconds over 'iters' runs.
    """
    # warm-up does JIT and warms caches
    for _ in range(warmup):
        fn()
    if sync:
        torch.cuda.synchronize()

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    elapsed_ms = 0.0
    for _ in range(iters):
        start.record()
        fn()
        end.record()
        # Wait for the events to be recorded & measure GPU time
        torch.cuda.synchronize()
        elapsed_ms += start.elapsed_time(end)
    return elapsed_ms / iters

if __name__ == "__main__":
    for power in range(1, 25 ,2):
        N = 2 ** power
        N = 1 << 24
        a = torch.randn(N, device='cuda')
        b = torch.randn(N, device='cuda')
        triton_output = torch.empty_like(a)

        def torch_op():
            return a + b

        def triton_op():
            triton_output = torch.empty_like(a)
            solve(a, b, triton_output, N)
            return triton_output

        torch_output = torch_op()  # warm-up
        torch_time_elapsed = time_op_gpu(torch_op)

        triton_output = triton_op()  # warm-up
        triton_time_elapsed = time_op_gpu(triton_op)

        if torch.allclose(triton_output, torch_output):
            print(f"✅ Triton and Torch match with input size 2^{power}")
            print(f"Torch  time: {torch_time_elapsed:.5f} ms, \nTriton time: {triton_time_elapsed:.5f} ms")
        else:
            print(f"❌ Triton and Torch differ with input size 2^{power}")

        print("grid size: ", triton.cdiv(N, 16), "\n")

✅ Triton and Torch match with input size 2^1
Torch  time: 1.33550 ms, 
Triton time: 4.31417 ms
grid size:  1048576 

✅ Triton and Torch match with input size 2^3
Torch  time: 1.61238 ms, 
Triton time: 3.68864 ms
grid size:  1048576 

✅ Triton and Torch match with input size 2^5
Torch  time: 1.14208 ms, 
Triton time: 3.15389 ms
grid size:  1048576 

✅ Triton and Torch match with input size 2^7
Torch  time: 1.25037 ms, 
Triton time: 2.98867 ms
grid size:  1048576 

✅ Triton and Torch match with input size 2^9
Torch  time: 1.57355 ms, 
Triton time: 2.96891 ms
grid size:  1048576 

✅ Triton and Torch match with input size 2^11
Torch  time: 1.12892 ms, 
Triton time: 2.86296 ms
grid size:  1048576 

✅ Triton and Torch match with input size 2^13
Torch  time: 1.80937 ms, 
Triton time: 3.51826 ms
grid size:  1048576 

✅ Triton and Torch match with input size 2^15
Torch  time: 1.25404 ms, 
Triton time: 2.88556 ms
grid size:  1048576 

✅ Triton and Torch match with input size 2^17
Torch  time: 1.

In [32]:
import time
import torch
import triton
import triton.language as tl

@triton.jit
def vector_add_kernel(a_ptr, b_ptr, c_ptr, N, BLOCK: tl.constexpr):

    pid = tl.program_id(axis=0) # pid is a unique ID for each Thread Block

    # block_start = pid * 16 # slicing data for each block
    # offsets = block_start + tl.arange(0, 16)

    offsets = pid * BLOCK + tl.arange(0, BLOCK)

    mask = offsets < N
    a = tl.load(a_ptr + offsets, mask=mask)
    b = tl.load(b_ptr + offsets, mask=mask)
    c = a + b
    tl.store(c_ptr + offsets, c, mask=mask)

def solve(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, N: int):
    BLOCK = 1024
    grid = (triton.cdiv(N, BLOCK), )
    vector_add_kernel[grid](a, b, c, N, BLOCK=BLOCK, num_warps=4)


def time_op_gpu(fn, sync=True, warmup=5, iters=20):
    """
    Time a GPU operation using CUDA events for better accuracy (no CPU scheduling noise).
    - fn: a callable that launches GPU work
    - sync: whether to synchronize after each iteration (True recommended)
    - warmup: warm-up iterations to let JIT/caches settle
    - iters: timed iterations

    Returns: average time in milliseconds over 'iters' runs.
    """
    # warm-up does JIT and warms caches
    for _ in range(warmup):
        fn()
    if sync:
        torch.cuda.synchronize()

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    elapsed_ms = 0.0
    for _ in range(iters):
        start.record()
        fn()
        end.record()
        # Wait for the events to be recorded & measure GPU time
        torch.cuda.synchronize()
        elapsed_ms += start.elapsed_time(end)
    return elapsed_ms / iters

if __name__ == "__main__":
    for power in range(1, 25, 2):
        N = 2 ** power
        a = torch.randn(N, device='cuda')
        b = torch.randn(N, device='cuda')
        torch_output  = torch.empty_like(a)
        triton_output = torch.empty_like(a)

        def torch_op():
            return torch_output.copy_(a + b)

        def triton_op():
            triton_output = torch.empty_like(a)
            solve(a, b, triton_output, N)
            return triton_output

        torch_output = torch_op()  # warm-up
        torch_time_elapsed = time_op_gpu(torch_op)

        triton_output = triton_op()  # warm-up
        triton_time_elapsed = time_op_gpu(triton_op)

        if torch.allclose(triton_output, torch_output):
            print(f"✅ Triton and Torch match with input size 2^{power}")
            print(f"Torch  time: {torch_time_elapsed:.5f} ms, \nTriton time: {triton_time_elapsed:.5f} ms")
        else:
            print(f"❌ Triton and Torch differ with input size 2^{power}")

        print("grid size: ", triton.cdiv(N, 16), "\n")

✅ Triton and Torch match with input size 2^1
Torch  time: 0.17645 ms, 
Triton time: 0.09037 ms
grid size:  1 

✅ Triton and Torch match with input size 2^3
Torch  time: 0.04901 ms, 
Triton time: 0.18940 ms
grid size:  1 

✅ Triton and Torch match with input size 2^5
Torch  time: 0.31917 ms, 
Triton time: 0.07534 ms
grid size:  2 

✅ Triton and Torch match with input size 2^7
Torch  time: 0.34054 ms, 
Triton time: 0.06917 ms
grid size:  8 

✅ Triton and Torch match with input size 2^9
Torch  time: 0.04664 ms, 
Triton time: 0.19412 ms
grid size:  32 

✅ Triton and Torch match with input size 2^11
Torch  time: 0.07051 ms, 
Triton time: 0.06189 ms
grid size:  128 

✅ Triton and Torch match with input size 2^13
Torch  time: 0.23926 ms, 
Triton time: 0.08479 ms
grid size:  512 

✅ Triton and Torch match with input size 2^15
Torch  time: 0.07361 ms, 
Triton time: 0.07344 ms
grid size:  2048 

✅ Triton and Torch match with input size 2^17
Torch  time: 0.05223 ms, 
Triton time: 0.06151 ms
grid 