In [1]:
import triton
import triton.language as tl

@triton.jit
def add_kernel(x_ptr, y_ptr, z_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
    # 获取线程的全局索引
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # 边界检查
    mask = offsets < n_elements
    # 加载数据并计算
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    z = x + y
    # 存储结果
    tl.store(z_ptr + offsets, z, mask=mask)

In [2]:
import torch
import triton
import triton.language as tl

@triton.jit
def matmul_kernel(
    # 输入矩阵的指针
    a_ptr, b_ptr, c_ptr,
    # 矩阵维度
    M, N, K,
    # 矩阵的步长（stride）
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    # 分块大小（编译期常量）
    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
):
    # 获取当前程序的ID（对应输出矩阵的块）
    pid = tl.program_id(axis=0)
    # 计算输出块的坐标
    pid_m = pid // (N // BLOCK_SIZE_N)
    pid_n = pid % (N // BLOCK_SIZE_N)

    # 定义输入块的偏移量
    offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    offs_k = tl.arange(0, BLOCK_SIZE_K)

    # 计算指针的偏移
    a_ptr = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
    b_ptr = b_ptr + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn

    # 初始化累加器
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    # 循环加载K维度的块并计算
    for k in range(0, K, BLOCK_SIZE_K):
        # 加载A和B的块到寄存器
        a = tl.load(a_ptr, mask=offs_k[None, :] + k < K, other=0.0)
        b = tl.load(b_ptr, mask=offs_k[:, None] + k < K, other=0.0)
        # 矩阵乘法累加
        accumulator += tl.dot(a, b)
        # 移动指针到下一个K块
        a_ptr += BLOCK_SIZE_K * stride_ak
        b_ptr += BLOCK_SIZE_K * stride_bk

    # 加载输出块的偏移
    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptr = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
    # 存储结果到输出矩阵
    tl.store(c_ptr, accumulator, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N))

# 封装为Python函数
def triton_matmul(a, b):
    M, K = a.shape
    K, N = b.shape
    # 初始化输出矩阵
    c = torch.empty((M, N), device=a.device, dtype=a.dtype)
    # 定义分块大小（可根据GPU架构调整）
    BLOCK_SIZE_M = BLOCK_SIZE_N = BLOCK_SIZE_K = 128
    # 计算需要的程序数（输出块的数量）
    grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N),)
    # 启动Triton内核
    matmul_kernel[grid](
        a, b, c,
        M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
        BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K
    )
    return c

# 测试
if __name__ == "__main__":
    # 生成随机矩阵
    a = torch.randn(1024, 512, device="cuda", dtype=torch.float32)
    b = torch.randn(512, 1024, device="cuda", dtype=torch.float32)
    # Triton矩阵乘法
    c_triton = triton_matmul(a, b)
    # PyTorch内置矩阵乘法
    c_torch = torch.matmul(a, b)
    # 验证结果正确性
    print("结果误差:", torch.max(torch.abs(c_triton - c_torch)))
    # 性能对比
    %timeit triton_matmul(a, b)
    %timeit torch.matmul(a, b)

AssertionError: Torch not compiled with CUDA enabled