From b69b35abab28ab0a8afcd072447ee5b92087ac04 Mon Sep 17 00:00:00 2001 From: m96-chan Date: Tue, 30 Dec 2025 15:33:03 +0900 Subject: [PATCH] refactor(matmul): reorganize kernel directory structure (#122) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements Option 2 from Issue #122: explicit naming convention w{weight}a{act}_{out}/ for clearer kernel identification. Directory mapping: - gemm/fp8/bf16/ -> gemm/w8a16_bf16/ (FP8 weight, BF16 activation) - gemm/fp8/fp8/ -> gemm/w8a8_bf16/ (pure FP8) - gemm/nvf4/bf16/ -> gemm/w4a16_bf16/ (NVF4 weight) - gemv/bf16/bf16/nvf4* -> gemv/w4a16_bf16/ - gemv/bf16/bf16/fp8* -> gemv/w8a16_bf16/ - gemv/fp8/fp8/ -> gemv/w8a8_bf16/ - gemv/nvf4/nvf4/ -> gemv/w4a4_bf16/ Updated: - CMakeLists.txt with new paths - Include paths in source files - CLAUDE.md with new naming convention docs Build verified: SM 120a CUDA 13.1 SUCCESS 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- CLAUDE.md | 35 ++++++++------ native/CMakeLists.txt | 48 +++++++++---------- .../bf16 => bf16_bf16}/generic/bf16_naive.cuh | 2 +- .../bf16 => bf16_bf16}/generic/bf16_wmma.cuh | 2 +- .../generic/bf16_wmma_generic.cuh | 2 +- .../bf16 => bf16_bf16}/sm100/bf16_cutlass.cuh | 0 .../bf16 => bf16_bf16}/sm120/bf16_cutlass.cuh | 0 .../bf16 => bf16_bf16}/sm80/bf16_cutlass.cuh | 0 .../bf16 => bf16_bf16}/sm90/bf16_cutlass.cuh | 0 .../f32 => f32_f32}/generic/f32_ampere.cu | 0 .../f32 => f32_f32}/generic/f32_ampere.cuh | 2 +- .../f32 => f32_f32}/generic/f32_naive.cuh | 2 +- .../{f32/f32 => f32_f32}/generic/tf32_mma.cuh | 2 +- .../f32 => f32_f32}/generic/tf32_wmma.cuh | 2 +- .../int4 => int4_int4}/sm120/int4_via_int8.cu | 2 +- .../int8 => int8_int8}/sm120/int8_native.cu | 0 .../bf16 => w4a16_bf16}/sm120/nvf4_cutlass.cu | 0 .../sm120/nvf4_nvf4_cutlass.cu | 0 .../sm120/fp8_blockwise.cu | 2 +- .../bf16 => w8a16_bf16}/sm120/grouped_gemm.cu | 0 .../sm120/w8a16_cutlass.cu | 2 +- .../bf16 => w8a16_bf16}/sm120/w8a16_gemm.cu | 0 .../fp8 => w8a8_bf16}/sm120/fp8_cutlass.cu | 2 +- .../fp8 => w8a8_bf16}/sm120/fp8_cutlass_v2.cu | 2 +- .../fp8 => w8a8_bf16}/sm120/fp8_cutlass_v3.cu | 2 +- .../f32 => w8a8_f32}/sm100/fp8_blockwise.cu | 0 .../{fp8/f32 => w8a8_f32}/sm90/fp8_cutlass.cu | 0 .../generic/bf16_cutlass.cuh | 0 .../bf16 => bf16_bf16}/sm120/bf16_opt.cu | 0 .../bf16 => bf16_bf16}/sm120/bf16_opt.cuh | 0 .../int4 => int4_int4}/sm120/int4_gemv.cu | 0 .../int4 => int4_int4}/sm120/int4_gemv.cuh | 0 .../{bf16/bf16 => w4a16_bf16}/sm120/nvf4.cu | 0 .../{bf16/bf16 => w4a16_bf16}/sm120/nvf4.cuh | 0 .../bf16 => w4a16_bf16}/sm120/nvf4_kernels.cu | 0 .../nvf4 => w4a4_bf16}/sm120/nvf4_gemv.cu | 0 .../nvf4 => w4a4_bf16}/sm120/nvf4_gemv.cuh | 0 .../{bf16/bf16 => w8a16_bf16}/sm120/fp8.cuh | 0 .../bf16 => w8a16_bf16}/sm120/fp8_opt.cuh | 0 .../sm120/fp8_opt_kernels.cu | 0 .../fp8 => w8a8_bf16}/sm120/fp8_accurate.cu | 0 .../fp8 => w8a8_bf16}/sm120/fp8_accurate.cuh | 0 .../{fp8/fp8 => w8a8_bf16}/sm120/fp8_gemv.cu | 0 .../{fp8/fp8 => w8a8_bf16}/sm120/fp8_gemv.cuh | 0 native/ops/matmul/matmul.cu | 18 +++---- native/ops/matmul/matmul_cutlass.cu | 2 +- 46 files changed, 68 insertions(+), 61 deletions(-) rename native/ops/matmul/gemm/{bf16/bf16 => bf16_bf16}/generic/bf16_naive.cuh (98%) rename native/ops/matmul/gemm/{bf16/bf16 => bf16_bf16}/generic/bf16_wmma.cuh (99%) rename native/ops/matmul/gemm/{bf16/bf16 => bf16_bf16}/generic/bf16_wmma_generic.cuh (99%) rename native/ops/matmul/gemm/{bf16/bf16 => bf16_bf16}/sm100/bf16_cutlass.cuh (100%) rename native/ops/matmul/gemm/{bf16/bf16 => bf16_bf16}/sm120/bf16_cutlass.cuh (100%) rename native/ops/matmul/gemm/{bf16/bf16 => bf16_bf16}/sm80/bf16_cutlass.cuh (100%) rename native/ops/matmul/gemm/{bf16/bf16 => bf16_bf16}/sm90/bf16_cutlass.cuh (100%) rename native/ops/matmul/gemm/{f32/f32 => f32_f32}/generic/f32_ampere.cu (100%) rename native/ops/matmul/gemm/{f32/f32 => f32_f32}/generic/f32_ampere.cuh (99%) rename native/ops/matmul/gemm/{f32/f32 => f32_f32}/generic/f32_naive.cuh (99%) rename native/ops/matmul/gemm/{f32/f32 => f32_f32}/generic/tf32_mma.cuh (99%) rename native/ops/matmul/gemm/{f32/f32 => f32_f32}/generic/tf32_wmma.cuh (99%) rename native/ops/matmul/gemm/{int4/int4 => int4_int4}/sm120/int4_via_int8.cu (99%) rename native/ops/matmul/gemm/{int8/int8 => int8_int8}/sm120/int8_native.cu (100%) rename native/ops/matmul/gemm/{nvf4/bf16 => w4a16_bf16}/sm120/nvf4_cutlass.cu (100%) rename native/ops/matmul/gemm/{nvf4/bf16 => w4a16_bf16}/sm120/nvf4_nvf4_cutlass.cu (100%) rename native/ops/matmul/gemm/{fp8/bf16 => w8a16_bf16}/sm120/fp8_blockwise.cu (99%) rename native/ops/matmul/gemm/{fp8/bf16 => w8a16_bf16}/sm120/grouped_gemm.cu (100%) rename native/ops/matmul/gemm/{fp8/bf16 => w8a16_bf16}/sm120/w8a16_cutlass.cu (99%) rename native/ops/matmul/gemm/{fp8/bf16 => w8a16_bf16}/sm120/w8a16_gemm.cu (100%) rename native/ops/matmul/gemm/{fp8/fp8 => w8a8_bf16}/sm120/fp8_cutlass.cu (99%) rename native/ops/matmul/gemm/{fp8/fp8 => w8a8_bf16}/sm120/fp8_cutlass_v2.cu (99%) rename native/ops/matmul/gemm/{fp8/fp8 => w8a8_bf16}/sm120/fp8_cutlass_v3.cu (99%) rename native/ops/matmul/gemm/{fp8/f32 => w8a8_f32}/sm100/fp8_blockwise.cu (100%) rename native/ops/matmul/gemm/{fp8/f32 => w8a8_f32}/sm90/fp8_cutlass.cu (100%) rename native/ops/matmul/gemv/{bf16/bf16 => bf16_bf16}/generic/bf16_cutlass.cuh (100%) rename native/ops/matmul/gemv/{bf16/bf16 => bf16_bf16}/sm120/bf16_opt.cu (100%) rename native/ops/matmul/gemv/{bf16/bf16 => bf16_bf16}/sm120/bf16_opt.cuh (100%) rename native/ops/matmul/gemv/{int4/int4 => int4_int4}/sm120/int4_gemv.cu (100%) rename native/ops/matmul/gemv/{int4/int4 => int4_int4}/sm120/int4_gemv.cuh (100%) rename native/ops/matmul/gemv/{bf16/bf16 => w4a16_bf16}/sm120/nvf4.cu (100%) rename native/ops/matmul/gemv/{bf16/bf16 => w4a16_bf16}/sm120/nvf4.cuh (100%) rename native/ops/matmul/gemv/{bf16/bf16 => w4a16_bf16}/sm120/nvf4_kernels.cu (100%) rename native/ops/matmul/gemv/{nvf4/nvf4 => w4a4_bf16}/sm120/nvf4_gemv.cu (100%) rename native/ops/matmul/gemv/{nvf4/nvf4 => w4a4_bf16}/sm120/nvf4_gemv.cuh (100%) rename native/ops/matmul/gemv/{bf16/bf16 => w8a16_bf16}/sm120/fp8.cuh (100%) rename native/ops/matmul/gemv/{bf16/bf16 => w8a16_bf16}/sm120/fp8_opt.cuh (100%) rename native/ops/matmul/gemv/{bf16/bf16 => w8a16_bf16}/sm120/fp8_opt_kernels.cu (100%) rename native/ops/matmul/gemv/{fp8/fp8 => w8a8_bf16}/sm120/fp8_accurate.cu (100%) rename native/ops/matmul/gemv/{fp8/fp8 => w8a8_bf16}/sm120/fp8_accurate.cuh (100%) rename native/ops/matmul/gemv/{fp8/fp8 => w8a8_bf16}/sm120/fp8_gemv.cu (100%) rename native/ops/matmul/gemv/{fp8/fp8 => w8a8_bf16}/sm120/fp8_gemv.cuh (100%) diff --git a/CLAUDE.md b/CLAUDE.md index 5d8bbac..66e6722 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -73,31 +73,38 @@ native/ops/matmul/ ├── common/ # Shared utilities │ └── aligned_copy_sm120.cuh ├── gemm/ # GEMM kernels (M > 1) -│ └── {input_dtype}/{output_dtype}/{arch}/{compute}_{suffix}.{cu,cuh} +│ └── {w_dtype}_{a_dtype}_{out_dtype}/{arch}/{kernel}.{cu,cuh} ├── gemv/ # GEMV kernels (M = 1) -│ └── {input_dtype}/{output_dtype}/{arch}/{compute}_{suffix}.{cu,cuh} +│ └── {w_dtype}_{a_dtype}_{out_dtype}/{arch}/{kernel}.{cu,cuh} ├── cublaslt.cuh # cuBLASLt wrapper ├── matmul.cu # Main dispatcher └── matmul_cutlass.cu # CUTLASS dispatcher ``` -**Path Convention:** `{gemm|gemv}/{input_dtype}/{output_dtype}/{arch}/{compute}_{suffix}.cu` +**Path Convention:** `{gemm|gemv}/{w{weight}a{act}_{out}}/{arch}/{kernel}.cu` -| Component | Values | Examples | -|-----------|--------|----------| -| `input_dtype` | `f32`, `bf16`, `fp8`, `nvf4` | Input tensor dtype | -| `output_dtype` | `f32`, `bf16`, `fp8` | Output tensor dtype | +| Component | Values | Description | +|-----------|--------|-------------| +| `w_dtype` | `w4`, `w8`, `bf16`, `f32`, `int4`, `int8` | Weight dtype (w=weight) | +| `a_dtype` | `a4`, `a8`, `a16`, `bf16`, `f32`, `int4`, `int8` | Activation dtype (a=act) | +| `out_dtype` | `bf16`, `f32` | Output dtype | | `arch` | `generic`, `sm80`, `sm90`, `sm100`, `sm120` | Target architecture | -| `compute` | `naive`, `wmma`, `mma`, `cutlass` | Compute method | -| `suffix` | `blockwise`, `kernels`, etc. | Variant identifier | + +**Naming Rationale (Issue #122 Option 2):** +- `w8a16_bf16`: FP8 weights, BF16 activations, BF16 output (W8A16 GEMM) +- `w4a16_bf16`: NVF4 weights, BF16 activations, BF16 output (NVF4 GEMV) +- `w8a8_bf16`: FP8 weights, FP8 activations, BF16 output (pure FP8) +- `bf16_bf16`: BF16 weights, BF16 activations (no quantization) +- `f32_f32`: FP32 weights, FP32 activations (baseline) **Examples:** ``` -gemm/bf16/bf16/sm120/bf16_cutlass.cuh # BF16->BF16 GEMM, SM120, CUTLASS -gemm/fp8/f32/sm90/fp8_cutlass.cu # FP8->F32 GEMM, SM90, CUTLASS -gemm/nvf4/bf16/sm120/nvf4_cutlass.cu # NVF4->BF16 GEMM, SM120, CUTLASS -gemv/bf16/bf16/sm120/nvf4.cu # NVF4->BF16 GEMV, SM120 -gemm/f32/f32/generic/tf32_mma.cuh # TF32 GEMM, generic (SM80+) +gemm/bf16_bf16/sm80/bf16_cutlass.cuh # BF16 GEMM, SM80, CUTLASS +gemm/w8a8_f32/sm90/fp8_cutlass.cu # FP8->F32 GEMM, SM90, CUTLASS +gemm/w4a16_bf16/sm120/nvf4_cutlass.cu # NVF4 weights, BF16 act->BF16, SM120 +gemv/w4a16_bf16/sm120/nvf4.cu # NVF4 GEMV, SM120 +gemv/w8a16_bf16/sm120/fp8_opt_kernels.cu # FP8 weight, BF16 act GEMV, SM120 +gemm/f32_f32/generic/tf32_mma.cuh # TF32 GEMM, generic (SM80+) ``` ### Module Separation Policy diff --git a/native/CMakeLists.txt b/native/CMakeLists.txt index 6540f9f..ed0f604 100644 --- a/native/CMakeLists.txt +++ b/native/CMakeLists.txt @@ -153,30 +153,30 @@ pybind11_add_module(${MODULE_NAME} ops/reduction/reduction.cu ops/matmul/matmul.cu ops/matmul/matmul_cutlass.cu - # GEMM kernels - ops/matmul/gemm/f32/f32/generic/f32_ampere.cu - ops/matmul/gemm/fp8/f32/sm90/fp8_cutlass.cu - ops/matmul/gemm/fp8/f32/sm100/fp8_blockwise.cu - ops/matmul/gemm/fp8/bf16/sm120/fp8_blockwise.cu - ops/matmul/gemm/fp8/bf16/sm120/w8a16_gemm.cu - ops/matmul/gemm/fp8/bf16/sm120/w8a16_cutlass.cu - ops/matmul/gemm/fp8/bf16/sm120/grouped_gemm.cu - ops/matmul/gemm/fp8/fp8/sm120/fp8_cutlass.cu - ops/matmul/gemm/fp8/fp8/sm120/fp8_cutlass_v2.cu - ops/matmul/gemm/fp8/fp8/sm120/fp8_cutlass_v3.cu - ops/matmul/gemm/int8/int8/sm120/int8_native.cu - ops/matmul/gemm/int4/int4/sm120/int4_via_int8.cu - ops/matmul/gemm/nvf4/bf16/sm120/nvf4_cutlass.cu - ops/matmul/gemm/nvf4/bf16/sm120/nvf4_nvf4_cutlass.cu - # GEMV kernels - ops/matmul/gemv/bf16/bf16/sm120/nvf4.cu - ops/matmul/gemv/bf16/bf16/sm120/nvf4_kernels.cu - ops/matmul/gemv/bf16/bf16/sm120/fp8_opt_kernels.cu - ops/matmul/gemv/bf16/bf16/sm120/bf16_opt.cu - ops/matmul/gemv/fp8/fp8/sm120/fp8_gemv.cu - ops/matmul/gemv/fp8/fp8/sm120/fp8_accurate.cu - ops/matmul/gemv/nvf4/nvf4/sm120/nvf4_gemv.cu - ops/matmul/gemv/int4/int4/sm120/int4_gemv.cu + # GEMM kernels (Issue #122: Reorganized with w{weight}a{act}_{out} naming) + ops/matmul/gemm/f32_f32/generic/f32_ampere.cu + ops/matmul/gemm/w8a8_f32/sm90/fp8_cutlass.cu + ops/matmul/gemm/w8a8_f32/sm100/fp8_blockwise.cu + ops/matmul/gemm/w8a16_bf16/sm120/fp8_blockwise.cu + ops/matmul/gemm/w8a16_bf16/sm120/w8a16_gemm.cu + ops/matmul/gemm/w8a16_bf16/sm120/w8a16_cutlass.cu + ops/matmul/gemm/w8a16_bf16/sm120/grouped_gemm.cu + ops/matmul/gemm/w8a8_bf16/sm120/fp8_cutlass.cu + ops/matmul/gemm/w8a8_bf16/sm120/fp8_cutlass_v2.cu + ops/matmul/gemm/w8a8_bf16/sm120/fp8_cutlass_v3.cu + ops/matmul/gemm/int8_int8/sm120/int8_native.cu + ops/matmul/gemm/int4_int4/sm120/int4_via_int8.cu + ops/matmul/gemm/w4a16_bf16/sm120/nvf4_cutlass.cu + ops/matmul/gemm/w4a16_bf16/sm120/nvf4_nvf4_cutlass.cu + # GEMV kernels (Issue #122: Reorganized with w{weight}a{act}_{out} naming) + ops/matmul/gemv/w4a16_bf16/sm120/nvf4.cu + ops/matmul/gemv/w4a16_bf16/sm120/nvf4_kernels.cu + ops/matmul/gemv/w8a16_bf16/sm120/fp8_opt_kernels.cu + ops/matmul/gemv/bf16_bf16/sm120/bf16_opt.cu + ops/matmul/gemv/w8a8_bf16/sm120/fp8_gemv.cu + ops/matmul/gemv/w8a8_bf16/sm120/fp8_accurate.cu + ops/matmul/gemv/w4a4_bf16/sm120/nvf4_gemv.cu + ops/matmul/gemv/int4_int4/sm120/int4_gemv.cu ops/nn/nn.cu ops/quantize/quantize.cu ops/attention/paged_attention.cu diff --git a/native/ops/matmul/gemm/bf16/bf16/generic/bf16_naive.cuh b/native/ops/matmul/gemm/bf16_bf16/generic/bf16_naive.cuh similarity index 98% rename from native/ops/matmul/gemm/bf16/bf16/generic/bf16_naive.cuh rename to native/ops/matmul/gemm/bf16_bf16/generic/bf16_naive.cuh index 7d59bfb..98d0be0 100644 --- a/native/ops/matmul/gemm/bf16/bf16/generic/bf16_naive.cuh +++ b/native/ops/matmul/gemm/bf16_bf16/generic/bf16_naive.cuh @@ -14,7 +14,7 @@ #include #include #include -#include "../../../../../../core/cuda_graph.hpp" +#include "../../../../../core/cuda_graph.hpp" namespace pygpukit { namespace ops { diff --git a/native/ops/matmul/gemm/bf16/bf16/generic/bf16_wmma.cuh b/native/ops/matmul/gemm/bf16_bf16/generic/bf16_wmma.cuh similarity index 99% rename from native/ops/matmul/gemm/bf16/bf16/generic/bf16_wmma.cuh rename to native/ops/matmul/gemm/bf16_bf16/generic/bf16_wmma.cuh index 4324778..ab8ecbe 100644 --- a/native/ops/matmul/gemm/bf16/bf16/generic/bf16_wmma.cuh +++ b/native/ops/matmul/gemm/bf16_bf16/generic/bf16_wmma.cuh @@ -14,7 +14,7 @@ #include #include #include -#include "../../../../../../core/cuda_graph.hpp" +#include "../../../../../core/cuda_graph.hpp" namespace pygpukit { namespace ops { diff --git a/native/ops/matmul/gemm/bf16/bf16/generic/bf16_wmma_generic.cuh b/native/ops/matmul/gemm/bf16_bf16/generic/bf16_wmma_generic.cuh similarity index 99% rename from native/ops/matmul/gemm/bf16/bf16/generic/bf16_wmma_generic.cuh rename to native/ops/matmul/gemm/bf16_bf16/generic/bf16_wmma_generic.cuh index 98b68fa..bcf89cc 100644 --- a/native/ops/matmul/gemm/bf16/bf16/generic/bf16_wmma_generic.cuh +++ b/native/ops/matmul/gemm/bf16_bf16/generic/bf16_wmma_generic.cuh @@ -12,7 +12,7 @@ #include #include #include -#include "../../../../../../core/cuda_graph.hpp" +#include "../../../../../core/cuda_graph.hpp" namespace pygpukit { namespace ops { diff --git a/native/ops/matmul/gemm/bf16/bf16/sm100/bf16_cutlass.cuh b/native/ops/matmul/gemm/bf16_bf16/sm100/bf16_cutlass.cuh similarity index 100% rename from native/ops/matmul/gemm/bf16/bf16/sm100/bf16_cutlass.cuh rename to native/ops/matmul/gemm/bf16_bf16/sm100/bf16_cutlass.cuh diff --git a/native/ops/matmul/gemm/bf16/bf16/sm120/bf16_cutlass.cuh b/native/ops/matmul/gemm/bf16_bf16/sm120/bf16_cutlass.cuh similarity index 100% rename from native/ops/matmul/gemm/bf16/bf16/sm120/bf16_cutlass.cuh rename to native/ops/matmul/gemm/bf16_bf16/sm120/bf16_cutlass.cuh diff --git a/native/ops/matmul/gemm/bf16/bf16/sm80/bf16_cutlass.cuh b/native/ops/matmul/gemm/bf16_bf16/sm80/bf16_cutlass.cuh similarity index 100% rename from native/ops/matmul/gemm/bf16/bf16/sm80/bf16_cutlass.cuh rename to native/ops/matmul/gemm/bf16_bf16/sm80/bf16_cutlass.cuh diff --git a/native/ops/matmul/gemm/bf16/bf16/sm90/bf16_cutlass.cuh b/native/ops/matmul/gemm/bf16_bf16/sm90/bf16_cutlass.cuh similarity index 100% rename from native/ops/matmul/gemm/bf16/bf16/sm90/bf16_cutlass.cuh rename to native/ops/matmul/gemm/bf16_bf16/sm90/bf16_cutlass.cuh diff --git a/native/ops/matmul/gemm/f32/f32/generic/f32_ampere.cu b/native/ops/matmul/gemm/f32_f32/generic/f32_ampere.cu similarity index 100% rename from native/ops/matmul/gemm/f32/f32/generic/f32_ampere.cu rename to native/ops/matmul/gemm/f32_f32/generic/f32_ampere.cu diff --git a/native/ops/matmul/gemm/f32/f32/generic/f32_ampere.cuh b/native/ops/matmul/gemm/f32_f32/generic/f32_ampere.cuh similarity index 99% rename from native/ops/matmul/gemm/f32/f32/generic/f32_ampere.cuh rename to native/ops/matmul/gemm/f32_f32/generic/f32_ampere.cuh index 2a6586c..45c9f3a 100644 --- a/native/ops/matmul/gemm/f32/f32/generic/f32_ampere.cuh +++ b/native/ops/matmul/gemm/f32_f32/generic/f32_ampere.cuh @@ -18,7 +18,7 @@ #include #include -#include "../../../../../../core/cuda_graph.hpp" +#include "../../../../../core/cuda_graph.hpp" namespace pygpukit { namespace ops { diff --git a/native/ops/matmul/gemm/f32/f32/generic/f32_naive.cuh b/native/ops/matmul/gemm/f32_f32/generic/f32_naive.cuh similarity index 99% rename from native/ops/matmul/gemm/f32/f32/generic/f32_naive.cuh rename to native/ops/matmul/gemm/f32_f32/generic/f32_naive.cuh index d8afc27..5065ba5 100644 --- a/native/ops/matmul/gemm/f32/f32/generic/f32_naive.cuh +++ b/native/ops/matmul/gemm/f32_f32/generic/f32_naive.cuh @@ -10,7 +10,7 @@ #include #include -#include "../../../../../../core/cuda_graph.hpp" +#include "../../../../../core/cuda_graph.hpp" namespace pygpukit { namespace ops { diff --git a/native/ops/matmul/gemm/f32/f32/generic/tf32_mma.cuh b/native/ops/matmul/gemm/f32_f32/generic/tf32_mma.cuh similarity index 99% rename from native/ops/matmul/gemm/f32/f32/generic/tf32_mma.cuh rename to native/ops/matmul/gemm/f32_f32/generic/tf32_mma.cuh index ace60ac..8df7ded 100644 --- a/native/ops/matmul/gemm/f32/f32/generic/tf32_mma.cuh +++ b/native/ops/matmul/gemm/f32_f32/generic/tf32_mma.cuh @@ -11,7 +11,7 @@ #pragma once #include #include -#include "../../../../../../core/cuda_graph.hpp" +#include "../../../../../core/cuda_graph.hpp" namespace pygpukit { namespace ops { diff --git a/native/ops/matmul/gemm/f32/f32/generic/tf32_wmma.cuh b/native/ops/matmul/gemm/f32_f32/generic/tf32_wmma.cuh similarity index 99% rename from native/ops/matmul/gemm/f32/f32/generic/tf32_wmma.cuh rename to native/ops/matmul/gemm/f32_f32/generic/tf32_wmma.cuh index 15050b3..c505880 100644 --- a/native/ops/matmul/gemm/f32/f32/generic/tf32_wmma.cuh +++ b/native/ops/matmul/gemm/f32_f32/generic/tf32_wmma.cuh @@ -1,7 +1,7 @@ #pragma once #include #include -#include "../../../../../../core/cuda_graph.hpp" +#include "../../../../../core/cuda_graph.hpp" namespace pygpukit { namespace ops { diff --git a/native/ops/matmul/gemm/int4/int4/sm120/int4_via_int8.cu b/native/ops/matmul/gemm/int4_int4/sm120/int4_via_int8.cu similarity index 99% rename from native/ops/matmul/gemm/int4/int4/sm120/int4_via_int8.cu rename to native/ops/matmul/gemm/int4_int4/sm120/int4_via_int8.cu index 08a5cf1..3ce517c 100644 --- a/native/ops/matmul/gemm/int4/int4/sm120/int4_via_int8.cu +++ b/native/ops/matmul/gemm/int4_int4/sm120/int4_via_int8.cu @@ -37,7 +37,7 @@ #include "cutlass/util/device_memory.h" #define PYGPUKIT_PATCH_CUTLASS_LDSM_POST 1 -#include "../../../../common/aligned_copy_sm120.cuh" +#include "../../../common/aligned_copy_sm120.cuh" using namespace cute; diff --git a/native/ops/matmul/gemm/int8/int8/sm120/int8_native.cu b/native/ops/matmul/gemm/int8_int8/sm120/int8_native.cu similarity index 100% rename from native/ops/matmul/gemm/int8/int8/sm120/int8_native.cu rename to native/ops/matmul/gemm/int8_int8/sm120/int8_native.cu diff --git a/native/ops/matmul/gemm/nvf4/bf16/sm120/nvf4_cutlass.cu b/native/ops/matmul/gemm/w4a16_bf16/sm120/nvf4_cutlass.cu similarity index 100% rename from native/ops/matmul/gemm/nvf4/bf16/sm120/nvf4_cutlass.cu rename to native/ops/matmul/gemm/w4a16_bf16/sm120/nvf4_cutlass.cu diff --git a/native/ops/matmul/gemm/nvf4/bf16/sm120/nvf4_nvf4_cutlass.cu b/native/ops/matmul/gemm/w4a16_bf16/sm120/nvf4_nvf4_cutlass.cu similarity index 100% rename from native/ops/matmul/gemm/nvf4/bf16/sm120/nvf4_nvf4_cutlass.cu rename to native/ops/matmul/gemm/w4a16_bf16/sm120/nvf4_nvf4_cutlass.cu diff --git a/native/ops/matmul/gemm/fp8/bf16/sm120/fp8_blockwise.cu b/native/ops/matmul/gemm/w8a16_bf16/sm120/fp8_blockwise.cu similarity index 99% rename from native/ops/matmul/gemm/fp8/bf16/sm120/fp8_blockwise.cu rename to native/ops/matmul/gemm/w8a16_bf16/sm120/fp8_blockwise.cu index c612aba..84f1cdd 100644 --- a/native/ops/matmul/gemm/fp8/bf16/sm120/fp8_blockwise.cu +++ b/native/ops/matmul/gemm/w8a16_bf16/sm120/fp8_blockwise.cu @@ -48,7 +48,7 @@ // Provides alignment-safe LDSM operations for Issue #2902 workaround // ============================================================================ #define PYGPUKIT_PATCH_CUTLASS_LDSM_POST 1 -#include "../../../../common/aligned_copy_sm120.cuh" +#include "../../../common/aligned_copy_sm120.cuh" using namespace cute; diff --git a/native/ops/matmul/gemm/fp8/bf16/sm120/grouped_gemm.cu b/native/ops/matmul/gemm/w8a16_bf16/sm120/grouped_gemm.cu similarity index 100% rename from native/ops/matmul/gemm/fp8/bf16/sm120/grouped_gemm.cu rename to native/ops/matmul/gemm/w8a16_bf16/sm120/grouped_gemm.cu diff --git a/native/ops/matmul/gemm/fp8/bf16/sm120/w8a16_cutlass.cu b/native/ops/matmul/gemm/w8a16_bf16/sm120/w8a16_cutlass.cu similarity index 99% rename from native/ops/matmul/gemm/fp8/bf16/sm120/w8a16_cutlass.cu rename to native/ops/matmul/gemm/w8a16_bf16/sm120/w8a16_cutlass.cu index 4c9fb5a..5015663 100644 --- a/native/ops/matmul/gemm/fp8/bf16/sm120/w8a16_cutlass.cu +++ b/native/ops/matmul/gemm/w8a16_bf16/sm120/w8a16_cutlass.cu @@ -36,7 +36,7 @@ #include "cutlass/util/device_memory.h" #define PYGPUKIT_PATCH_CUTLASS_LDSM_POST 1 -#include "../../../../common/aligned_copy_sm120.cuh" +#include "../../../common/aligned_copy_sm120.cuh" using namespace cute; diff --git a/native/ops/matmul/gemm/fp8/bf16/sm120/w8a16_gemm.cu b/native/ops/matmul/gemm/w8a16_bf16/sm120/w8a16_gemm.cu similarity index 100% rename from native/ops/matmul/gemm/fp8/bf16/sm120/w8a16_gemm.cu rename to native/ops/matmul/gemm/w8a16_bf16/sm120/w8a16_gemm.cu diff --git a/native/ops/matmul/gemm/fp8/fp8/sm120/fp8_cutlass.cu b/native/ops/matmul/gemm/w8a8_bf16/sm120/fp8_cutlass.cu similarity index 99% rename from native/ops/matmul/gemm/fp8/fp8/sm120/fp8_cutlass.cu rename to native/ops/matmul/gemm/w8a8_bf16/sm120/fp8_cutlass.cu index 360a28e..99e07f6 100644 --- a/native/ops/matmul/gemm/fp8/fp8/sm120/fp8_cutlass.cu +++ b/native/ops/matmul/gemm/w8a8_bf16/sm120/fp8_cutlass.cu @@ -38,7 +38,7 @@ // Alignment patch for Issue #2902 workaround #define PYGPUKIT_PATCH_CUTLASS_LDSM_POST 1 -#include "../../../../common/aligned_copy_sm120.cuh" +#include "../../../common/aligned_copy_sm120.cuh" using namespace cute; diff --git a/native/ops/matmul/gemm/fp8/fp8/sm120/fp8_cutlass_v2.cu b/native/ops/matmul/gemm/w8a8_bf16/sm120/fp8_cutlass_v2.cu similarity index 99% rename from native/ops/matmul/gemm/fp8/fp8/sm120/fp8_cutlass_v2.cu rename to native/ops/matmul/gemm/w8a8_bf16/sm120/fp8_cutlass_v2.cu index 261e2e1..06165e0 100644 --- a/native/ops/matmul/gemm/fp8/fp8/sm120/fp8_cutlass_v2.cu +++ b/native/ops/matmul/gemm/w8a8_bf16/sm120/fp8_cutlass_v2.cu @@ -24,7 +24,7 @@ #include "cutlass/util/device_memory.h" #define PYGPUKIT_PATCH_CUTLASS_LDSM_POST 1 -#include "../../../../common/aligned_copy_sm120.cuh" +#include "../../../common/aligned_copy_sm120.cuh" using namespace cute; diff --git a/native/ops/matmul/gemm/fp8/fp8/sm120/fp8_cutlass_v3.cu b/native/ops/matmul/gemm/w8a8_bf16/sm120/fp8_cutlass_v3.cu similarity index 99% rename from native/ops/matmul/gemm/fp8/fp8/sm120/fp8_cutlass_v3.cu rename to native/ops/matmul/gemm/w8a8_bf16/sm120/fp8_cutlass_v3.cu index bd519f1..2775c0f 100644 --- a/native/ops/matmul/gemm/fp8/fp8/sm120/fp8_cutlass_v3.cu +++ b/native/ops/matmul/gemm/w8a8_bf16/sm120/fp8_cutlass_v3.cu @@ -31,7 +31,7 @@ #include "cutlass/util/device_memory.h" #define PYGPUKIT_PATCH_CUTLASS_LDSM_POST 1 -#include "../../../../common/aligned_copy_sm120.cuh" +#include "../../../common/aligned_copy_sm120.cuh" using namespace cute; diff --git a/native/ops/matmul/gemm/fp8/f32/sm100/fp8_blockwise.cu b/native/ops/matmul/gemm/w8a8_f32/sm100/fp8_blockwise.cu similarity index 100% rename from native/ops/matmul/gemm/fp8/f32/sm100/fp8_blockwise.cu rename to native/ops/matmul/gemm/w8a8_f32/sm100/fp8_blockwise.cu diff --git a/native/ops/matmul/gemm/fp8/f32/sm90/fp8_cutlass.cu b/native/ops/matmul/gemm/w8a8_f32/sm90/fp8_cutlass.cu similarity index 100% rename from native/ops/matmul/gemm/fp8/f32/sm90/fp8_cutlass.cu rename to native/ops/matmul/gemm/w8a8_f32/sm90/fp8_cutlass.cu diff --git a/native/ops/matmul/gemv/bf16/bf16/generic/bf16_cutlass.cuh b/native/ops/matmul/gemv/bf16_bf16/generic/bf16_cutlass.cuh similarity index 100% rename from native/ops/matmul/gemv/bf16/bf16/generic/bf16_cutlass.cuh rename to native/ops/matmul/gemv/bf16_bf16/generic/bf16_cutlass.cuh diff --git a/native/ops/matmul/gemv/bf16/bf16/sm120/bf16_opt.cu b/native/ops/matmul/gemv/bf16_bf16/sm120/bf16_opt.cu similarity index 100% rename from native/ops/matmul/gemv/bf16/bf16/sm120/bf16_opt.cu rename to native/ops/matmul/gemv/bf16_bf16/sm120/bf16_opt.cu diff --git a/native/ops/matmul/gemv/bf16/bf16/sm120/bf16_opt.cuh b/native/ops/matmul/gemv/bf16_bf16/sm120/bf16_opt.cuh similarity index 100% rename from native/ops/matmul/gemv/bf16/bf16/sm120/bf16_opt.cuh rename to native/ops/matmul/gemv/bf16_bf16/sm120/bf16_opt.cuh diff --git a/native/ops/matmul/gemv/int4/int4/sm120/int4_gemv.cu b/native/ops/matmul/gemv/int4_int4/sm120/int4_gemv.cu similarity index 100% rename from native/ops/matmul/gemv/int4/int4/sm120/int4_gemv.cu rename to native/ops/matmul/gemv/int4_int4/sm120/int4_gemv.cu diff --git a/native/ops/matmul/gemv/int4/int4/sm120/int4_gemv.cuh b/native/ops/matmul/gemv/int4_int4/sm120/int4_gemv.cuh similarity index 100% rename from native/ops/matmul/gemv/int4/int4/sm120/int4_gemv.cuh rename to native/ops/matmul/gemv/int4_int4/sm120/int4_gemv.cuh diff --git a/native/ops/matmul/gemv/bf16/bf16/sm120/nvf4.cu b/native/ops/matmul/gemv/w4a16_bf16/sm120/nvf4.cu similarity index 100% rename from native/ops/matmul/gemv/bf16/bf16/sm120/nvf4.cu rename to native/ops/matmul/gemv/w4a16_bf16/sm120/nvf4.cu diff --git a/native/ops/matmul/gemv/bf16/bf16/sm120/nvf4.cuh b/native/ops/matmul/gemv/w4a16_bf16/sm120/nvf4.cuh similarity index 100% rename from native/ops/matmul/gemv/bf16/bf16/sm120/nvf4.cuh rename to native/ops/matmul/gemv/w4a16_bf16/sm120/nvf4.cuh diff --git a/native/ops/matmul/gemv/bf16/bf16/sm120/nvf4_kernels.cu b/native/ops/matmul/gemv/w4a16_bf16/sm120/nvf4_kernels.cu similarity index 100% rename from native/ops/matmul/gemv/bf16/bf16/sm120/nvf4_kernels.cu rename to native/ops/matmul/gemv/w4a16_bf16/sm120/nvf4_kernels.cu diff --git a/native/ops/matmul/gemv/nvf4/nvf4/sm120/nvf4_gemv.cu b/native/ops/matmul/gemv/w4a4_bf16/sm120/nvf4_gemv.cu similarity index 100% rename from native/ops/matmul/gemv/nvf4/nvf4/sm120/nvf4_gemv.cu rename to native/ops/matmul/gemv/w4a4_bf16/sm120/nvf4_gemv.cu diff --git a/native/ops/matmul/gemv/nvf4/nvf4/sm120/nvf4_gemv.cuh b/native/ops/matmul/gemv/w4a4_bf16/sm120/nvf4_gemv.cuh similarity index 100% rename from native/ops/matmul/gemv/nvf4/nvf4/sm120/nvf4_gemv.cuh rename to native/ops/matmul/gemv/w4a4_bf16/sm120/nvf4_gemv.cuh diff --git a/native/ops/matmul/gemv/bf16/bf16/sm120/fp8.cuh b/native/ops/matmul/gemv/w8a16_bf16/sm120/fp8.cuh similarity index 100% rename from native/ops/matmul/gemv/bf16/bf16/sm120/fp8.cuh rename to native/ops/matmul/gemv/w8a16_bf16/sm120/fp8.cuh diff --git a/native/ops/matmul/gemv/bf16/bf16/sm120/fp8_opt.cuh b/native/ops/matmul/gemv/w8a16_bf16/sm120/fp8_opt.cuh similarity index 100% rename from native/ops/matmul/gemv/bf16/bf16/sm120/fp8_opt.cuh rename to native/ops/matmul/gemv/w8a16_bf16/sm120/fp8_opt.cuh diff --git a/native/ops/matmul/gemv/bf16/bf16/sm120/fp8_opt_kernels.cu b/native/ops/matmul/gemv/w8a16_bf16/sm120/fp8_opt_kernels.cu similarity index 100% rename from native/ops/matmul/gemv/bf16/bf16/sm120/fp8_opt_kernels.cu rename to native/ops/matmul/gemv/w8a16_bf16/sm120/fp8_opt_kernels.cu diff --git a/native/ops/matmul/gemv/fp8/fp8/sm120/fp8_accurate.cu b/native/ops/matmul/gemv/w8a8_bf16/sm120/fp8_accurate.cu similarity index 100% rename from native/ops/matmul/gemv/fp8/fp8/sm120/fp8_accurate.cu rename to native/ops/matmul/gemv/w8a8_bf16/sm120/fp8_accurate.cu diff --git a/native/ops/matmul/gemv/fp8/fp8/sm120/fp8_accurate.cuh b/native/ops/matmul/gemv/w8a8_bf16/sm120/fp8_accurate.cuh similarity index 100% rename from native/ops/matmul/gemv/fp8/fp8/sm120/fp8_accurate.cuh rename to native/ops/matmul/gemv/w8a8_bf16/sm120/fp8_accurate.cuh diff --git a/native/ops/matmul/gemv/fp8/fp8/sm120/fp8_gemv.cu b/native/ops/matmul/gemv/w8a8_bf16/sm120/fp8_gemv.cu similarity index 100% rename from native/ops/matmul/gemv/fp8/fp8/sm120/fp8_gemv.cu rename to native/ops/matmul/gemv/w8a8_bf16/sm120/fp8_gemv.cu diff --git a/native/ops/matmul/gemv/fp8/fp8/sm120/fp8_gemv.cuh b/native/ops/matmul/gemv/w8a8_bf16/sm120/fp8_gemv.cuh similarity index 100% rename from native/ops/matmul/gemv/fp8/fp8/sm120/fp8_gemv.cuh rename to native/ops/matmul/gemv/w8a8_bf16/sm120/fp8_gemv.cuh diff --git a/native/ops/matmul/matmul.cu b/native/ops/matmul/matmul.cu index 17631ae..077111f 100644 --- a/native/ops/matmul/matmul.cu +++ b/native/ops/matmul/matmul.cu @@ -1,22 +1,22 @@ /** * Matrix multiplication dispatch */ -#include "gemm/f32/f32/generic/f32_naive.cuh" +#include "gemm/f32_f32/generic/f32_naive.cuh" #include "../common/error.cuh" #include "../common/device.cuh" #include "../../core/memory.hpp" #include "../../core/cuda_graph.hpp" #include "../ops.cuh" // For transpose() -// Include existing optimized kernels -#include "gemm/f32/f32/generic/f32_ampere.cuh" -#include "gemm/f32/f32/generic/tf32_wmma.cuh" -#include "gemm/f32/f32/generic/tf32_mma.cuh" -#include "gemm/bf16/bf16/generic/bf16_naive.cuh" -#include "gemm/bf16/bf16/generic/bf16_wmma.cuh" -#include "gemm/bf16/bf16/generic/bf16_wmma_generic.cuh" +// Include existing optimized kernels (Issue #122: Updated paths) +#include "gemm/f32_f32/generic/f32_ampere.cuh" +#include "gemm/f32_f32/generic/tf32_wmma.cuh" +#include "gemm/f32_f32/generic/tf32_mma.cuh" +#include "gemm/bf16_bf16/generic/bf16_naive.cuh" +#include "gemm/bf16_bf16/generic/bf16_wmma.cuh" +#include "gemm/bf16_bf16/generic/bf16_wmma_generic.cuh" #include "cublaslt.cuh" -#include "gemm/bf16/bf16/sm80/bf16_cutlass.cuh" +#include "gemm/bf16_bf16/sm80/bf16_cutlass.cuh" #include #include diff --git a/native/ops/matmul/matmul_cutlass.cu b/native/ops/matmul/matmul_cutlass.cu index 56d7660..e50a52c 100644 --- a/native/ops/matmul/matmul_cutlass.cu +++ b/native/ops/matmul/matmul_cutlass.cu @@ -11,7 +11,7 @@ #if PYGPUKIT_HAS_CUTLASS -#include "gemm/bf16/bf16/sm80/bf16_cutlass.cuh" +#include "gemm/bf16_bf16/sm80/bf16_cutlass.cuh" namespace pygpukit { namespace ops {