Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions kernel_perf_agent/examples/gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""Matrix Multiplication"""
import torch
import triton
import triton.language as tl
from kernel_opt import kernel_opt

BLOCK_SIZE_M = 32
BLOCK_SIZE_N = 32
BLOCK_SIZE_K = 32
DEVICE = triton.runtime.driver.active.get_active_torch_device()

@triton.jit
def matmul_kernel(
a_ptr,
b_ptr,
c_ptr,
M,
N,
K,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n
# Range of pointers for loading the block of A and B.
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
accumulator = tl.dot(a, b, accumulator)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
c = accumulator.to(tl.float16)

offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)


@kernel_opt(
func_prompt="an unoptimized matrix-matrix multiplication",
opt_prompt="Integrate persistent programming style",
model="gpt-3.5-turbo",
dsl="Triton",
kernel_name="gemm",
debug=True,
)
def host(a, b):
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
M, K = a.shape
K, N = b.shape
c = torch.empty((M, N), device=a.device, dtype=torch.float16)

grid = lambda META: (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
matmul_kernel[grid](
a,
b,
c,
M,
N,
K,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
)
return c

def main():
"""Run a simple test of the matrix multiplication."""
torch.manual_seed(0)
M, K, N = 128, 128, 128
x = torch.randn(M, K, device=DEVICE, dtype=torch.float16)
y = torch.randn(N, K, device=DEVICE, dtype=torch.float16)

# Run reference implementation
triton_output = host(x, y)
torch_output = torch.matmul(x, y)
if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0):
print("✅ Triton and Torch match")
else:
print("❌ Triton and Torch differ")

if __name__ == "__main__":
main()
106 changes: 106 additions & 0 deletions kernel_perf_agent/examples/gemm_sw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""Matrix Multiplication with PID Swizzling"""
import torch
import triton
import triton.language as tl
from kernel_opt import kernel_opt

BLOCK_SIZE_M = 128
BLOCK_SIZE_N = 128
BLOCK_SIZE_K = 128
GROUP_SIZE_M = 8
DEVICE = triton.runtime.driver.active.get_active_torch_device()

@triton.jit
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M):
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (tile_id % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
return pid_m, pid_n

@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
pid_m, pid_n = _compute_pid(pid, num_pid_in_group, num_pid_m, GROUP_SIZE_M)

# Range of pointers for loading the block of A and B.
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
accumulator = tl.dot(a, b, accumulator)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
c = accumulator.to(tl.float16)

offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)

@kernel_opt(
func_prompt="a general matrix-matrix multiplication with PID swizzling",
opt_prompt="Integrate on-device tensor memory accelerator",
model="gpt-3.5-turbo",
dsl="Triton",
kernel_name="gemm_pid_swizzle",
debug=True,
)
def matmul(a, b):
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
M, K = a.shape
K, N = b.shape
c = torch.empty((M, N), device=a.device, dtype=torch.float16)

grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
GROUP_SIZE_M=GROUP_SIZE_M
)
return c

def main():
"""Run a simple test of the matrix multiplication."""
torch.manual_seed(0)
M, K, N = 512, 512, 512
x = torch.randn(M, K, device=DEVICE, dtype=torch.float16)
y = torch.randn(N, K, device=DEVICE, dtype=torch.float16)

# Run reference implementation
triton_output = matmul(x, y)
torch_output = torch.matmul(x, y)
if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0):
print("✅ Triton and Torch match")
else:
print("❌ Triton and Torch differ")

if __name__ == "__main__":
main()
170 changes: 170 additions & 0 deletions kernel_perf_agent/examples/grouped_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
"""Grouped Matrix Multiplication"""
import torch
import triton
import triton.language as tl
from kernel_opt import kernel_opt

BLOCK_SIZE_M = 128
BLOCK_SIZE_N = 128
BLOCK_SIZE_K = 128
NUM_SM = 128
DEVICE = triton.runtime.driver.active.get_active_torch_device()

@triton.jit
def grouped_matmul_kernel(
group_a_ptrs,
group_b_ptrs,
group_c_ptrs,
group_gemm_sizes,
g_lds,
group_size,
NUM_SM: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
tile_idx = tl.program_id(0)
last_problem_end = 0
for g in range(group_size):
# get the gemm size
gm = tl.load(group_gemm_sizes + g * 3)
gn = tl.load(group_gemm_sizes + g * 3 + 1)
gk = tl.load(group_gemm_sizes + g * 3 + 2)
num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M)
num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N)
num_tiles = num_m_tiles * num_n_tiles
# iterate through the tiles in the current gemm problem
while tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles:
# pick up a tile from the current gemm problem
k = gk
lda = tl.load(g_lds + g * 3)
ldb = tl.load(g_lds + g * 3 + 1)
ldc = tl.load(g_lds + g * 3 + 2)
a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(tl.float16))
b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(tl.float16))
c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(tl.float16))
# tile coordinates
tile_idx_in_gemm = tile_idx - last_problem_end
tile_m_idx = tile_idx_in_gemm // num_n_tiles
tile_n_idx = tile_idx_in_gemm % num_n_tiles

# regular gemm
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + offs_am[:, None] * lda + offs_k[None, :]
b_ptrs = b_ptr + offs_k[:, None] * ldb + offs_bn[None, :]
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)):
tl.multiple_of(a_ptrs, [16, 16])
tl.multiple_of(b_ptrs, [16, 16])
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K
b_ptrs += BLOCK_SIZE_K * ldb
c = accumulator.to(tl.float16)

offs_cm = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + ldc * offs_cm[:, None] + offs_cn[None, :]

# assumes full tile for now
tl.store(c_ptrs, c)

# go to the next tile by advancing NUM_SM
tile_idx += NUM_SM

# go to the next gemm problem
last_problem_end = last_problem_end + num_tiles

@kernel_opt(
func_prompt="a grouped of matrix matrix multiplications",
opt_prompt="Integrate on-device tensor memory accelerator",
model="gpt-3.5-turbo",
dsl="Triton",
kernel_name="group_gemm",
debug=True,
)
def group_gemm_fn(group_A, group_B):
assert len(group_A) == len(group_B)
group_size = len(group_A)

A_addrs = []
B_addrs = []
C_addrs = []
g_sizes = []
g_lds = []
group_C = []
for i in range(group_size):
A = group_A[i]
B = group_B[i]
assert A.shape[1] == B.shape[0]
M, K = A.shape
K, N = B.shape
C = torch.empty((M, N), device=DEVICE, dtype=A.dtype)
group_C.append(C)
A_addrs.append(A.data_ptr())
B_addrs.append(B.data_ptr())
C_addrs.append(C.data_ptr())
g_sizes += [M, N, K]
g_lds += [A.stride(0), B.stride(0), C.stride(0)]

d_a_ptrs = torch.tensor(A_addrs, device=DEVICE)
d_b_ptrs = torch.tensor(B_addrs, device=DEVICE)
d_c_ptrs = torch.tensor(C_addrs, device=DEVICE)
d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE)
d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE)

grid = lambda META: (META["NUM_SM"],)
grouped_matmul_kernel[grid](
d_a_ptrs,
d_b_ptrs,
d_c_ptrs,
d_g_sizes,
d_g_lds,
group_size,
NUM_SM=NUM_SM,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
)

return group_C

def main():
"""Run a simple test of the grouped matrix multiplication."""
# Test the kernel
group_m = [1024, 512, 256, 128]
group_n = [1024, 512, 256, 128]
group_k = [1024, 512, 256, 128]
group_A = []
group_B = []
group_B_T = []
assert len(group_m) == len(group_n)
assert len(group_n) == len(group_k)
group_size = len(group_m)
for i in range(group_size):
M = group_m[i]
N = group_n[i]
K = group_k[i]
A = torch.rand((M, K), device=DEVICE, dtype=torch.float16)
B = torch.rand((K, N), device=DEVICE, dtype=torch.float16)
B_T = B.T.contiguous()
group_A.append(A)
group_B.append(B)
group_B_T.append(B_T)

tri_out = group_gemm_fn(group_A, group_B)
ref_out = [torch.matmul(a, b) for a, b in zip(group_A, group_B)]
passed = True
for i in range(group_size):
if not torch.allclose(ref_out[i], tri_out[i], atol=1e-2, rtol=1e-2):
print("❌ Triton and Torch differ")
passed = False
break
if passed:
print("✅ Triton and Torch match")

if __name__ == "__main__":
main()
Empty file.
Loading