From 0364f01135f97839868276c6dbc0151c5c229d9d Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Sun, 2 Nov 2025 22:02:54 -0800 Subject: [PATCH] Add kernel performance agent with optimization framework - Add kernel_perf_agent module with decorator-based optimization - Include database of code samples (matmul, matadd, TMA examples) - Add retriever for relevant optimization patterns - Add verifier for numeric and benchmark validation - Add examples for gemm, gemm_sw, and grouped_gemm - Update pyproject.toml with new dependencies - Update agent, providers, and worker with perf integration --- kernel_perf_agent/examples/gemm.py | 110 ++++++++ kernel_perf_agent/examples/gemm_sw.py | 106 ++++++++ kernel_perf_agent/examples/grouped_gemm.py | 170 ++++++++++++ .../kernel_opt/database/__init__.py | 0 kernel_perf_agent/kernel_opt/database/base.py | 187 ++++++++++++++ .../database/code_samples/matadd.py | 60 +++++ .../database/code_samples/matadd_perst.py | 66 +++++ .../code_samples/matadd_tma_device.py | 81 ++++++ .../database/code_samples/matadd_tma_host.py | 56 ++++ .../database/code_samples/matmul.py | 84 ++++++ .../database/code_samples/matmul_sw.py | 91 +++++++ .../database/code_samples/matmul_tma_host.py | 65 +++++ .../database/docs/experimental_tma.md | 153 +++++++++++ .../kernel_opt/database/docs/on_device_tma.py | 42 +++ .../kernel_opt/database/docs/on_host_tma.py | 36 +++ .../kernel_opt/database/docs/persistence.py | 29 +++ .../kernel_opt/database/docs/pid_swizzle.py | 24 ++ .../kernel_opt/database/docs/tma.md | 150 +++++++++++ .../kernel_opt/decorator/__init__.py | 3 + .../kernel_opt/decorator/agent.py | 152 +++++++++++ .../kernel_opt/decorator/kernel_opt.py | 183 +++++++++++++ .../kernel_opt/prompts/__init__.py | 0 .../kernel_opt/prompts/prompt_manager.py | 109 ++++++++ .../prompts/rewrite_prompt_template.py | 21 ++ .../kernel_opt/retriever/__init__.py | 0 .../kernel_opt/retriever/retriever.py | 131 ++++++++++ .../kernel_opt/utils/__init__.py | 0 .../kernel_opt/utils/debug_util.py | 13 + kernel_perf_agent/kernel_opt/utils/io_util.py | 16 ++ .../kernel_opt/utils/logging_util.py | 30 +++ .../kernel_opt/utils/parser_util.py | 80 ++++++ .../kernel_opt/utils/proxy_util.py | 57 +++++ .../kernel_opt/verifier/__init__.py | 0 kernel_perf_agent/kernel_opt/verifier/base.py | 24 ++ .../kernel_opt/verifier/benchmark.py | 54 ++++ .../kernel_opt/verifier/numeric.py | 57 +++++ .../kernel_opt/verifier/verifier.py | 51 ++++ pyproject.toml | 4 +- triton_kernel_agent/agent.py | 11 +- triton_kernel_agent/providers/openai_base.py | 21 +- triton_kernel_agent/worker.py | 242 +++++++++++++++++- 41 files changed, 2751 insertions(+), 18 deletions(-) create mode 100644 kernel_perf_agent/examples/gemm.py create mode 100644 kernel_perf_agent/examples/gemm_sw.py create mode 100644 kernel_perf_agent/examples/grouped_gemm.py create mode 100644 kernel_perf_agent/kernel_opt/database/__init__.py create mode 100644 kernel_perf_agent/kernel_opt/database/base.py create mode 100644 kernel_perf_agent/kernel_opt/database/code_samples/matadd.py create mode 100644 kernel_perf_agent/kernel_opt/database/code_samples/matadd_perst.py create mode 100644 kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_device.py create mode 100644 kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_host.py create mode 100644 kernel_perf_agent/kernel_opt/database/code_samples/matmul.py create mode 100644 kernel_perf_agent/kernel_opt/database/code_samples/matmul_sw.py create mode 100644 kernel_perf_agent/kernel_opt/database/code_samples/matmul_tma_host.py create mode 100644 kernel_perf_agent/kernel_opt/database/docs/experimental_tma.md create mode 100644 kernel_perf_agent/kernel_opt/database/docs/on_device_tma.py create mode 100644 kernel_perf_agent/kernel_opt/database/docs/on_host_tma.py create mode 100644 kernel_perf_agent/kernel_opt/database/docs/persistence.py create mode 100644 kernel_perf_agent/kernel_opt/database/docs/pid_swizzle.py create mode 100644 kernel_perf_agent/kernel_opt/database/docs/tma.md create mode 100644 kernel_perf_agent/kernel_opt/decorator/__init__.py create mode 100644 kernel_perf_agent/kernel_opt/decorator/agent.py create mode 100644 kernel_perf_agent/kernel_opt/decorator/kernel_opt.py create mode 100644 kernel_perf_agent/kernel_opt/prompts/__init__.py create mode 100644 kernel_perf_agent/kernel_opt/prompts/prompt_manager.py create mode 100644 kernel_perf_agent/kernel_opt/prompts/rewrite_prompt_template.py create mode 100644 kernel_perf_agent/kernel_opt/retriever/__init__.py create mode 100644 kernel_perf_agent/kernel_opt/retriever/retriever.py create mode 100644 kernel_perf_agent/kernel_opt/utils/__init__.py create mode 100644 kernel_perf_agent/kernel_opt/utils/debug_util.py create mode 100644 kernel_perf_agent/kernel_opt/utils/io_util.py create mode 100644 kernel_perf_agent/kernel_opt/utils/logging_util.py create mode 100644 kernel_perf_agent/kernel_opt/utils/parser_util.py create mode 100644 kernel_perf_agent/kernel_opt/utils/proxy_util.py create mode 100644 kernel_perf_agent/kernel_opt/verifier/__init__.py create mode 100644 kernel_perf_agent/kernel_opt/verifier/base.py create mode 100644 kernel_perf_agent/kernel_opt/verifier/benchmark.py create mode 100644 kernel_perf_agent/kernel_opt/verifier/numeric.py create mode 100644 kernel_perf_agent/kernel_opt/verifier/verifier.py diff --git a/kernel_perf_agent/examples/gemm.py b/kernel_perf_agent/examples/gemm.py new file mode 100644 index 0000000..ef9d5ef --- /dev/null +++ b/kernel_perf_agent/examples/gemm.py @@ -0,0 +1,110 @@ +"""Matrix Multiplication""" +import torch +import triton +import triton.language as tl +from kernel_opt import kernel_opt + +BLOCK_SIZE_M = 32 +BLOCK_SIZE_N = 32 +BLOCK_SIZE_K = 32 +DEVICE = triton.runtime.driver.active.get_active_torch_device() + +@triton.jit +def matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + # Range of pointers for loading the block of A and B. + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, b, accumulator) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + c = accumulator.to(tl.float16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +@kernel_opt( + func_prompt="an unoptimized matrix-matrix multiplication", + opt_prompt="Integrate persistent programming style", + model="gpt-3.5-turbo", + dsl="Triton", + kernel_name="gemm", + debug=True, +) +def host(a, b): + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + M, K = a.shape + K, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=torch.float16) + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + matmul_kernel[grid]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + ) + return c + +def main(): + """Run a simple test of the matrix multiplication.""" + torch.manual_seed(0) + M, K, N = 128, 128, 128 + x = torch.randn(M, K, device=DEVICE, dtype=torch.float16) + y = torch.randn(N, K, device=DEVICE, dtype=torch.float16) + + # Run reference implementation + triton_output = host(x, y) + torch_output = torch.matmul(x, y) + if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0): + print("✅ Triton and Torch match") + else: + print("❌ Triton and Torch differ") + +if __name__ == "__main__": + main() diff --git a/kernel_perf_agent/examples/gemm_sw.py b/kernel_perf_agent/examples/gemm_sw.py new file mode 100644 index 0000000..95d5770 --- /dev/null +++ b/kernel_perf_agent/examples/gemm_sw.py @@ -0,0 +1,106 @@ +"""Matrix Multiplication with PID Swizzling""" +import torch +import triton +import triton.language as tl +from kernel_opt import kernel_opt + +BLOCK_SIZE_M = 128 +BLOCK_SIZE_N = 128 +BLOCK_SIZE_K = 128 +GROUP_SIZE_M = 8 +DEVICE = triton.runtime.driver.active.get_active_torch_device() + +@triton.jit +def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M): + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + return pid_m, pid_n + +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + pid_m, pid_n = _compute_pid(pid, num_pid_in_group, num_pid_m, GROUP_SIZE_M) + + # Range of pointers for loading the block of A and B. + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, b, accumulator) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + c = accumulator.to(tl.float16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + +@kernel_opt( + func_prompt="a general matrix-matrix multiplication with PID swizzling", + opt_prompt="Integrate on-device tensor memory accelerator", + model="gpt-3.5-turbo", + dsl="Triton", + kernel_name="gemm_pid_swizzle", + debug=True, +) +def matmul(a, b): + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + M, K = a.shape + K, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=torch.float16) + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + matmul_kernel[grid]( + a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=GROUP_SIZE_M + ) + return c + +def main(): + """Run a simple test of the matrix multiplication.""" + torch.manual_seed(0) + M, K, N = 512, 512, 512 + x = torch.randn(M, K, device=DEVICE, dtype=torch.float16) + y = torch.randn(N, K, device=DEVICE, dtype=torch.float16) + + # Run reference implementation + triton_output = matmul(x, y) + torch_output = torch.matmul(x, y) + if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0): + print("✅ Triton and Torch match") + else: + print("❌ Triton and Torch differ") + +if __name__ == "__main__": + main() diff --git a/kernel_perf_agent/examples/grouped_gemm.py b/kernel_perf_agent/examples/grouped_gemm.py new file mode 100644 index 0000000..0f9261a --- /dev/null +++ b/kernel_perf_agent/examples/grouped_gemm.py @@ -0,0 +1,170 @@ +"""Grouped Matrix Multiplication""" +import torch +import triton +import triton.language as tl +from kernel_opt import kernel_opt + +BLOCK_SIZE_M = 128 +BLOCK_SIZE_N = 128 +BLOCK_SIZE_K = 128 +NUM_SM = 128 +DEVICE = triton.runtime.driver.active.get_active_torch_device() + +@triton.jit +def grouped_matmul_kernel( + group_a_ptrs, + group_b_ptrs, + group_c_ptrs, + group_gemm_sizes, + g_lds, + group_size, + NUM_SM: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + tile_idx = tl.program_id(0) + last_problem_end = 0 + for g in range(group_size): + # get the gemm size + gm = tl.load(group_gemm_sizes + g * 3) + gn = tl.load(group_gemm_sizes + g * 3 + 1) + gk = tl.load(group_gemm_sizes + g * 3 + 2) + num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M) + num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N) + num_tiles = num_m_tiles * num_n_tiles + # iterate through the tiles in the current gemm problem + while tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles: + # pick up a tile from the current gemm problem + k = gk + lda = tl.load(g_lds + g * 3) + ldb = tl.load(g_lds + g * 3 + 1) + ldc = tl.load(g_lds + g * 3 + 2) + a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(tl.float16)) + b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(tl.float16)) + c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(tl.float16)) + # tile coordinates + tile_idx_in_gemm = tile_idx - last_problem_end + tile_m_idx = tile_idx_in_gemm // num_n_tiles + tile_n_idx = tile_idx_in_gemm % num_n_tiles + + # regular gemm + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_am[:, None] * lda + offs_k[None, :] + b_ptrs = b_ptr + offs_k[:, None] * ldb + offs_bn[None, :] + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)): + tl.multiple_of(a_ptrs, [16, 16]) + tl.multiple_of(b_ptrs, [16, 16]) + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K * ldb + c = accumulator.to(tl.float16) + + offs_cm = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + ldc * offs_cm[:, None] + offs_cn[None, :] + + # assumes full tile for now + tl.store(c_ptrs, c) + + # go to the next tile by advancing NUM_SM + tile_idx += NUM_SM + + # go to the next gemm problem + last_problem_end = last_problem_end + num_tiles + +@kernel_opt( + func_prompt="a grouped of matrix matrix multiplications", + opt_prompt="Integrate on-device tensor memory accelerator", + model="gpt-3.5-turbo", + dsl="Triton", + kernel_name="group_gemm", + debug=True, +) +def group_gemm_fn(group_A, group_B): + assert len(group_A) == len(group_B) + group_size = len(group_A) + + A_addrs = [] + B_addrs = [] + C_addrs = [] + g_sizes = [] + g_lds = [] + group_C = [] + for i in range(group_size): + A = group_A[i] + B = group_B[i] + assert A.shape[1] == B.shape[0] + M, K = A.shape + K, N = B.shape + C = torch.empty((M, N), device=DEVICE, dtype=A.dtype) + group_C.append(C) + A_addrs.append(A.data_ptr()) + B_addrs.append(B.data_ptr()) + C_addrs.append(C.data_ptr()) + g_sizes += [M, N, K] + g_lds += [A.stride(0), B.stride(0), C.stride(0)] + + d_a_ptrs = torch.tensor(A_addrs, device=DEVICE) + d_b_ptrs = torch.tensor(B_addrs, device=DEVICE) + d_c_ptrs = torch.tensor(C_addrs, device=DEVICE) + d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE) + d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE) + + grid = lambda META: (META["NUM_SM"],) + grouped_matmul_kernel[grid]( + d_a_ptrs, + d_b_ptrs, + d_c_ptrs, + d_g_sizes, + d_g_lds, + group_size, + NUM_SM=NUM_SM, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + ) + + return group_C + +def main(): + """Run a simple test of the grouped matrix multiplication.""" + # Test the kernel + group_m = [1024, 512, 256, 128] + group_n = [1024, 512, 256, 128] + group_k = [1024, 512, 256, 128] + group_A = [] + group_B = [] + group_B_T = [] + assert len(group_m) == len(group_n) + assert len(group_n) == len(group_k) + group_size = len(group_m) + for i in range(group_size): + M = group_m[i] + N = group_n[i] + K = group_k[i] + A = torch.rand((M, K), device=DEVICE, dtype=torch.float16) + B = torch.rand((K, N), device=DEVICE, dtype=torch.float16) + B_T = B.T.contiguous() + group_A.append(A) + group_B.append(B) + group_B_T.append(B_T) + + tri_out = group_gemm_fn(group_A, group_B) + ref_out = [torch.matmul(a, b) for a, b in zip(group_A, group_B)] + passed = True + for i in range(group_size): + if not torch.allclose(ref_out[i], tri_out[i], atol=1e-2, rtol=1e-2): + print("❌ Triton and Torch differ") + passed = False + break + if passed: + print("✅ Triton and Torch match") + +if __name__ == "__main__": + main() diff --git a/kernel_perf_agent/kernel_opt/database/__init__.py b/kernel_perf_agent/kernel_opt/database/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/kernel_perf_agent/kernel_opt/database/base.py b/kernel_perf_agent/kernel_opt/database/base.py new file mode 100644 index 0000000..73ea0be --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/base.py @@ -0,0 +1,187 @@ +from pathlib import Path + +from kernel_perf_agent.kernel_opt.database.docs import ( + on_device_tma, + on_host_tma, + persistence, + pid_swizzle, +) + + +class OptNode: + + def __init__(self, level: int, dsl: str, opt_desc: str) -> None: + """Initialize the optimization node with the given level, description, and DSL. + :param level: int, Level in the tree + :param dsl: str, DSL used in the node + :param opt_desc: str, Description of the optimization + :param opt_parents: List[str], Parent nodes description + :param opt_children: List[OptNode], Children nodes + """ + + self.level = level # int, Level in the tree + self.dsl = dsl + self.opt_desc = opt_desc # str, Root node description + self.opt_parents = [] # List[str], Parent nodes description + self.opt_children = [] # List[OptNode], Children nodes + + def add_children(self, child_nodes): + """Adds a child node to the current node.""" + self.opt_children.extend(child_nodes) + + def remove_children(self, child_nodes): + """Removes a child node from the current node.""" + for child in child_nodes: + if child in self.opt_children: + self.opt_children.remove(child) + + def add_parents(self, parent_nodes): + """Adds a child node to the current node.""" + self.opt_parents.extend(parent_nodes) + + def remove_parents(self, parent_nodes): + """Removes a child node from the current node.""" + for parent in parent_nodes: + if parent in self.opt_parents: + self.opt_parents.remove(parent) + + def __repr__(self): + """String representation of the node for easy printing.""" + return f"OptNode at level {self.level}: ({self.opt_desc})" + + +class OptHierarchy: + + def __init__(self) -> None: + """Initialize the optimization hierarchy with the root node.""" + self.root = OptNode(level=0, dsl="text", opt_desc="root") + + def get_root(self): + return self.root + + def hard_initialize(self, common_path) -> None: + """Hard initialize the hierarchy with pre-programmed database.""" + + # Level 1 nodes - Latency, Memory, Utilization bottlenecks + optnode_latency = OptNode( + level=1, + dsl="text", + opt_desc="""To optimize compute-bound kernels, we employ techniques to reduce kernel execution latency, including: + - Persistent programming style to minimize kernel launch overhead + - Software pipelining to improve instruction-level parallelism and reduce execution time + """, + ) + optnode_memory = OptNode( + level=1, + dsl="text", + opt_desc="""To optimize memory-bound kernels, we employ techniques to improve performance, including: + - PID swizzling to enhance L2 cache locality + - Leveraging new architecture features, such as Tensor Memory Accelerator (TMA) to overlap memory transfers + with compute operations + """, + ) + optnode_utilization = OptNode( + level=1, + dsl="text", + opt_desc="""To optimize kernels that are not fully utilizing hardware resources, we employ techniques + to increase resource utilization and occupancy rates, including: + - Leveraging Tensor Memory Accelerator (TMA) to overlap memory transfers with compute operations + - Enabling warp specializations to improve instruction-level parallelism and reduce register pressure + - Autotuning to identify and apply optimal kernel configurations that maximize resource usage + """, + ) + level_1_opts = [optnode_latency, optnode_memory, optnode_utilization] + self.root.add_children(level_1_opts) + optnode_latency.add_parents([self.root]) + optnode_memory.add_parents([self.root]) + optnode_utilization.add_parents([self.root]) + + # Level 2 nodes - TMA, PID swizzling, persistent programming style + optnode_host_TMA = OptNode( + level=2, dsl="text", opt_desc=on_host_tma.ON_HOST_TMA + ) + optnode_device_TMA = OptNode( + level=2, dsl="text", opt_desc=on_device_tma.ON_DEVICE_TMA + ) + optnode_PID_swizzling = OptNode( + level=2, dsl="text", opt_desc=pid_swizzle.PID_SWIZZLE + ) + optnode_persistence = OptNode( + level=2, dsl="text", opt_desc=persistence.PERSISTENCE + ) + + optnode_latency.add_children([optnode_persistence]) + optnode_memory.add_children( + [ + optnode_host_TMA, + optnode_device_TMA, + optnode_PID_swizzling, + optnode_persistence, + ] + ) + optnode_utilization.add_children([optnode_persistence]) + + optnode_host_TMA.add_parents([optnode_memory]) + optnode_device_TMA.add_parents([optnode_memory]) + optnode_PID_swizzling.add_parents([optnode_memory]) + optnode_persistence.add_parents( + [optnode_latency, optnode_memory, optnode_utilization] + ) + + # Level 3 nodes - code example of each kernel + # common_path="../kernel_opt/database/code_samples/" + optnode_matmul = OptNode( + level=3, dsl="triton", opt_desc=Path(common_path / "matmul.py").read_text() + ) + optnode_matmul_pid_swizzling = OptNode( + level=3, + dsl="triton", + opt_desc=Path(common_path / "matmul_sw.py").read_text(), + ) + optnode_matmul_tma_host = OptNode( + level=3, + dsl="triton", + opt_desc=Path(common_path / "matmul_tma_host.py").read_text(), + ) + optnode_matadd = OptNode( + level=3, dsl="triton", opt_desc=Path(common_path / "matadd.py").read_text() + ) + optnode_matadd_persistence = OptNode( + level=3, + dsl="triton", + opt_desc=Path(common_path / "matadd_perst.py").read_text(), + ) + optnode_matadd_tma_host = OptNode( + level=3, + dsl="triton", + opt_desc=Path(common_path / "matadd_tma_host.py").read_text(), + ) + optnode_matadd_tma_device = OptNode( + level=3, + dsl="triton", + opt_desc=Path(common_path / "matadd_tma_device.py").read_text(), + ) + + optnode_host_TMA.add_children( + [ + optnode_matmul, + optnode_matmul_tma_host, + optnode_matadd, + optnode_matadd_tma_host, + ] + ) + optnode_device_TMA.add_children([optnode_matadd, optnode_matadd_tma_device]) + optnode_PID_swizzling.add_children( + [optnode_matmul, optnode_matmul_pid_swizzling] + ) + optnode_persistence.add_children([optnode_matadd, optnode_matadd_persistence]) + + optnode_matmul.add_parents([optnode_host_TMA, optnode_PID_swizzling]) + optnode_matmul_pid_swizzling.add_parents([optnode_PID_swizzling]) + optnode_matmul_tma_host.add_parents([optnode_host_TMA]) + optnode_matadd.add_parents( + [optnode_host_TMA, optnode_device_TMA, optnode_persistence] + ) + optnode_matadd_persistence.add_parents([optnode_persistence]) + optnode_matadd_tma_host.add_parents([optnode_host_TMA]) + optnode_matadd_tma_device.add_parents([optnode_device_TMA]) diff --git a/kernel_perf_agent/kernel_opt/database/code_samples/matadd.py b/kernel_perf_agent/kernel_opt/database/code_samples/matadd.py new file mode 100644 index 0000000..2544695 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/code_samples/matadd.py @@ -0,0 +1,60 @@ +# ============================ unoptimized matadd ================================= +import torch +import triton +import triton.language as tl + +BLOCK_SIZE_M = 128 +BLOCK_SIZE_N = 128 +DEVICE = triton.runtime.driver.active.get_active_torch_device() + +@triton.jit +def add_kernel( + x_ptr, + y_ptr, + output_ptr, + M, + N, + stride_m, + stride_n, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + # Range of pointers for loading the block of A and B. + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + x_ptrs = x_ptr + (offs_m[:, None] * stride_m + offs_n[None, :] * stride_n) + y_ptrs = y_ptr + (offs_m[:, None] * stride_m + offs_n[None, :] * stride_n) + output_ptrs = output_ptr + (offs_m[:, None] * stride_m + offs_n[None, :] * stride_n) + data_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + + x = tl.load(x_ptrs, mask=data_mask, other=0.0) + y = tl.load(y_ptrs, mask=data_mask, other=0.0) + output = x + y + tl.store(output_ptrs, output, mask=data_mask) + + +def add(x: torch.Tensor, y: torch.Tensor): + M, N = x.shape + output = torch.empty((M, N), device=x.device, dtype=torch.float16) + + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + add_kernel[grid]( + x, + y, + output, + M, + N, + x.stride(0), + x.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + ) + return output diff --git a/kernel_perf_agent/kernel_opt/database/code_samples/matadd_perst.py b/kernel_perf_agent/kernel_opt/database/code_samples/matadd_perst.py new file mode 100644 index 0000000..c873231 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/code_samples/matadd_perst.py @@ -0,0 +1,66 @@ +# ===================== matadd with persistent programming style ================== +import torch +import triton +import triton.language as tl + +BLOCK_SIZE_M = 128 +BLOCK_SIZE_N = 128 +DEVICE = triton.runtime.driver.active.get_active_torch_device() + +@triton.jit +def add_kernel(x_ptr, + y_ptr, + output_ptr, + M, + N, + stride_m, stride_n, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + NUM_SMS: tl.constexpr, + ): + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + num_tiles = num_pid_m * num_pid_n + + # iterate over the program id with a stride of the total number of blocks + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + + # Range of pointers for loading the block of A and B. + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + x_ptrs = x_ptr + (offs_m[:, None] * stride_m + offs_n[None, :] * stride_n) + y_ptrs = y_ptr + (offs_m[:, None] * stride_m + offs_n[None, :] * stride_n) + output_ptrs = output_ptr + (offs_m[:, None] * stride_m + offs_n[None, :] * stride_n) + data_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + + x = tl.load(x_ptrs, mask=data_mask, other=0.0) + y = tl.load(y_ptrs, mask=data_mask, other=0.0) + output = x + y + tl.store(output_ptrs, output, mask=data_mask) + +def add(x: torch.Tensor, y: torch.Tensor): + M, N = x.shape + output = torch.empty((M, N), device=x.device, dtype=torch.float16) + + # Get the number of streaming multiprocessors and use it to launch a fixed number of blocks + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + grid = lambda meta: ( + min(NUM_SMS, triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"])), + ) + add_kernel[grid]( + x, + y, + output, + M, + N, + x.stride(0), + x.stride(1), + BLOCK_SIZE_M = BLOCK_SIZE_M, + BLOCK_SIZE_N = BLOCK_SIZE_N, + NUM_SMS = NUM_SMS + ) + return output diff --git a/kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_device.py b/kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_device.py new file mode 100644 index 0000000..6af6ce4 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_device.py @@ -0,0 +1,81 @@ +# ======== matadd with on-device Tensor Memory Accelerator (TMA) integration ========== +from typing import Optional + +import torch +import triton +import triton.language as tl + +BLOCK_SIZE_M = 128 +BLOCK_SIZE_N = 128 +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit +def add_kernel( + x_ptr, + y_ptr, + output_ptr, + M, + N, + stride_m, + stride_n, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + # device TMA + x_desc = tl.make_tensor_descriptor( + x_ptr, + shape=[M, N], + strides=[stride_m, stride_n], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], + ) + y_desc = tl.make_tensor_descriptor( + y_ptr, + shape=[M, N], + strides=[stride_m, stride_n], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], + ) + output_desc = tl.make_tensor_descriptor( + output_ptr, + shape=[M, N], + strides=[stride_m, stride_n], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], + ) + + x = x_desc.load([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N]) + y = y_desc.load([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N]) + output = x + y + output_desc.store([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], output) + + +def add(x: torch.Tensor, y: torch.Tensor): + M, N = x.shape + output = torch.empty((M, N), device=x.device, dtype=torch.float16) + + # TMA descriptors require a global memory allocation + def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + add_kernel[grid]( + x, + y, + output, + M, + N, + x.stride(0), + x.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + ) + return output diff --git a/kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_host.py b/kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_host.py new file mode 100644 index 0000000..7195774 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_host.py @@ -0,0 +1,56 @@ +# ======== matadd with on-host Tensor Memory Accelerator (TMA) integration ========== +import torch +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor + +BLOCK_SIZE_M = 128 +BLOCK_SIZE_N = 128 +DEVICE = triton.runtime.driver.active.get_active_torch_device() + +@triton.jit +def add_kernel( + x_desc, + y_desc, + output_desc, + M, + N, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + x = x_desc.load([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N]) + y = y_desc.load([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N]) + output = x + y + output_desc.store([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], output) + + +def add(x: torch.Tensor, y: torch.Tensor): + M, N = x.shape + output = torch.empty((M, N), device=x.device, dtype=torch.float16) + + # TMA descriptors for loading A, B and storing C + x_desc = TensorDescriptor(x, x.shape, x.stride(), [BLOCK_SIZE_M, BLOCK_SIZE_N]) + y_desc = TensorDescriptor(y, y.shape, y.stride(), [BLOCK_SIZE_M, BLOCK_SIZE_N]) + output_desc = TensorDescriptor( + output, output.shape, output.stride(), [BLOCK_SIZE_M, BLOCK_SIZE_N] + ) + + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + add_kernel[grid]( + x_desc, + y_desc, + output_desc, + M, + N, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + ) + return output diff --git a/kernel_perf_agent/kernel_opt/database/code_samples/matmul.py b/kernel_perf_agent/kernel_opt/database/code_samples/matmul.py new file mode 100644 index 0000000..b79874d --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/code_samples/matmul.py @@ -0,0 +1,84 @@ +# ============================ unoptimized matmul ================================= +import torch +import triton +import triton.language as tl + +BLOCK_SIZE_M = 128 +BLOCK_SIZE_N = 128 +BLOCK_SIZE_K = 128 +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit +def matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + # Range of pointers for loading the block of A and B. + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, b, accumulator) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + c = accumulator.to(tl.float16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def matmul(a, b): + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + M, K = a.shape + K, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=torch.float16) + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + matmul_kernel[grid]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + ) + return c diff --git a/kernel_perf_agent/kernel_opt/database/code_samples/matmul_sw.py b/kernel_perf_agent/kernel_opt/database/code_samples/matmul_sw.py new file mode 100644 index 0000000..1acfe65 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/code_samples/matmul_sw.py @@ -0,0 +1,91 @@ +# ==================== matmul with PID swizzling ================================= +import torch +import triton +import triton.language as tl + +BLOCK_SIZE_M = 128 +BLOCK_SIZE_N = 128 +BLOCK_SIZE_K = 128 +GROUP_SIZE_M = 8 +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit +def matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + # Range of pointers for loading the block of A and B. + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, b, accumulator) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + c = accumulator.to(tl.float16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def matmul(a, b): + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + M, K = a.shape + K, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=torch.float16) + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + matmul_kernel[grid]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=GROUP_SIZE_M, + ) + return c diff --git a/kernel_perf_agent/kernel_opt/database/code_samples/matmul_tma_host.py b/kernel_perf_agent/kernel_opt/database/code_samples/matmul_tma_host.py new file mode 100644 index 0000000..41b8a76 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/code_samples/matmul_tma_host.py @@ -0,0 +1,65 @@ +# ======== matmul with on-host Tensor Memory Accelerator (TMA) integration ========== +import torch +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor + +BLOCK_SIZE_M = 128 +BLOCK_SIZE_N = 128 +BLOCK_SIZE_K = 128 +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit +def matmul_kernel( + a_desc, + b_desc, + c_desc, + M, + N, + K, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = a_desc.load([pid_m * BLOCK_SIZE_M, k * BLOCK_SIZE_K]) # TMA load of A + b = b_desc.load([k * BLOCK_SIZE_K, pid_n * BLOCK_SIZE_N]) # TMA load of B + accumulator = tl.dot(a, b, accumulator) + c = accumulator.to(tl.float16) + c_desc.store([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], c) + + +def matmul(a, b): + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + M, K = a.shape + K, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=torch.float16) + + # TMA descriptors for loading A, B and storing C + a_desc = TensorDescriptor(a, a.shape, a.stride(), [BLOCK_SIZE_M, BLOCK_SIZE_K]) + b_desc = TensorDescriptor(b, b.shape, b.stride(), [BLOCK_SIZE_K, BLOCK_SIZE_N]) + c_desc = TensorDescriptor(c, c.shape, c.stride(), [BLOCK_SIZE_M, BLOCK_SIZE_N]) + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + matmul_kernel[grid]( + a_desc, + b_desc, + c_desc, + M, + N, + K, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + ) + return c diff --git a/kernel_perf_agent/kernel_opt/database/docs/experimental_tma.md b/kernel_perf_agent/kernel_opt/database/docs/experimental_tma.md new file mode 100644 index 0000000..fb34cbc --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/docs/experimental_tma.md @@ -0,0 +1,153 @@ +# Triton Tutorial: How to integrate NV TMA into kernels +## Background +TMA is a hardware unit introduced by the NV Hopper GPU. It takes over some of the data transfer work from softwares and thus improves the performance by freeing up warps or reducing register pressures etc. In practice, Triton kernel authors can update the kernel code by simply replacing `tl.load` and `tl.store` with TMA API calls to get this performance boost. + +## TMA APIs +TMA API is going through changes (from experimental to official) on upstream Triton. While we’re working out a plan to migrate, we’ll support the “old” experimental API that’s currently being used in our fbsource codebase. This tutorial will be based on the experimental API. + +TMA data load/store needs a TMA tensor descriptor object. The descriptor will describe the tensor address, strides, shapes etc. of the tensor to be copied (treat it as the CUDA `TensorMap` object). The descriptor itself needs to be stored somewhere. Depending on where we initialize the descriptor, we have two types of descriptors: on-host and on-device. The former allocates the memory on host memory, initializes descriptors there and then copies them by value to GMEM. The latter will allocate a big chunk of memory on GMEM, and then have each program to find their own offset and initialize descriptors there. + +To leverage TMA, we need to decide between on-host and on-device descriptors. That decision could be yet another topic. Here we quickly highlight a few key differences: +- Only on-device descriptors can handle dynamic shapes where not all programs are handling the same box size, which is typical in kernels like Jagged Flash Attention or HSTU. The reason is that on-host descriptors are initialized before kernel launch while on-device ones are initialized in the kernel where the box size is known. +- Torch Inductor, especially AOTI, currently only supports on-device descriptors +- On-device descriptors are initialized by every kernel program in the grid while on-host ones are initialized by host code so likely on-device descriptors take more compute resources +- Current on-device descriptors implementation (experimental API) might take more global memory because the number of programs is not necessarily known when allocating memory chunk for descriptors (e.g. depending on auto tuned BLOCK_SIZE_M), so we need to be conservative and allocate more memory + +Note: neither of these two types of TMA is necessarily faster than the other. It depends on actual use cases. + +Now for the sake of this tutorial we’ll start with on-device descriptors. And also we’ll use the example of copying 2d tensors as it’s the most common. + +With those premises, here’re the APIs to call: + +- Allocate memory chunk to store descriptors on host: +``` +TMA_DESC_SIZE = 128 # size in bytes used by a single descriptor, tunable +NUM_DESC_PER_PROGRAM = ... # how many different tensors to load/store by each program. e.g. 3 for GEMM `C=AB`, 4 for HSTU Q,K,V,O tensors +NUM_OF_PROGRAMS = ... # same as specified in kernel `grid`. If grid size is related to auto tune config, use a reasonable upper bound by hard coding "minimal block M size" etc. for now. +workspace = torch.empty( + TMA_DESC_SIZE * NUM_DESC_PER_PROGRAM * NUM_OF_PROGRAMS, + dtype=torch.uint8, + device="cuda",) +# then pass `workspace` to kernel +``` +- Initialize descriptor object: +``` +desc_ptr = workspace + TMA_DESC_SIZE * + TMA_DESC_SIZE * # in program offset in range [0,NUM_DESC_PER_PROGRAM) + + +tl.extra.cuda.experimental_device_tensormap_create2d( +desc_ptr=desc_ptr, +global_address=, # tensor to load into or store from +load_size=[BOX_SIZE_0, BOX_SIZE_1], # size of the 2D box to copy +global_size=[GLOBAL_SIZE_0, GLOBAL_SIZE_1], # this defines a "global box" in GMEM. TMA load/store won't go over this boundary if load_size is not divisble by global_size. e.g. Assuming GLOBAL_SIZE_0 == 1.5 * BLOCK_SIZE_0 and GLOBAL_SIZE_1 == BLOCK_SIZE_1, then: for TMA load, the second box will return a tensor of size (BLOCK_SIZE_0, BLOCK_SIZE_1) but the second half of the tensor is all 0; for TMA store, the second box will only have its first half written to GMEM. +element_ty= # usually tensor_ptr.dtype.element_ty +) +``` +- Acquire fence on a TensorMap/descriptor object: +``` +tl.extra.cuda.experimental_tensormap_fenceproxy_acquire() +``` +- Load data from GMEM to SMEM: +``` +x = tl._experimental_descriptor_load( + , #initialized, and acquired fence above + [OFFSET_0, OFFSET_1], # offset in "global box" for the 2D loading box to start from + [BOX_SIZE_0, BOX_SIZE_1], # keep the same as descriptor's `load_size` + ,) +``` +- Store data from SMEM to GMEM: +``` +tl._experimental_descriptor_store( + , #initialized, and acquired fence above + , #the tensor to be stored on GMEM + [OFFSET_0, OFFSET_1], # offset in "global box" for the 2D loading box to start from +) +``` + +## Example +### Store +Let’s assume we have the following non TMA store code now: + +``` +start_m = pid * BLOCK_M +offs_m = start_m + tl.arange(0, BLOCK_M) +offs_v_d = tl.arange(0, BLOCK_D_V) +off_o = Out + seq_start * stride_om + off_h * stride_oh # TMA will use Out as global address, and include seq_start * stride_om + off_h * stride_oh as part of offsets +out_ptrs = off_o + offs_m[:, None] * stride_om + offs_v_d[None, :] +tl.store(out_ptrs, acc, mask=(offs_m < seq_len)[:, None]) + +# Essentially, it tries to store the tensor `acc` into this box: +# Out[ +# (seq_start + pid * BLOCK_M : seq_start + (pid+1) * BLOCK_M), +# (off_h * stride_oh : off_h * stride_oh + BLOCK_D_V) +# ] +# In other words, it's a box of size (BLOCK_M, BLOCK_D_V) starting at [seq_start + pid * BLOCK_M, off_h * stride_oh]. This will be the bases for our TMA desc init and load/store op. +# And the rows with dim0 larger than (seq_start + seq_len) will be masked. Note that (seq_start + seq_len) == seq_end, which we'll use in TMA store below +``` +The equivalent TMA store code would be: +``` +# pyre-ignore [20] +tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=device_desc_o, + global_address=Out, # Out is of shape (L, H, DimV) + load_size=[BLOCK_M, BLOCK_D_V], #box size as explained in comments above + global_size=[seq_end.to(tl.int32), H * DimV], # this eliminates the need for `mask`, TMA automatically take care of boundaries. + element_ty=Out.dtype.element_ty, +) +# pyre-ignore [20] +tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(device_desc_o) +tl._experimental_descriptor_store( + device_desc_o, + acc, # acc needs to be casted to the right dtype + [ #offset as explained in comments above (where the box starts at) + (seq_start + pid * BLOCK_M).to(tl.int32), + (off_h * stride_oh).to(tl.int32), + ], + ) +``` +### Load +Assume we have this non TMA load code: +``` +Q_block_ptr = tl.make_block_ptr( + base=Q + off_h * stride_qh + seq_start * stride_qm, + shape=(seq_len, BLOCK_D_Q), + strides=(stride_qm, 1), + offsets=(start_m, 0), + block_shape=(BLOCK_M, BLOCK_D_Q), + order=(1, 0), + ) +q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option="zero") + + +# Essentially this tries to load this box into q: +# Q[ +# (seq_start + start_m : seq_start + start_m + BLOCK_M), +# (off_h * stride_qh : off_h * stride_qh + BLOCK_D_Q) +# ] +# In other words, it's a box of size (BLOCK_M, BLOCK_D_Q) starting at [seq_start + start_m, off_h * stride_qh]. This will be the bases for our TMA desc init and load/store op. +# And the rows with dim0 larger than seq_len will be filled with zero, with shape of q always being (BLOCK_M, BLOCK_D_Q). +``` +The equivalent TMA load code will be: +``` +# pyre-ignore [20] +tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=device_desc_q, + global_address=Q, # shape (L, H, DimQ) + load_size=[BLOCK_M,BLOCK_D_Q], #box size as explained in comments above + global_size=[seq_end.to(tl.int32), H * DimQ], # seq_end == seq_start + seq_len + element_ty=Q.dtype.element_ty, + ) +# pyre-ignore [20] + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(device_desc_q) + + +q = tl._experimental_descriptor_load( + device_desc_q, + [ #offset as explained in comments above (where the box starts at) + (seq_start + start_m).to(tl.int32), + (off_h * stride_qh).to(tl.int32), + ], + [BLOCK_M,BLOCK_D_Q], + Q.dtype.element_ty, + ) +``` diff --git a/kernel_perf_agent/kernel_opt/database/docs/on_device_tma.py b/kernel_perf_agent/kernel_opt/database/docs/on_device_tma.py new file mode 100644 index 0000000..d9baa13 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/docs/on_device_tma.py @@ -0,0 +1,42 @@ +ON_DEVICE_TMA = """ +============================= On-Device Tensor Memory Accelerator (TMA) =================================== +## What is TMA? +The Tensor Memory Accelerator (TMA) is a hardware feature introduced in NVIDIA Hopper GPUs +for performing asynchronous memory copies between a GPU's global memory (GMEM) and the +shared memory (SMEM) of its thread blocks (i.e., CTAs). TMA offloads some of the data +transfer work from software, thereby improving performance by overlapping memory transfers +with computation, freeing up warps, and reducing register pressure. + +## On-Device TMA: +TMA data load/store operations require a TMA tensor descriptor object. This descriptor +specifies the tensor's address, strides, shapes, and other attributes necessary for the +copy operation. TMA descriptors can be initialized on the device. On-device descriptors +allocate a large chunk of memory in GMEM, and each program have to find its own offset +and initialize descriptors there. + +## How to integrate on-device TMA into a Triton program? +To enable on-device TMA in a Triton program, we need to add support from both the host and kernel programs. +In the host program, a global memory allocation is needed by adding the following function: +``` +def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) + +triton.set_allocator(alloc_fn) +``` +In addition, we need to import the method `from typing import Optional`. +In the kernel program, instead of loading and storing a tensor block with a range of pointers, +we declare a TMA descriptor for each tensor and then use the descriptor to load and store the tensor in blocks. +An example of a TMA descriptor declaration is +``` +x_desc = tl.make_tensor_descriptor( + x_ptr, # the pointer to the tensor + shape=[M, N], # the shape of the tensor + strides=[stride_m, stride_n], # the stride of the tensor + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], # the block size of each TMA load/store +) +``` +An example of the TMA load is +``` +x = x_desc.load([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N]) # the start offset of the TMA load +``` +""" diff --git a/kernel_perf_agent/kernel_opt/database/docs/on_host_tma.py b/kernel_perf_agent/kernel_opt/database/docs/on_host_tma.py new file mode 100644 index 0000000..08371c0 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/docs/on_host_tma.py @@ -0,0 +1,36 @@ + +ON_HOST_TMA = """ +============================= On-Host Tensor Memory Accelerator (TMA) =================================== +## What is TMA? +The Tensor Memory Accelerator (TMA) is a hardware feature introduced in NVIDIA Hopper GPUs +for performing asynchronous memory copies between a GPU's global memory (GMEM) and the +shared memory (SMEM) of its thread blocks (i.e., CTAs). TMA offloads some of the data +transfer work from software, thereby improving performance by overlapping memory transfers +with computation, freeing up warps, and reducing register pressure. + +## On-Host TMA: +TMA data load/store operations require a TMA tensor descriptor object. This descriptor +specifies the tensor's address, strides, shapes, and other attributes necessary for the +copy operation. TMA descriptors can be initialized on the host. On-host descriptors +allocate memory in the host memory, initialize the descriptors there, and then copy +them by value to GMEM. + +## How to integrate on-host TMA into a Triton program? +To enable on-host TMA in a Triton program, we need to add support on both the host and kernel programs. +In the host program, we allocate a TMA descriptor for each tensor and pass the descriptor as an argument to the kernel. +An example of a TMA descriptor declaration is +``` +x_desc = TensorDescriptor( + x, # the pointer to the tensor + x.shape, # the shape of the tensor + x.stride(), # the stride of the tensor + [BLOCK_SIZE_M, BLOCK_SIZE_N] # the block size of each TMA load/store +) +``` +And in addition, we need to import the method `from triton.tools.tensor_descriptor import TensorDescriptor`. +In the kernel program, instead of loading and storing a tensor block with a range of pointers, +we use the TMA descriptor to load and store the tensor in blocks. An example of the TMA load is +``` +x = x_desc.load([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N]) # the start offset of the TMA load +``` +""" diff --git a/kernel_perf_agent/kernel_opt/database/docs/persistence.py b/kernel_perf_agent/kernel_opt/database/docs/persistence.py new file mode 100644 index 0000000..0165eac --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/docs/persistence.py @@ -0,0 +1,29 @@ +PERSISTENCE=''' +================================ Persistent Programming Style ======================================= +## What it is: +The persistent programming style in GPU is a kernel design pattern where a fixed number of +blocks is launched, typically equal to the number of streaming multiprocessors (SMs), +instead of launching blocks proportional to the problem size. This pattern is particularly effective +for large-scale computations where the problem size exceeds the GPU's parallel capacity. + +## Traditional Approach: +In an unoptimized Triton GPU kernel, the number of blocks launched is dependent on the input size, +typically calculated as `triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]` +in the grid argument. +Each block processes exactly one tile of work, and the number of blocks can be much larger +than the available hardware resources. + +## Persistent Approach: +In a persistent style implementation, a fixed number of blocks is launched, which can be the number +of streaming multiprocessors (SMs) on the GPU by calling `torch.cuda.get_device_properties("cuda").multi_processor_count`. +In the kernel code, each block iterates over the program ID with a stride equal to the total number of blocks, +ensuring that the computation is completed by a fixed number of blocks. +These blocks "persist" and loop until all work is completed. + +## Advantages: +* Better resource utilization: Matches hardware capabilities exactly +* Reduced launch overhead: Fewer kernel launches for large problems +* Improved occupancy: Keeps all SMs busy throughout execution +* Better cache locality: Blocks can reuse data across multiple iterations +* Load balancing: Work is distributed more evenly across SMs +''' diff --git a/kernel_perf_agent/kernel_opt/database/docs/pid_swizzle.py b/kernel_perf_agent/kernel_opt/database/docs/pid_swizzle.py new file mode 100644 index 0000000..dff5359 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/docs/pid_swizzle.py @@ -0,0 +1,24 @@ + +PID_SWIZZLE=""" +===================================== PID Swizzling =========================================== +## What it is: +PID swizzling is a GPU optimization technique used in Triton programming that remaps +program identifiers (`pid_m` and `pid_n`) to create better memory access patterns, +specifically for L2 cache locality. This technique is commonly used in high-performance GPU kernels, +particularly for GEMM (General Matrix Multiply) operations in frameworks like Triton. + +## Traditional Approach: +The program launch order matters as it affects the L2 cache hit rate. +In an unoptimized GPU kernel, each program instance computes a [BLOCK_SIZE_M, BLOCK_SIZE_N] +block of the output tensor, and the program identifiers are arranged in a simple row-major ordering +by `pid_m = pid // num_pid_n` and `pid_n = pid % num_pid_n`. +This creates poor cache locality because adjacent programs access memory locations that are far apart. + +## PID Swizzling Approach: +PID swizzling forms "super-grouping" of programs with a fixed row size `GROUP_SIZE_M`. +The number of programs in a group is `GROUP_SIZE_M * num_pid_n`. +The `group_id` is calculated by dividing the program id by the number of programs in a group. +If `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the row size of the last group is smaller +and can be calculated by subtracting `GROUP_SIZE_M * group_id` from `num_pid_m`. +The programs within a group are arranged in a column-major order. +""" diff --git a/kernel_perf_agent/kernel_opt/database/docs/tma.md b/kernel_perf_agent/kernel_opt/database/docs/tma.md new file mode 100644 index 0000000..89087d9 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/docs/tma.md @@ -0,0 +1,150 @@ +**TMA (Tensor Memory Accelerator)** is a hardware feature in NVIDIA GPUs that accelerates memory transfers for tensor operations by providing more efficient block-based memory access patterns. + +What is TMA? +------------ + +TMA replaces traditional pointer-based memory access with **tensor descriptors** that describe the entire tensor layout, enabling the GPU hardware to optimize memory transfers automatically. + +Benefits of TMA: +---------------- + +* **Hardware-accelerated memory transfers** +* **Better memory coalescing** +* **Reduced memory access overhead** +* **Simplified memory access patterns** + +How to Add TMA to Triton Code +----------------------------- + +There are two approaches: **Host-side TMA** and **Device-side TMA**. + +### 1. Host-side TMA Implementation + +**Host-side setup:** + +``` +from triton.tools.tensor_descriptor import TensorDescriptor + +def matmul_with_tma(a, b): + # Create TMA descriptors on host + a_desc = TensorDescriptor( + a, # the tensor + a.shape, # tensor shape + a.stride(), # tensor strides + [BLOCK_SIZE_M, BLOCK_SIZE_K] # block size for TMA operations + ) + + b_desc = TensorDescriptor( + b, + b.shape, + b.stride(), + [BLOCK_SIZE_K, BLOCK_SIZE_N] + ) + + c_desc = TensorDescriptor( + c, + c.shape, + c.stride(), + [BLOCK_SIZE_M, BLOCK_SIZE_N] + ) + + # Pass descriptors to kernel + kernel[grid](a_desc, b_desc, c_desc, ...) +``` + +**Kernel-side usage:** + +``` +@triton.jit +def matmul_kernel(a_desc, b_desc, c_desc, ...): + pid = tl.program_id(axis=0) + # Calculate tile positions + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + # Load using TMA descriptors + a = a_desc.load([pid_m * BLOCK_SIZE_M, 0]) # offset coordinates + b = b_desc.load([0, pid_n * BLOCK_SIZE_N]) + + # Compute + accumulator = tl.dot(a, b) + + # Store using TMA descriptor + c_desc.store([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], accumulator) +``` + +### 2. Device-side TMA Implementation + +**Host-side setup:** + +``` +from typing import Optional + +def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) + +# Set custom allocator for TMA +triton.set_allocator(alloc_fn) +``` + +**Kernel-side usage:** + +``` +@triton.jit +def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, ...): + # Create TMA descriptors in kernel + a_desc = tl.make_tensor_descriptor( + a_ptr, # pointer to tensor + shape=[M, K], # tensor shape + strides=[stride_am, stride_ak], # tensor strides + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K] # TMA block size + ) + + b_desc = tl.make_tensor_descriptor( + b_ptr, + shape=[K, N], + strides=[stride_bk, stride_bn], + block_shape=[BLOCK_SIZE_K, BLOCK_SIZE_N] + ) + + c_desc = tl.make_tensor_descriptor( + c_ptr, + shape=[M, N], + strides=[stride_cm, stride_cn], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N] + ) + + # Use descriptors for memory operations + pid = tl.program_id(axis=0) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + # Load blocks using TMA + a = a_desc.load([pid_m * BLOCK_SIZE_M, 0]) + b = b_desc.load([0, pid_n * BLOCK_SIZE_N]) + + # Compute and store + result = tl.dot(a, b) + c_desc.store([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], result) +``` + +Key Differences from Traditional Approach: +------------------------------------------ + +**Traditional:** + +``` +# Manual pointer arithmetic +offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) +a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak +a = tl.load(a_ptrs, mask=...) +``` + +**TMA:** + +``` +# Descriptor-based access +a = a_desc.load([pid_m * BLOCK_SIZE_M, k_offset]) +``` + +TMA simplifies memory access patterns and leverages hardware acceleration for better performance in tensor operations. diff --git a/kernel_perf_agent/kernel_opt/decorator/__init__.py b/kernel_perf_agent/kernel_opt/decorator/__init__.py new file mode 100644 index 0000000..df0379b --- /dev/null +++ b/kernel_perf_agent/kernel_opt/decorator/__init__.py @@ -0,0 +1,3 @@ +from .kernel_opt import kernel_opt + +__all__ = ["kernel_opt"] diff --git a/kernel_perf_agent/kernel_opt/decorator/agent.py b/kernel_perf_agent/kernel_opt/decorator/agent.py new file mode 100644 index 0000000..d6f53aa --- /dev/null +++ b/kernel_perf_agent/kernel_opt/decorator/agent.py @@ -0,0 +1,152 @@ +import re +from pathlib import Path +from typing import Tuple + +from fastapi import status + +from kernel_perf_agent.kernel_opt.configs.envs import NUM_OF_ROUNDS +from kernel_perf_agent.kernel_opt.database.base import OptHierarchy, OptNode +from kernel_perf_agent.kernel_opt.profiler.profiler import KernelProfiler +from kernel_perf_agent.kernel_opt.prompts.prompt_manager import PromptManager +from kernel_perf_agent.kernel_opt.retriever.retriever import Retriever +from kernel_perf_agent.kernel_opt.rewriter.kernel_rewriter import KernelRewriter +from kernel_perf_agent.kernel_opt.utils.parser_util import extract_code, get_module_path +from kernel_perf_agent.kernel_opt.verifier.verifier import KernelCodeVerifier, KernelFileVerifier + + +def KernelAgent( + model: str, + dsl: str, + kernel_name: str, + func_prompt: str, + func: str, + opt_prompt: str, + debug: bool, +) -> Tuple[str, str, str, str]: + """Agent wrapper for kernel optimization.""" + + status_msg = "" + debug_msg = "" + session_info = "" + func_out = "" + + # Verifier - check syntax errors + verifier_result = KernelCodeVerifier( + code=func, + ) + if not verifier_result.ok: + status_msg += "❌ Input syntax validation failed.\n" + session_info += "❌ Input syntax validation failed.\n" + debug_msg = verifier_result.message + return status_msg, func, debug_msg, "" + + session_info += "✅ Input syntax validation passed.\n" + + # TODO: Profiler - check functional correctness + profiling_result = KernelProfiler(kernel_path=Path()) + if not profiling_result.ok: + status_msg += "❌ Input functional validation failed.\n" + session_info += "❌ Input functional validation failed.\n" + debug_msg = profiling_result.message + return status_msg, func, debug_msg, "" + session_info += "✅ Input functional validation passed.\n" + + # Database Construction + common_path = Path("kernel_opt/database/code_samples/") + opt_hierarchy = OptHierarchy() + opt_hierarchy.hard_initialize(common_path) + + # Retriever - fetch related context from database + retriever = Retriever( + func_prompt=func_prompt, + opt_prompt=opt_prompt, + model=model, + dsl=dsl, + kernel_name=kernel_name, + database=opt_hierarchy, + module_path=Path(), + debug=debug, + ) + opt_node, debug_str = retriever.retrieve() + debug_msg += debug_str + + # Prompt Manager - Build prompt for LLM + prompt_manager = PromptManager( + func_source_code=func, + func_prompt=func_prompt, + opt_prompt=opt_prompt, + model=model, + dsl=dsl, + kernel_name=kernel_name, + database=opt_hierarchy, + opt_node=opt_node, + module_path=Path(), + debug=debug, + ) + prompt, debug_str = prompt_manager.build_rewrite_prompt() + debug_msg += debug_str + error = "" + + # Iterate with error messages + for attempt in range(NUM_OF_ROUNDS): + session_info += "=" * 30 + "\n" + session_info += f"Attempt: {attempt}" + "\n" + + # Rewriter - rewrite kernel + rewriter = KernelRewriter( + prompt=prompt, + model=model, + debug=debug, + module_path=Path(), + error=error, + ) + response, debug_str = rewriter.generate_kernel(error) + debug_msg += debug_str + + # Extractor - extract kernel from response + output_program = extract_code(response_text=response, debug=debug) + # if debug: + # debug_msg += "****** Extracted code ****** : \n" + # debug_msg += output_program + + correct = True + # Verifier - check syntax errors + verifier_result = KernelCodeVerifier( + code=output_program, + ) + if verifier_result.ok: + session_info += "✅ Output syntax validation passed.\n" + else: + correct = False + session_info += "❌ Output syntax validation failed.\n" + session_info += ( + f""" +The previous generated program has a syntax Error: {error} """ + + "\n" + + verifier_result.message + ) + + # TODO: Profiler - check functional correctness + profiling_result = KernelProfiler(kernel_path=Path()) + if profiling_result.ok: + session_info += "✅ Output functional validation passed.\n" + else: + correct = False + session_info += "❌ Output functional validation failed.\n" + session_info += ( + f""" +The previous generated program has a function error: {error} """ + + "\n" + + profiling_result.message + ) + if correct: + func_out = output_program + break + + else: + status_msg = "❌ Kernel Generation Failed. {\n}" + session_info += f""" +❌ Kernel validation failed after {NUM_OF_ROUNDS} attempts with the last error: +{error}""" + + return status_msg, func_out, debug_msg, session_info diff --git a/kernel_perf_agent/kernel_opt/decorator/kernel_opt.py b/kernel_perf_agent/kernel_opt/decorator/kernel_opt.py new file mode 100644 index 0000000..f5b5724 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/decorator/kernel_opt.py @@ -0,0 +1,183 @@ +"""Kernel agent decorator.""" + +import functools +import os +import subprocess +from typing import Callable + +from dotenv import load_dotenv +from kernel_perf_agent.kernel_opt.configs.envs import NUM_OF_ROUNDS +from kernel_perf_agent.kernel_opt.database.base import OptHierarchy +from kernel_perf_agent.kernel_opt.profiler.profiler import KernelProfiler +from kernel_perf_agent.kernel_opt.prompts.prompt_manager import PromptManager +from kernel_perf_agent.kernel_opt.retriever.retriever import Retriever +from kernel_perf_agent.kernel_opt.rewriter.kernel_rewriter import KernelRewriter +from kernel_perf_agent.kernel_opt.utils.parser_util import ( + extract_code, + get_module_path, + remove_decorators_from_file, +) +from kernel_perf_agent.kernel_opt.verifier.verifier import KernelCodeVerifier, KernelFileVerifier + + +def kernel_opt( + func_prompt: str, + opt_prompt: str, + model: str, + dsl: str, + kernel_name: str, + debug: bool, +): + """Decorator for kernel generation. + :param prompt: Description of the kernel to generate + :param model: LLM model to use (e.g., "deepseek-chat") + :param dsl: Target DSL (e.g., "triton") + :param kernel_name: Name of the kernel (defaults to function name) + :param debug: Whether to print debug information + """ + + def decorator(func: Callable): + @functools.wraps(func) + def wrapper(*args, **kwargs): + + func_path = get_module_path(func).resolve() + func_dir = get_module_path(func).parent.resolve() + + # Debug output + if debug: + debug_output_path = func_dir / "debug_output" + if debug_output_path.is_dir(): + subprocess.run(["rm", "-rf", str(debug_output_path)], check=True) + subprocess.run(["mkdir", str(debug_output_path)]) + + # Load environment variables + current_script_dir = os.path.dirname(os.getcwd()) + dotenv_path = os.path.join(current_script_dir, ".env") + load_dotenv(dotenv_path) + + # Verifier - check syntax errors + verifier_result = KernelFileVerifier(module_path=func_path) + if not verifier_result.ok: + print("❌ Input syntax validation failed.") + return func(*args, **kwargs) + print("✅ Input syntax validation passed.") + + # TODO: Profiler - check functional correctness + profiling_result = KernelProfiler(kernel_path=func_path) + if not profiling_result.ok: + print("❌ Input functional validation failed.") + return func(*args, **kwargs) + print("✅ Input functional validation passed.") + + # Database Construction + common_path = func_dir / "../kernel_opt/database/code_samples/" + opt_hierarchy = OptHierarchy() + opt_hierarchy.hard_initialize(common_path) + + # Retriever - fetch related context from database + retriever = Retriever( + func_prompt=func_prompt, + opt_prompt=opt_prompt, + model=model, + dsl=dsl, + kernel_name=kernel_name, + database=opt_hierarchy, + module_path=func_dir, + debug=debug, + ) + opt_node, debug_str = retriever.retrieve() + + # Prompt Manager - Build prompt for LLM + func_source_code = remove_decorators_from_file(func_path) + prompt_manager = PromptManager( + func_source_code=func_source_code, + func_prompt=func_prompt, + opt_prompt=opt_prompt, + model=model, + dsl=dsl, + kernel_name=kernel_name, + database=opt_hierarchy, + opt_node=opt_node, + module_path=func_dir, + debug=debug, + ) + prompt, debug_str = prompt_manager.build_rewrite_prompt() + error = "" + + # Iterate with error messages + for attempt in range(NUM_OF_ROUNDS): + print("=" * 50) + print(f"Attempt: {attempt}") + + # Rewriter - rewrite kernel + rewriter = KernelRewriter( + prompt=prompt, + model=model, + debug=debug, + module_path=func_dir, + error=error, + ) + response, debug_str = rewriter.generate_kernel(error=error) + + # Extractor - extract kernel from response + output_program = extract_code(response_text=response, debug=debug) + if debug: + debug_output_path = func_dir / "debug_output" / "output.log" + with open(str(debug_output_path), "w") as file: + file.write("****** Extracted code ****** : \n") + file.write(output_program) + + correct = True + + # Verifier - check syntax errors + error = "" + verifier_result = KernelCodeVerifier( + code=output_program, + ) + if verifier_result.ok: + print("✅ Output syntax validation passed.") + else: + correct = False + print("❌ Output syntax validation failed.") + error += ( + f""" +The previous generated program has a syntax Error: {error} \n""" + + verifier_result.message + ) + + # TODO: Profiler - check functional correctness + profiling_result = KernelProfiler(kernel_path=func_path) + if profiling_result.ok: + print("✅ Output functional validation passed.") + else: + correct = False + print("❌ Output functional validation failed.") + error += ( + f""" +The previous generated program has a function error: {error} \n""" + + profiling_result.message + ) + + # Stop iteration and store the output if correct + if correct: + with open( + str(func_dir / str(kernel_name + "_opt.py")), "w" + ) as file: + file.write(output_program) + break + + else: + print("❌ Kernel Generation Failed") + raise RuntimeError( + f"❌ Kernel validation failed after {NUM_OF_ROUNDS} attempts with the last error: \n{error}" + ) + + # Run original function + print("=" * 50) + print("Please find the generated kernel in {}_opt.py.".format(kernel_name)) + print("Below runs the original program.") + return func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/kernel_perf_agent/kernel_opt/prompts/__init__.py b/kernel_perf_agent/kernel_opt/prompts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/kernel_perf_agent/kernel_opt/prompts/prompt_manager.py b/kernel_perf_agent/kernel_opt/prompts/prompt_manager.py new file mode 100644 index 0000000..4752eff --- /dev/null +++ b/kernel_perf_agent/kernel_opt/prompts/prompt_manager.py @@ -0,0 +1,109 @@ +"""Prompt management.""" + +import ast +import inspect +from pathlib import Path +from typing import Callable, Tuple + +from kernel_perf_agent.kernel_opt.database.base import OptHierarchy, OptNode +from kernel_perf_agent.kernel_opt.prompts.rewrite_prompt_template import REWRITE_PROMPT_TEMPLATE +from kernel_perf_agent.kernel_opt.utils.parser_util import get_module_path, remove_decorators_from_file + + +class PromptManager: + """Manages prompt construction.""" + + def __init__( + self, + func_source_code: str, + func_prompt: str, + opt_prompt: str, + model: str, + dsl: str, + kernel_name: str, + database: OptHierarchy, + opt_node: OptNode, + module_path: Path, + debug: bool, + ): + """Initialize prompt manager. + :param func: Function to optimize + :param func_prompt: Function prompt + :param opt_prompt: Optimization prompt + :param model: LLM model to use + :param dsl: Target DSL (e.g., "triton") + :param kernel_name: Name of the kernel (defaults to function name) + :param database: Knowledge database of kernel optimizations + :param opt_node: The most relevant optimization node in database + :param module_path: Path to the module containing the function + :param debug: Whether to print debug information + """ + + self.func_source_code = func_source_code + self.func_prompt = func_prompt + self.opt_prompt = opt_prompt + self.model = model + self.dsl = dsl + self.kernel_name = kernel_name + self.database = database + self.opt_node = opt_node + self.module_path = module_path + self.debug = debug + + def build_rewrite_prompt(self) -> Tuple[str, str]: + """Build rewrite prompt.""" + + # Get context by traversing opt_node to all leaf nodes + context = "" + leaf = False + cur_level = [self.opt_node] + while cur_level: + child_level = [] + for node in cur_level: + # Leaf nodes are code examples + if not leaf and not node.opt_children: + leaf = True + context += """ +Here are code examples before and after the optimization: +""" + context += node.opt_desc + for child in node.opt_children: + if child not in child_level: + child_level.append(child) + cur_level = child_level + + debug_str = "" + # if self.debug: + # debug_str += f""" + # ****** Context ****** : + # {context} + # """ + # if str(self.module_path) != "": + # debug_context_path = self.module_path / "debug_output" / "context.log" + # with open(str(debug_context_path), "w") as file: + # file.write(debug_str) + # # file.write("****** Context ****** : \n") + # # file.write(context) + + # Rewriting the kernels at the same DSL level as the input. + prompt = REWRITE_PROMPT_TEMPLATE.format( + dsl=self.dsl, + kernel_name=self.kernel_name, + func_prompt=self.func_prompt, + input_kernel=self.func_source_code, + opt_prompt=self.opt_prompt, + context=context, + ) + + if self.debug: + debug_str += f""" +****** Prompt ****** : +{prompt} +""" + # if str(self.module_path) != "": + # debug_prompt_path = self.module_path / "debug_output" / "prompt.log" + # with open(str(debug_prompt_path), "w") as file: + # file.write("****** Prompt ****** : \n") + # file.write(prompt) + + return prompt, debug_str diff --git a/kernel_perf_agent/kernel_opt/prompts/rewrite_prompt_template.py b/kernel_perf_agent/kernel_opt/prompts/rewrite_prompt_template.py new file mode 100644 index 0000000..fe3fb6f --- /dev/null +++ b/kernel_perf_agent/kernel_opt/prompts/rewrite_prompt_template.py @@ -0,0 +1,21 @@ +"""Rewrite Prompt template.""" +REWRITE_PROMPT_TEMPLATE = """ +You are a professional performance engineer who is an expert in rewriting {dsl} kernels to improve their performance. + +Your task is to rewrite the following {dsl} kernel to integrate the specific optimization. +The kernel name is {kernel_name}. +The function of this kernel is {func_prompt}. +The kernel source code is: +{input_kernel} + +The required optimization to integrate is: +{opt_prompt} + +Here are the necessary context about the specific optimization: +{context} + +IMPORTANT: +1. Rewrite the given kernel at {dsl} level. +2. Generate the complete implementation that contains both the host code and the kernel code. +3. Please use markdown formatting (like ```python) in your output to wrap the code that you generate. +""" diff --git a/kernel_perf_agent/kernel_opt/retriever/__init__.py b/kernel_perf_agent/kernel_opt/retriever/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/kernel_perf_agent/kernel_opt/retriever/retriever.py b/kernel_perf_agent/kernel_opt/retriever/retriever.py new file mode 100644 index 0000000..a5c2021 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/retriever/retriever.py @@ -0,0 +1,131 @@ +import os +import warnings +from pathlib import Path + +import numpy as np +from kernel_perf_agent.kernel_opt.database.base import OptHierarchy +from kernel_perf_agent.kernel_opt.utils.proxy_util import devgpu_proxy_setup +from langchain_community.document_loaders import DirectoryLoader +from langchain_openai import OpenAIEmbeddings + +warnings.filterwarnings("ignore", category=FutureWarning) + + +class Retriever: + """Retriever for kernel generation.""" + + def __init__( + self, + func_prompt: str, + opt_prompt: str, + model: str, + dsl: str, + kernel_name: str, + database: OptHierarchy, + module_path: Path, + debug: bool, + ): + """ + Initialize the retriever. + :param func_prompt: Description of the kernel to generate + :param opt_prompt: Description of the optimization to perform + :param model: LLM model to use + :param dsl: Target DSL (e.g., "triton") + :param kernel_name: Name of the kernel (defaults to function name) + :param database: Knowledge database of kernel optimizations + :param module_path: Path to the module containing the function + :param debug: Whether to print debug information + """ + + self.func_prompt = func_prompt + self.opt_prompt = opt_prompt + self.model = model + self.dsl = dsl + self.kernel_name = kernel_name + self.database = database + self.module_path = module_path + self.debug = debug + + # Configure the proxy + self._original_proxy_env = devgpu_proxy_setup() + + def doc_preprocess(self): + """Preprocess the documents.""" + # Langchain + # ========================================= + # DirectoryLoader + dir_loader = DirectoryLoader("../kernel_opt/database/", recursive=True) + dir_docs = dir_loader.load() + + # ========================================= + def cosine_similarity(vec1, vec2): + dot_product = np.dot(vec1, vec2) + norm_vec1 = np.linalg.norm(vec1) + norm_vec2 = np.linalg.norm(vec2) + return dot_product / (norm_vec1 * norm_vec2) + + query = "add Tensor Memory Accelerator (TMA) support to the kernel" + embeddings = OpenAIEmbeddings() + query_embedding = embeddings.embed_query(query) + for doc in dir_docs: + doc_embedding = embeddings.embed_query(doc.page_content) + similarity = cosine_similarity([query_embedding], doc_embedding)[0] + print(doc.metadata, similarity) + + def _cosine_similarity(self, vec1, vec2): + dot_product = np.dot(vec1, vec2) + norm_vec1 = np.linalg.norm(vec1) + norm_vec2 = np.linalg.norm(vec2) + return dot_product / (norm_vec1 * norm_vec2) + + def retrieve(self): + """Retrieve the relevant context in the database from the key (opt_prompt).""" + embeddings = OpenAIEmbeddings() + key = self.opt_prompt + key_embedding = embeddings.embed_query(key) + + # Compute the similarity score for all nodes in the database tree + root = self.database.get_root() + cur_level = [root_child for root_child in root.opt_children] + opt_similarity = dict() + while cur_level: + child_level = [] + for node in cur_level: + opt_embedding = embeddings.embed_query(node.opt_desc) + opt_similarity[node] = self._cosine_similarity( + key_embedding, opt_embedding + ) + for child in node.opt_children: + if child not in child_level: + child_level.append(child) + cur_level = child_level + + # Get the node with the highest similarity + opt_similarity_sorted = sorted( + opt_similarity.items(), key=lambda item: item[1], reverse=True + ) + opt_most_similar = opt_similarity_sorted[0][0] + + # Print the nodes and their similarity scores + debug_str = "" + if self.debug: + for key, value in opt_similarity_sorted: + debug_str += f""" +--------------------------------- +{key.opt_desc.splitlines(keepends=True)[:2]} +--------------------------------- +{str(value)} +""" + # if str(self.module_path) != "": + # debug_similarity_path = ( + # self.module_path / "debug_output" / "similarity_score.log" + # ) + # with open(str(debug_similarity_path), "w") as file: + # file.write(debug_str) + # # for key, value in opt_similarity_sorted: + # # file.write("\n---------------------------------\n") + # # file.write(key.opt_desc) + # # file.write("\n---------------------------------\n") + # # file.write(str(value)) + + return opt_most_similar, debug_str diff --git a/kernel_perf_agent/kernel_opt/utils/__init__.py b/kernel_perf_agent/kernel_opt/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/kernel_perf_agent/kernel_opt/utils/debug_util.py b/kernel_perf_agent/kernel_opt/utils/debug_util.py new file mode 100644 index 0000000..2324e38 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/utils/debug_util.py @@ -0,0 +1,13 @@ +"""Utility functions for debugging.""" + +from kernel_perf_agent.kernel_opt.configs.envs import DEBUG_MODE + + +def debug_print(prompt: str, response: str, conversation_length: int): + """Print debug information.""" + if DEBUG_MODE: + print("=" * 20 + " [DEBUG_MODE] START " + "=" * 20) + print(f"Prompt: \n{prompt}\n") + print(f"Response: \n{response}\n") + print(f"Conversation length: {conversation_length}") + print("=" * 20 + " [DEBUG_MODE] END " + "=" * 20) diff --git a/kernel_perf_agent/kernel_opt/utils/io_util.py b/kernel_perf_agent/kernel_opt/utils/io_util.py new file mode 100644 index 0000000..472da09 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/utils/io_util.py @@ -0,0 +1,16 @@ +"""Utility functions for file I/O operations.""" + +from pathlib import Path +from typing import Union + + +def write_file(path: Union[str, Path], content: str) -> None: + """Write content to a file. + + Args: + path: Path to the file + content: Content to write + """ + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(content) diff --git a/kernel_perf_agent/kernel_opt/utils/logging_util.py b/kernel_perf_agent/kernel_opt/utils/logging_util.py new file mode 100644 index 0000000..a60abf1 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/utils/logging_util.py @@ -0,0 +1,30 @@ +"""Utility functions for logging configuration and operations.""" + +import logging + + +def setup_basic_logging() -> None: + """Configure basic logging with a modern format.""" + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s | %(levelname)s | %(message)s", + datefmt="%H:%M:%S", + ) + + +def log_interaction( + logger: logging.Logger, prompt: str, response: str, conversation_length: int +) -> None: + """Log an interaction with the model. + + Args: + logger: Logger instance to use + prompt: The prompt sent to the model + response: The response from the model + conversation_length: Current length of the conversation + """ + logger.info( + f"Generated response: prompt={len(prompt)}, " + f"response={len(response)}, " + f"conversation_length={conversation_length}" + ) diff --git a/kernel_perf_agent/kernel_opt/utils/parser_util.py b/kernel_perf_agent/kernel_opt/utils/parser_util.py new file mode 100644 index 0000000..272c204 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/utils/parser_util.py @@ -0,0 +1,80 @@ +"""Utility functions for parsing and extracting code from text.""" + +import re +import ast +import inspect +from typing import Callable +from pathlib import Path + +def get_unwrapped_source(func): + # Get the full source (including decorator) + full_source = inspect.getsource(func) + + # Parse the AST + tree = ast.parse(full_source) + + # Get the function node (should be the first statement) + func_node = tree.body[0] + + # Remove decorators + func_node.decorator_list = [] + + # Convert back to source code + return ast.unparse(func_node) + +def get_module_path(torch_fn: Callable) -> Path: + module = inspect.getmodule(torch_fn) + if module is None: + raise ValueError("Could not determine module for function") + module_path = Path(module.__file__) + return module_path + +def remove_decorators_from_file(filepath: Path): + # with open(filepath, 'r') as f: + # source_code = f.read() + source_code = filepath.read_text() + + # Parse the AST + tree = ast.parse(source_code) + + # Remove kernel_opt decorators from all functions and classes + for node in ast.walk(tree): + # if isinstance(node, (ast.FunctionDef, ast.ClassDef)): + # node.decorator_list = [] # Clear the list of decorators + if isinstance(node, (ast.FunctionDef, ast.ClassDef)): + for decorator in node.decorator_list: + if isinstance(decorator, ast.Name) and decorator.id == "kernel_opt": + # print(f" - {decorator.id}") + node.decorator_list.remove(decorator) + elif isinstance(decorator, ast.Call): + if isinstance(decorator.func, ast.Name) and decorator.func.id == "kernel_opt": + node.decorator_list.remove(decorator) + # print(f" - {decorator.func.id} (with arguments)") + + # Convert back to source code + modified_code = ast.unparse(tree) + + return modified_code + +def extract_code(response_text: str, debug: bool) -> str: + """Extract code from response text with proper error handling. + + Args: + response_text: The text containing potential code blocks + + Returns: + The extracted code if found, otherwise the original text stripped + """ + code = "" + # Look for python code blocks + pattern = r"```python\n(.*?)\n```" + matches = re.findall(pattern, response_text, re.DOTALL) + + if matches: + # Return first code block + code = matches[0].strip() + else: + # No markdown found, return original text + code = response_text.strip() + + return code diff --git a/kernel_perf_agent/kernel_opt/utils/proxy_util.py b/kernel_perf_agent/kernel_opt/utils/proxy_util.py new file mode 100644 index 0000000..9b6ea47 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/utils/proxy_util.py @@ -0,0 +1,57 @@ +""" +Meta Proxy Configuration + +Authors: Laura Wang, Jie Liu +""" + +import os +import subprocess +from typing import Dict, Optional + + +def get_meta_proxy_config() -> Optional[Dict[str, str]]: + """ + Get Meta's proxy configuration if available. + + Returns: + Dictionary with proxy settings or None if not available + """ + try: + # Check if with-proxy command exists (Meta environment) + result = subprocess.run( + ["which", "with-proxy"], capture_output=True, text=True, timeout=5 + ) + if result.returncode != 0: + return None + + # Get proxy environment variables from with-proxy + result = subprocess.run( + ["with-proxy", "env"], capture_output=True, text=True, timeout=5 + ) + if result.returncode != 0: + return None + + # Parse proxy settings + proxy_config = {} + for line in result.stdout.split("\n"): + if "=" in line: + key, value = line.split("=", 1) + if key.lower() in ["http_proxy", "https_proxy"]: + proxy_config[key.lower()] = value + + return proxy_config if proxy_config else None + + except Exception: + return None + + +def devgpu_proxy_setup() -> Dict[str, str]: + original_proxy_env = {} + proxy_config = get_meta_proxy_config() + if proxy_config: + for key in ["HTTP_PROXY", "HTTPS_PROXY", "http_proxy", "https_proxy"]: + original_proxy_env[key] = os.environ.get(key) + proxy_url = proxy_config.get(key) + if proxy_url: + os.environ[key] = proxy_url + return original_proxy_env diff --git a/kernel_perf_agent/kernel_opt/verifier/__init__.py b/kernel_perf_agent/kernel_opt/verifier/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/kernel_perf_agent/kernel_opt/verifier/base.py b/kernel_perf_agent/kernel_opt/verifier/base.py new file mode 100644 index 0000000..eaedf3b --- /dev/null +++ b/kernel_perf_agent/kernel_opt/verifier/base.py @@ -0,0 +1,24 @@ +"""Base verification classes.""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Optional + + +@dataclass +class VerifyResult: + """Verification result.""" + + ok: bool + message: str + profile_data: Optional[Dict] = None + + +class Verifier(ABC): + """Abstract base class for verifiers.""" + + @abstractmethod + def verify(self, kernel_path: Path, kernel_code: str, test_code: bool) -> VerifyResult: + """Verify kernel implementation.""" + pass diff --git a/kernel_perf_agent/kernel_opt/verifier/benchmark.py b/kernel_perf_agent/kernel_opt/verifier/benchmark.py new file mode 100644 index 0000000..3b5f068 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/verifier/benchmark.py @@ -0,0 +1,54 @@ +"""Benchmark verifier.""" + +import importlib.util +import sys +import time +from pathlib import Path + +import torch + +from .base import Verifier, VerifyResult + + +class BenchmarkVerifier(Verifier): + """Verify kernel performance.""" + + def verify(self, kernel_path: Path) -> VerifyResult: + """Run benchmarks.""" + try: + # Import benchmark module + spec = importlib.util.spec_from_file_location( + "bench", kernel_path.parent / "bench.py" + ) + bench = importlib.util.module_from_spec(spec) + sys.modules["bench"] = bench + spec.loader.exec_module(bench) + + # Run benchmark + bench_fn = getattr(bench, f"bench_{kernel_path.parent.parent.name}") + + # Warmup + for _ in range(self.cfg["warmup"]): + bench_fn() + + # Benchmark + torch.cuda.synchronize() + start = time.time() + for _ in range(self.cfg["repetitions"]): + bench_fn() + torch.cuda.synchronize() + end = time.time() + + # Calculate metrics + avg_time = (end - start) / self.cfg["repetitions"] * 1000 # ms + + return VerifyResult( + ok=True, + message=f"Benchmark verification passed: {avg_time:.2f} ms", + profile_data={"avg_time_ms": avg_time}, + ) + + except Exception as e: + return VerifyResult( + ok=False, message=f"Benchmark verification failed: {str(e)}" + ) diff --git a/kernel_perf_agent/kernel_opt/verifier/numeric.py b/kernel_perf_agent/kernel_opt/verifier/numeric.py new file mode 100644 index 0000000..69b9022 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/verifier/numeric.py @@ -0,0 +1,57 @@ +"""Verify kernel numerical correctness.""" + +import subprocess +from pathlib import Path +import sys +import inspect +from .base import Verifier, VerifyResult +from kernel_perf_agent.kernel_opt.configs.envs import TIMEOUT + +class NumericVerifier(Verifier): + """Verify kernel numerical correctness.""" + + def verify(self, kernel_path: Path, kernel_code: str, test_code: bool) -> VerifyResult: + """Run numeric tests using pytest. + + Args: + build_dir: Path to the build directory containing test.py + + Returns: + VerifyResult indicating success or failure + """ + build_dir = kernel_path.parent + test_file_name = kernel_path.name + exec_code = kernel_code + + print("build_dir: ", build_dir) + if not test_code: + test_file_name = build_dir / "test.py" + try: + with open(test_file_name, 'r') as file: + exec_code = file.read() + except FileNotFoundError: + print(f"Error: The file '{test_file_name}' was not found.") + except Exception as e: + print(f"An error occurred: {e}") + + print(f"[debug] test_path: {test_file_name}") + + if not test_file_name.exists(): + return VerifyResult(ok=False, message=f"Test file not found: {test_file_name}") + + print("test_file_name: ", test_file_name) + try: + # Run the test in the build directory to make relative imports work + # subprocess.run( + # ["python", test_file_name], # Use relative path since we set cwd + # capture_output=True, + # text=True, + # check=True, + # timeout=TIMEOUT, + # cwd=str(build_dir), # Run in the build directory + # ) + # exec(exec_code) + subprocess.run(["python3", "-c", exec_code]) + return VerifyResult(ok=True, message="Numeric verification passed") + except subprocess.CalledProcessError as e: + return VerifyResult(ok=False, message=f"Test failed: {e.stderr}") diff --git a/kernel_perf_agent/kernel_opt/verifier/verifier.py b/kernel_perf_agent/kernel_opt/verifier/verifier.py new file mode 100644 index 0000000..b1353a5 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/verifier/verifier.py @@ -0,0 +1,51 @@ +"""Verification loop.""" + +from pathlib import Path +from typing import List, Optional +from dataclasses import dataclass +import py_compile + +from .base import Verifier, VerifyResult +from .numeric import NumericVerifier +from kernel_perf_agent.kernel_opt.utils.parser_util import get_unwrapped_source + +@dataclass +class VerifyResult: + """Verification result.""" + + ok: bool + message: str + +def KernelFileVerifier( + module_path: Path, +) -> VerifyResult: + """Run verification loop. + + Args: + module_path: Path to the kernel file + """ + + try: + py_compile.compile(module_path) + except py_compile.PyCompileError as e: + return VerifyResult(ok=False, message=f"Input compilation failed with syntax error: {e}") + + return VerifyResult(ok=True, message="Syntax verification passed.") + +def KernelCodeVerifier( + code: str, +) -> VerifyResult: + """Run verification loop. + + Args: + code: Code to be verified + """ + + try: + compiled_code = compile(code, '', 'exec') + except SyntaxError as e: + return VerifyResult(ok=False, message=f"Input compilation failed with syntax error: {e}") + except Exception as e: + return VerifyResult(ok=False, message=f"Input compilation failed with other errors: {e}") + + return VerifyResult(ok=True, message="Syntax verification passed.") diff --git a/pyproject.toml b/pyproject.toml index d03392c..eaffc12 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,9 +50,9 @@ kernel-agent = "triton_ui:main" "Repository" = "https://github.com/pytorch-labs/KernelAgent" [tool.setuptools.packages.find] -include = ["triton_kernel_agent*"] +include = ["triton_kernel_agent*", "kernel_perf_agent*"] [tool.pytest.ini_options] testpaths = ["tests"] python_files = ["test_*.py", "*_test.py"] -addopts = "-v --cov=triton_kernel_agent --cov-report=term-missing" \ No newline at end of file +addopts = "-v --cov=triton_kernel_agent --cov-report=term-missing" diff --git a/triton_kernel_agent/agent.py b/triton_kernel_agent/agent.py index 82f694c..10833e8 100644 --- a/triton_kernel_agent/agent.py +++ b/triton_kernel_agent/agent.py @@ -2,13 +2,14 @@ Main Triton Kernel Generation Agent. """ -import os import json +import logging +import os import re -from pathlib import Path -from typing import Optional, List, Dict, Any from datetime import datetime -import logging +from pathlib import Path +from typing import Any, Dict, List, Optional + from dotenv import load_dotenv from .manager import WorkerManager @@ -202,7 +203,7 @@ def _generate_test( # Call LLM API messages = [{"role": "user", "content": prompt}] - response_text = self._call_llm(messages, max_tokens=8192) + response_text = self._call_llm(messages, max_tokens=24576) # Extract test code from response test_code = self._extract_code_from_response(response_text) diff --git a/triton_kernel_agent/providers/openai_base.py b/triton_kernel_agent/providers/openai_base.py index 1b3347e..aed3963 100644 --- a/triton_kernel_agent/providers/openai_base.py +++ b/triton_kernel_agent/providers/openai_base.py @@ -2,9 +2,10 @@ Base provider for OpenAI-compatible APIs. """ -from typing import List, Dict, Any, Optional -from .base import BaseProvider, LLMResponse +from typing import Any, Dict, List, Optional + from ..utils import configure_proxy_environment +from .base import BaseProvider, LLMResponse try: from openai import OpenAI @@ -54,9 +55,11 @@ def get_response( content=response.choices[0].message.content, model=model_name, provider=self.name, - usage=response.usage.dict() - if hasattr(response, "usage") and response.usage - else None, + usage=( + response.usage.dict() + if hasattr(response, "usage") and response.usage + else None + ), ) def get_multiple_responses( @@ -74,9 +77,11 @@ def get_multiple_responses( content=choice.message.content, model=model_name, provider=self.name, - usage=response.usage.dict() - if hasattr(response, "usage") and response.usage - else None, + usage=( + response.usage.dict() + if hasattr(response, "usage") and response.usage + else None + ), ) for choice in response.choices ] diff --git a/triton_kernel_agent/worker.py b/triton_kernel_agent/worker.py index 530900d..3ddd886 100644 --- a/triton_kernel_agent/worker.py +++ b/triton_kernel_agent/worker.py @@ -16,6 +16,12 @@ from .prompt_manager import PromptManager from .providers import get_model_provider +from kernel_perf_agent.kernel_opt.database.base import OptHierarchy +from kernel_perf_agent.kernel_opt.retriever.retriever import Retriever +from kernel_perf_agent.kernel_opt.prompts.prompt_manager import ( + PromptManager as PerfPromptManager, +) + class VerificationWorker: """Worker that verifies and refines a single kernel implementation.""" @@ -30,6 +36,9 @@ def __init__( openai_api_key: Optional[str] = None, openai_model: str = "o3-2025-04-16", high_reasoning_effort: bool = True, + enable_optimization: bool = True, # Flag to enable optimization phase + optimization_hint: Optional[str] = None, # e.g., "use persistent programming" + optimization_database_path: Optional[Path] = None, # Path to code_samples ): """ Initialize a verification worker. @@ -59,6 +68,9 @@ def __init__( # History for LLM context self.history = deque(maxlen=history_size) + # Setup logging FIRST (before any logger.xxx() calls) + self._setup_logging() + # Initialize provider self.provider = None try: @@ -70,8 +82,20 @@ def __init__( # Initialize prompt manager self.prompt_manager = PromptManager() - # Setup logging - self._setup_logging() + # NEW: Optimization setup + self.enable_optimization = enable_optimization + self.optimization_hint = optimization_hint or "optimize for performance" + + # Initialize optimization database if enabled + self.opt_hierarchy = None + if self.enable_optimization: + self.opt_hierarchy = OptHierarchy() + db_path = optimization_database_path or ( + Path(__file__).parent.parent + / "kernel_perf_agent/kernel_opt/database/code_samples" + ) + self.opt_hierarchy.hard_initialize(db_path) + self.logger.info("Initialized optimization database") def _setup_logging(self): """Setup worker-specific logging.""" @@ -247,7 +271,7 @@ def _refine_kernel( # Call LLM API messages = [{"role": "user", "content": prompt}] - response_text = self._call_llm(messages, max_tokens=8192) + response_text = self._call_llm(messages, max_tokens=24576) # Extract refined kernel from response refined_kernel = self._extract_code_from_response(response_text) @@ -276,6 +300,191 @@ def _refine_kernel( return kernel_code + def _optimize_kernel( + self, + kernel_code: str, + problem_description: str, + test_code: str, + max_opt_rounds: int = 3, + additional_code: Optional[str] = None, + ) -> Tuple[bool, str]: + """ + Optimize a working kernel using RAG-based pattern retrieval. + + Args: + kernel_code: Working kernel to optimize + problem_description: Original problem description + test_code: Test code to verify correctness + max_opt_rounds: Maximum optimization attempts + + Returns: + Tuple of (success, optimized_kernel_code) + """ + if not self.enable_optimization or not self.opt_hierarchy: + return False, kernel_code + + self.logger.info("Starting optimization phase") + + self.dsl = "triton" + self.kernel_name = "triton_kernel" + + try: + # Step 1: RAG Retrieval + self.logger.info("Step 1: Retrieving optimization pattern from database") + retriever = Retriever( + func_prompt=problem_description, + opt_prompt=self.optimization_hint, + model=self.openai_model, + dsl=self.dsl, + kernel_name=self.kernel_name, + database=self.opt_hierarchy, + module_path=self.workdir, + debug=True, + ) + + opt_node, debug_info = retriever.retrieve() + self.logger.info( + f"Retrieved optimization pattern: {opt_node.opt_desc[:100]}..." + ) + # Step 2: Build optimization prompt using PerfPromptManager + self.logger.info("Step 2: Building optimization prompt") + perf_prompt_manager = PerfPromptManager( + func_source_code=kernel_code, + func_prompt=problem_description, + opt_prompt=self.optimization_hint, + model=self.openai_model, + dsl=self.dsl, + kernel_name=self.kernel_name, + database=self.opt_hierarchy, + opt_node=opt_node, + module_path=self.workdir, + debug=True, + ) + + opt_prompt, debug_str = perf_prompt_manager.build_rewrite_prompt() + self.logger.info( + f"Optimization prompt built successfully: {opt_prompt[:100]}..." + ) + + # Step 3: Try optimization with multiple rounds + best_kernel = kernel_code + best_perf = None + error_feedback = "" + + for opt_round in range(max_opt_rounds): + self.logger.info(f"Optimization round {opt_round + 1}/{max_opt_rounds}") + + # Build current prompt with error feedback if available + current_prompt = opt_prompt + if error_feedback: + current_prompt = f"""{error_feedback} + +Please fix the issues in the previous attempt and generate a corrected optimized kernel. + +{opt_prompt} +""" + + # Step 3a: Call LLM (same pattern as _refine_kernel) + messages = [{"role": "user", "content": current_prompt}] + try: + response_text = self._call_llm(messages, max_tokens=24576) + except Exception as e: + self.logger.error(f"LLM call failed: {e}") + error_feedback = ( + f"Previous attempt failed: LLM call error - {str(e)}" + ) + continue + + # Step 3b: Extract code (same pattern as _refine_kernel) + optimized_kernel = self._extract_code_from_response(response_text) + + if not optimized_kernel or len(optimized_kernel) < 100: + self.logger.warning( + f"Failed to extract valid optimized kernel (length: {len(optimized_kernel) if optimized_kernel else 0})" + ) + error_feedback = "Previous attempt failed: No valid kernel code extracted from LLM response. Please provide complete Triton kernel code wrapped in ```python code blocks." + continue + + # Step 4: Verify correctness by running tests + self.logger.info("Testing optimized kernel...") + self._write_kernel(optimized_kernel) + success, stdout, stderr = self._run_test() + + if not success: + self.logger.warning( + f"Optimized kernel failed tests: {stderr[:200]}" + ) + error_feedback = f"""Previous optimization attempt FAILED with error: +{stderr[:500]} + +The kernel must: +1. Pass all correctness tests +2. Maintain the same interface as the original kernel +3. Be syntactically valid Python/Triton code +""" + continue + + # Step 5: Passed tests! Extract performance metrics + self.logger.info("✅ Optimized kernel passed tests!") + perf_metrics = self._extract_performance_metrics(stdout) + + if perf_metrics: + speedup = perf_metrics.get("speedup", 0) + self.logger.info(f"Performance metrics: {perf_metrics}") + + # Update best if this is better + if best_perf is None or speedup > best_perf.get("speedup", 0): + best_kernel = optimized_kernel + best_perf = perf_metrics + self.logger.info(f"🎉 New best speedup: {speedup:.2f}x") + error_feedback = "" # Clear error for next iteration + else: + self.logger.info( + f"Speedup {speedup:.2f}x not better than best {best_perf.get('speedup', 0):.2f}x" + ) + else: + # No metrics available, accept first working optimization + self.logger.info( + "No performance metrics found, accepting optimized kernel" + ) + best_kernel = optimized_kernel + break + + # After all rounds, restore original kernel file + self._write_kernel(kernel_code) + + # Return best result + if best_kernel != kernel_code: + self.logger.info("✅ Optimization successful!") + if best_perf: + self.logger.info( + f"Final speedup: {best_perf.get('speedup', 'N/A'):.2f}x" + ) + return True, best_kernel + else: + self.logger.info("No improvement found, keeping original") + return False, kernel_code + + except Exception as e: + self.logger.error(f"Optimization failed: {e}") + import traceback + + self.logger.error(traceback.format_exc()) + return False, kernel_code + + def _extract_performance_metrics(self, stdout: str) -> Optional[Dict[str, float]]: + """ + Extract performance metrics from test output. + Looks for: PERF_METRICS:{"triton_ms": X, "pytorch_ms": Y, "speedup": Z} + """ + try: + match = re.search(r"PERF_METRICS:(\{[^}]+\})", stdout) + if match: + return json.loads(match.group(1)) + except Exception as e: + self.logger.warning(f"Failed to extract metrics: {e}") + return None + def _log_round( self, round_num: int, success: bool, kernel_code: str, stdout: str, stderr: str ): @@ -322,6 +531,7 @@ def run( current_kernel = kernel_code + # PHASE 1: Generation & Correctness (existing code) for round_num in range(self.max_rounds): # Check if another worker has succeeded if success_event.is_set(): @@ -353,6 +563,32 @@ def run( self.logger.info( f"Success! Kernel passed test in round {round_num + 1}" ) + + # PHASE 2: Optimization if enabled + if self.enable_optimization: + self.logger.info("Entering optimization phase...") + opt_success, optimized_kernel = self._optimize_kernel( + kernel_code=current_kernel, + problem_description=problem_description, + test_code=test_code, + additional_code=additional_code, + ) + + if opt_success: + current_kernel = optimized_kernel + self.logger.info("Using optimized kernel") + else: + self.logger.info("Using original working kernel") + + return { + "worker_id": self.worker_id, + "success": True, + "kernel_code": current_kernel, + "rounds": round_num + 1, + "optimized": self.enable_optimization and opt_success, # NEW + "history": list(self.history), + } + return { "worker_id": self.worker_id, "success": True,