In [None]:


def matmul(a, b):
    """
    a: 2D tensor of shape (M, K)
    b: 2D tensor of shape (K, N)
    Returns a new tensor of shape (M, N) in float16.
    """
    assert a.shape[1] == b.shape[0], "Incompatible dimensions"
    assert a.is_contiguous(), "Matrix A must be contiguous"
    M, K = a.shape
    _, N = b.shape

    c = torch.empty((M, N), device=a.device, dtype=torch.float16)

    BLOCK_SIZE_M = 128
    BLOCK_SIZE_N = 128
    BLOCK_SIZE_K = 32
    GROUP_SIZE_M = 8  

    grid = lambda META: (
        triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N'])
    ,)

    matmul_kernel[grid](
        a, b, c,
        M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
        BLOCK_SIZE_M=BLOCK_SIZE_M,
        BLOCK_SIZE_N=BLOCK_SIZE_N,
        BLOCK_SIZE_K=BLOCK_SIZE_K,
        GROUP_SIZE_M=GROUP_SIZE_M
            )

    return c




A = torch.randn((512, 512), device='cuda', dtype=torch.float16)
B = torch.randn((512, 512), device='cuda', dtype=torch.float16)

for _ in range(3):
    matmul(A, B)

import torch._dynamo as dynamo

def benchmark():
    torch.cuda.synchronize()

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    torch.matmul(A, B) 
    end.record()
    torch.cuda.synchronize()
    print(f"PyTorch Vanilla: {start.elapsed_time(end):.3f} ms")

    compiled_fn = dynamo.optimize("inductor")(lambda x, y: x @ y)
    for _ in range(3):
        compiled_fn(A, B)
    start.record()
    compiled_fn(A, B)
    end.record()
    torch.cuda.synchronize()
    print(f"TorchInductor: {start.elapsed_time(end):.3f} ms")

    for _ in range(3):
        matmul(A, B)
    start.record()
    matmul(A, B)
    end.record()
    torch.cuda.synchronize()
    print(f"Triton: {start.elapsed_time(end):.3f} ms")

benchmark()

In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

In [None]:
import triton
import triton.language as tl

@triton.jit
def conv2d_fwd_naive_kernel(
    input_ptr, weight_ptr, bias_ptr, output_ptr,
    B, C_in, H_in, W_in,
    C_out, H_out, W_out,
    K_h, K_w,
    stride_h, stride_w,
    pad_h, pad_w
):
    pid = tl.program_id(0)
    w_out_idx = pid % W_out
    pid //= W_out
    h_out_idx = pid % H_out
    pid //= H_out
    c_out_idx = pid % C_out
    pid //= C_out
    n = pid
    if n >= B:
        return

    out_val = 0.0
    in_h_start = h_out_idx*stride_h - pad_h
    in_w_start = w_out_idx*stride_w - pad_w

    for c_in_idx in range(C_in):
        for r in range(K_h):
            for c in range(K_w):
                in_h = in_h_start + r
                in_w = in_w_start + c
                in_bounds = (in_h >= 0) & (in_h < H_in) & (in_w >= 0) & (in_w < W_in)
                in_offset = (n*C_in*H_in*W_in
                             + c_in_idx*H_in*W_in
                             + in_h*W_in
                             + in_w)
                w_offset = (c_out_idx*C_in*K_h*K_w
                            + c_in_idx*K_h*K_w
                            + r*K_w
                            + c)
                val_in = tl.load(input_ptr + in_offset, mask=in_bounds, other=0.0)
                val_w  = tl.load(weight_ptr + w_offset)
                out_val += val_in * val_w

    bias_val = tl.load(bias_ptr + c_out_idx)
    out_val += bias_val

    out_offset = (n*C_out*H_out*W_out
                  + c_out_idx*H_out*W_out
                  + h_out_idx*W_out
                  + w_out_idx)
    tl.store(output_ptr + out_offset, out_val)

def triton_conv2d_naive(input, weight, bias=None, stride=1, padding=0):
    if isinstance(stride, int):
        stride_h, stride_w = stride, stride
    else:
        stride_h, stride_w = stride
    if isinstance(padding, int):
        pad_h, pad_w = padding, padding
    else:
        pad_h, pad_w = padding

    B, C_in, H_in, W_in = input.shape
    C_out, C_in2, K_h, K_w = weight.shape
    assert C_in == C_in2

    H_out = (H_in + 2*pad_h - K_h)//stride_h + 1
    W_out = (W_in + 2*pad_w - K_w)//stride_w + 1

    out = torch.empty((B, C_out, H_out, W_out), device=input.device, dtype=input.dtype)
    in_ptr  = input.flatten()
    w_ptr   = weight.flatten()
    b_ptr   = bias.flatten() if bias is not None else 0
    out_ptr = out.flatten()

    num_out_elems = B*C_out*H_out*W_out
    grid = lambda meta: (num_out_elems,)
    conv2d_fwd_naive_kernel[grid](
        in_ptr, w_ptr, b_ptr, out_ptr,
        B, C_in, H_in, W_in,
        C_out, H_out, W_out,
        K_h, K_w,
        stride_h, stride_w,
        pad_h, pad_w
    )
    return out

@triton.jit
def relu_kernel(x_ptr, y_ptr, N, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < N
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.where(x > 0, x, 0)
    tl.store(y_ptr + offsets, y, mask=mask)

def triton_relu(x):
    out = torch.empty_like(x)
    N = x.numel()
    BLOCK_SIZE = 1024  
    
    grid = lambda meta: ((N + BLOCK_SIZE - 1)//BLOCK_SIZE, )

    relu_kernel[grid](x, out, N, BLOCK_SIZE=BLOCK_SIZE)
    return out


class TritonCNNModel(nn.Module):
    def __init__(self, in_channels=1, num_classes=10):
        super(TritonCNNModel, self).__init__()

        # conv1: (16,1,3,3)
        self.conv1_weight = nn.Parameter(torch.randn(16, in_channels, 3, 3)*0.01)
        self.conv1_bias   = nn.Parameter(torch.zeros(16))
        # conv2: (32,16,3,3)
        self.conv2_weight = nn.Parameter(torch.randn(32,16,3,3)*0.01)
        self.conv2_bias   = nn.Parameter(torch.zeros(32))

        # final FC layers
        self.fc1 = nn.Linear(32*24*24, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = triton_conv2d_naive(x, self.conv1_weight, self.conv1_bias, stride=1, padding=0)
        x = triton_relu(x)
        x = triton_conv2d_naive(x, self.conv2_weight, self.conv2_bias, stride=1, padding=0)
        x = triton_relu(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x



model_triton = TritonCNNModel().to(device)


with torch.no_grad():

    model_triton.conv1_weight.copy_(model_pt.conv1.weight)
    model_triton.conv1_bias.copy_(model_pt.conv1.bias)

    model_triton.conv2_weight.copy_(model_pt.conv2.weight)
    model_triton.conv2_bias.copy_(model_pt.conv2.bias)

    model_triton.fc1.weight.copy_(model_pt.fc1.weight)
    model_triton.fc1.bias.copy_(model_pt.fc1.bias)

    model_triton.fc2.weight.copy_(model_pt.fc2.weight)
    model_triton.fc2.bias.copy_(model_pt.fc2.bias)

In [None]:
import time

data, target = next(iter(test_loader))
data = data.to(device)

for _ in range(5):
    _ = model_pt(data)
    _ = model_triton(data)

torch.cuda.synchronize()  

t0 = time.time()
_ = model_pt(data)
torch.cuda.synchronize()
t1 = time.time()
pytorch_time = t1 - t0

t2 = time.time()
_ = model_triton(data)
torch.cuda.synchronize()
t3 = time.time()
triton_time = t3 - t2

print(f"PyTorch forward pass time: {pytorch_time*1e3:.3f} ms")
print(f"Triton forward pass time:  {triton_time*1e3:.3f} ms")

In [None]:
import triton
import triton.language as tl
import torch


def dropout_relu_scale_pytorch(x, p=0.5, alpha=1.0):
    mask = (torch.rand_like(x) > p).float() / (1-p)
    y = x * mask        # dropout
    y = torch.relu(y)   # relu
    z = alpha * y       # scale
    return z


@triton.jit
def fused_dropout_relu_scale_kernel(
    x_ptr, out_ptr, seed_ptr,  
    n_elements,
    p, alpha,
    BLOCK_SIZE: tl.constexpr
):
    pid = tl.program_id(0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements


    x = tl.load(x_ptr + offsets, mask=mask, other=0.0)


    rng_offset = tl.load(seed_ptr) + offsets  
    random_val = (rng_offset % 10000) / 10000  
    keep = random_val > p
    scale = 1.0 / (1.0 - p)
    x = x * keep * scale  

    # ReLU
    x = tl.where(x > 0, x, 0.0)

    # scale by alpha
    x = x * alpha

    # store out
    tl.store(out_ptr + offsets, x, mask=mask)

def fused_dropout_relu_scale_triton(x, p=0.5, alpha=1.0):
    """One-pass fused dropout+ReLU+scale using Triton."""
    out = torch.empty_like(x)
    n_elements = x.numel()


    seed = torch.randint(0, 2**10, (1,), device=x.device, dtype=torch.int32)

    BLOCK_SIZE = 1024
    grid = lambda meta: ((n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE, )

    fused_dropout_relu_scale_kernel[grid](
        x, out, seed,  
        n_elements,
        p, alpha,
        BLOCK_SIZE=BLOCK_SIZE
    )
    return out

import time

x = torch.randn(10_000_000, device='cuda')  

# Warm-up
for _ in range(2):
    _ = dropout_relu_scale_pytorch(x, p=0.5, alpha=1.23)
    _ = fused_dropout_relu_scale_triton(x, p=0.5, alpha=1.23)

torch.cuda.synchronize()
t0 = time.time()
y_pt = dropout_relu_scale_pytorch(x, p=0.5, alpha=1.23)
torch.cuda.synchronize()
t1 = time.time()
pt_time = t1 - t0

t2 = time.time()
y_tt = fused_dropout_relu_scale_triton(x, p=0.5, alpha=1.23)
torch.cuda.synchronize()
t3 = time.time()
tt_time = t3 - t2

print(f"PyTorch time: {pt_time*1e3:.3f} ms")
print(f"Triton time:  {tt_time*1e3:.3f} ms")

print("Max diff:", (y_pt - y_tt).abs().max().item())

In [None]:
import torch
import triton
import triton.language as tl
import torch.utils.benchmark as benchmark

###############################
# 1. PyTorch reference version
###############################
def dropout_relu_scale_pytorch(x, p=0.5, alpha=1.0):
    """
    Example:
      1) Dropout
      2) ReLU
      3) scale by alpha
    Each step is potentially a separate kernel in PyTorch.
    """
    # dropout
    mask = (torch.rand_like(x) > p).float() / (1.0 - p)
    y = x * mask
    # relu
    y = torch.relu(y)
    # scale
    y = alpha * y
    return y


###############################
# 2. Triton fused kernel
###############################

@triton.jit
def fused_dropout_relu_scale_kernel(
    x_ptr, out_ptr,
    seed_int,     # <-- This will be an actual int, not a pointer
    n_elements,
    p, alpha,
    BLOCK_SIZE: tl.constexpr
):
    pid = tl.program_id(0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements

    x = tl.load(x_ptr + offsets, mask=mask, other=0.0)

    # Now seed_int is a scalar int, so pointer+int mismatch is gone
    rng_offset = seed_int + offsets
    rand_val = (rng_offset % 100000) / 100000.0
    keep = rand_val > p
    scale = 1.0 / (1.0 - p)
    x = x * keep * scale
    x = tl.where(x > 0, x, 0.0)
    x = alpha * x

    tl.store(out_ptr + offsets, x, mask=mask)

def fused_dropout_relu_scale_triton(x, p=0.5, alpha=1.0):
    out = torch.empty_like(x)
    n_elements = x.numel()
    # Convert single-element tensor to a plain int
    seed_val = torch.randint(0, 2**20, (1,), device='cuda', dtype=torch.int32).item()

    BLOCK_SIZE = 1024
    grid = lambda meta: ((n_elements + BLOCK_SIZE - 1)//BLOCK_SIZE, )

    fused_dropout_relu_scale_kernel[grid](
        x, out,
        seed_val,        # pass an int, not a tensor
        n_elements,
        p, alpha,
        BLOCK_SIZE=BLOCK_SIZE
    )
    return out


###############################
# 3. Benchmark with torch.utils.benchmark
###############################

def run_benchmark():
    device = torch.device("cuda")
    x = torch.randn(10_000_000, device=device)  # 50 million floats
    p = 0.5
    alpha = 1.23

    # Warm up to let GPU reach stable states
    for _ in range(5):
        dropout_relu_scale_pytorch(x, p, alpha)
        fused_dropout_relu_scale_triton(x, p, alpha)
    torch.cuda.synchronize()

    # We'll create two benchmark timers
    t_pytorch = benchmark.Timer(
        stmt="dropout_relu_scale_pytorch(x, 0.5, 1.23)",
        setup="from __main__ import dropout_relu_scale_pytorch, x",
        label="Dropout+ReLU+Scale",
        sub_label="PyTorch multiple kernels",
    )

    t_triton = benchmark.Timer(
        stmt="fused_dropout_relu_scale_triton(x, 0.5, 1.23)",
        setup="from __main__ import fused_dropout_relu_scale_triton, x",
        label="Dropout+ReLU+Scale",
        sub_label="Triton fused kernel",
    )


    # We can do multiple runs for each
    num_runs = 5
    results_pytorch = t_pytorch.blocked_autorange(min_run_time=1)
    results_triton  = t_triton.blocked_autorange(min_run_time=1)

    print(results_pytorch)
    print(results_triton)

    # We can compare the times
    print("PyTorch time (median):", results_pytorch.median)
    print("Triton  time (median):", results_triton.median)


if __name__ == "__main__":
    run_benchmark()