# Tutorial 1: Vector Addition

In [4]:
import torch
import triton
import triton.language as tl

In [5]:
@triton.jit
def add_kernel(
    x_ptr,  # 1st vector pointer
    y_ptr,  # 2nd vector pointer
    output_ptr,  # output vector pointer
    n_elements,  # vector size
    BLOCK_SIZE: tl.constexpr,  # number of elements each program should process
):
    pid = tl.program_id(axis=0)

    block_start = pid * BLOCK_SIZE  #
    offsets = block_start + tl.arange(0, BLOCK_SIZE)  # list of pointers
    mask = offsets < n_elements  # make sure n_elements is a multiple of BLOCK_SIZE
    
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y

    tl.store(output_ptr + offsets, output, mask=mask)



In [14]:
def add(x: torch.Tensor, y: torch.Tensor):
    output = torch.empty_like(x)
    assert x.is_cuda and y.is_cuda and output.is_cuda
    
    n_elements = output.numel()
    
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
    
    return output

In [17]:
n_elements = 1024
x = torch.randn(n_elements, device='cuda')
y = torch.randn(n_elements, device='cuda')
print(x[0])
print(y[0])
print(x[0] + y[0])
add(x, y)

tensor(-0.2313, device='cuda:0')
tensor(0.6526, device='cuda:0')
tensor(0.4213, device='cuda:0')


tensor([ 0.4213, -0.2940,  0.8129,  ...,  0.7071,  0.7383,  1.1865],
       device='cuda:0')

In [21]:
# Benchmark across input sizes
batch_sizes = [2**10, 2**12, 2**14, 2**16, 2**18, 2**20]  # 1K to 1M elements
results = []

for size in batch_sizes:
    x = torch.randn(size, device='cuda')
    y = torch.randn(size, device='cuda')

    # record timing
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    start_event.record()
    add(x, y)
    end_event.record()
    # pause python and wait for kernel to finish
    torch.cuda.synchronize()

    elapsed_ms = start_event.elapsed_time(end_event)
    results.append((size, elapsed_ms))
    print(f"Size: {size:,} -> {elapsed_ms:.3f} ms")

# save to CSV
import csv
with open("triton_batch_benchmark.csv", "w") as f:
    writer = csv.writer(f)
    writer.writerow(["Batch Size", "Time (ms)"])
    writer.writerows(results)

Size: 1,024 -> 0.128 ms
Size: 4,096 -> 0.088 ms
Size: 16,384 -> 0.074 ms
Size: 65,536 -> 0.069 ms
Size: 262,144 -> 0.069 ms
Size: 1,048,576 -> 0.071 ms
