Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,31 +73,38 @@ native/ops/matmul/
├── common/ # Shared utilities
│ └── aligned_copy_sm120.cuh
├── gemm/ # GEMM kernels (M > 1)
│ └── {input_dtype}/{output_dtype}/{arch}/{compute}_{suffix}.{cu,cuh}
│ └── {w_dtype}_{a_dtype}_{out_dtype}/{arch}/{kernel}.{cu,cuh}
├── gemv/ # GEMV kernels (M = 1)
│ └── {input_dtype}/{output_dtype}/{arch}/{compute}_{suffix}.{cu,cuh}
│ └── {w_dtype}_{a_dtype}_{out_dtype}/{arch}/{kernel}.{cu,cuh}
├── cublaslt.cuh # cuBLASLt wrapper
├── matmul.cu # Main dispatcher
└── matmul_cutlass.cu # CUTLASS dispatcher
```

**Path Convention:** `{gemm|gemv}/{input_dtype}/{output_dtype}/{arch}/{compute}_{suffix}.cu`
**Path Convention:** `{gemm|gemv}/{w{weight}a{act}_{out}}/{arch}/{kernel}.cu`

| Component | Values | Examples |
|-----------|--------|----------|
| `input_dtype` | `f32`, `bf16`, `fp8`, `nvf4` | Input tensor dtype |
| `output_dtype` | `f32`, `bf16`, `fp8` | Output tensor dtype |
| Component | Values | Description |
|-----------|--------|-------------|
| `w_dtype` | `w4`, `w8`, `bf16`, `f32`, `int4`, `int8` | Weight dtype (w=weight) |
| `a_dtype` | `a4`, `a8`, `a16`, `bf16`, `f32`, `int4`, `int8` | Activation dtype (a=act) |
| `out_dtype` | `bf16`, `f32` | Output dtype |
| `arch` | `generic`, `sm80`, `sm90`, `sm100`, `sm120` | Target architecture |
| `compute` | `naive`, `wmma`, `mma`, `cutlass` | Compute method |
| `suffix` | `blockwise`, `kernels`, etc. | Variant identifier |

**Naming Rationale (Issue #122 Option 2):**
- `w8a16_bf16`: FP8 weights, BF16 activations, BF16 output (W8A16 GEMM)
- `w4a16_bf16`: NVF4 weights, BF16 activations, BF16 output (NVF4 GEMV)
- `w8a8_bf16`: FP8 weights, FP8 activations, BF16 output (pure FP8)
- `bf16_bf16`: BF16 weights, BF16 activations (no quantization)
- `f32_f32`: FP32 weights, FP32 activations (baseline)

**Examples:**
```
gemm/bf16/bf16/sm120/bf16_cutlass.cuh # BF16->BF16 GEMM, SM120, CUTLASS
gemm/fp8/f32/sm90/fp8_cutlass.cu # FP8->F32 GEMM, SM90, CUTLASS
gemm/nvf4/bf16/sm120/nvf4_cutlass.cu # NVF4->BF16 GEMM, SM120, CUTLASS
gemv/bf16/bf16/sm120/nvf4.cu # NVF4->BF16 GEMV, SM120
gemm/f32/f32/generic/tf32_mma.cuh # TF32 GEMM, generic (SM80+)
gemm/bf16_bf16/sm80/bf16_cutlass.cuh # BF16 GEMM, SM80, CUTLASS
gemm/w8a8_f32/sm90/fp8_cutlass.cu # FP8->F32 GEMM, SM90, CUTLASS
gemm/w4a16_bf16/sm120/nvf4_cutlass.cu # NVF4 weights, BF16 act->BF16, SM120
gemv/w4a16_bf16/sm120/nvf4.cu # NVF4 GEMV, SM120
gemv/w8a16_bf16/sm120/fp8_opt_kernels.cu # FP8 weight, BF16 act GEMV, SM120
gemm/f32_f32/generic/tf32_mma.cuh # TF32 GEMM, generic (SM80+)
```

### Module Separation Policy
Expand Down
48 changes: 24 additions & 24 deletions native/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -153,30 +153,30 @@ pybind11_add_module(${MODULE_NAME}
ops/reduction/reduction.cu
ops/matmul/matmul.cu
ops/matmul/matmul_cutlass.cu
# GEMM kernels
ops/matmul/gemm/f32/f32/generic/f32_ampere.cu
ops/matmul/gemm/fp8/f32/sm90/fp8_cutlass.cu
ops/matmul/gemm/fp8/f32/sm100/fp8_blockwise.cu
ops/matmul/gemm/fp8/bf16/sm120/fp8_blockwise.cu
ops/matmul/gemm/fp8/bf16/sm120/w8a16_gemm.cu
ops/matmul/gemm/fp8/bf16/sm120/w8a16_cutlass.cu
ops/matmul/gemm/fp8/bf16/sm120/grouped_gemm.cu
ops/matmul/gemm/fp8/fp8/sm120/fp8_cutlass.cu
ops/matmul/gemm/fp8/fp8/sm120/fp8_cutlass_v2.cu
ops/matmul/gemm/fp8/fp8/sm120/fp8_cutlass_v3.cu
ops/matmul/gemm/int8/int8/sm120/int8_native.cu
ops/matmul/gemm/int4/int4/sm120/int4_via_int8.cu
ops/matmul/gemm/nvf4/bf16/sm120/nvf4_cutlass.cu
ops/matmul/gemm/nvf4/bf16/sm120/nvf4_nvf4_cutlass.cu
# GEMV kernels
ops/matmul/gemv/bf16/bf16/sm120/nvf4.cu
ops/matmul/gemv/bf16/bf16/sm120/nvf4_kernels.cu
ops/matmul/gemv/bf16/bf16/sm120/fp8_opt_kernels.cu
ops/matmul/gemv/bf16/bf16/sm120/bf16_opt.cu
ops/matmul/gemv/fp8/fp8/sm120/fp8_gemv.cu
ops/matmul/gemv/fp8/fp8/sm120/fp8_accurate.cu
ops/matmul/gemv/nvf4/nvf4/sm120/nvf4_gemv.cu
ops/matmul/gemv/int4/int4/sm120/int4_gemv.cu
# GEMM kernels (Issue #122: Reorganized with w{weight}a{act}_{out} naming)
ops/matmul/gemm/f32_f32/generic/f32_ampere.cu
ops/matmul/gemm/w8a8_f32/sm90/fp8_cutlass.cu
ops/matmul/gemm/w8a8_f32/sm100/fp8_blockwise.cu
ops/matmul/gemm/w8a16_bf16/sm120/fp8_blockwise.cu
ops/matmul/gemm/w8a16_bf16/sm120/w8a16_gemm.cu
ops/matmul/gemm/w8a16_bf16/sm120/w8a16_cutlass.cu
ops/matmul/gemm/w8a16_bf16/sm120/grouped_gemm.cu
ops/matmul/gemm/w8a8_bf16/sm120/fp8_cutlass.cu
ops/matmul/gemm/w8a8_bf16/sm120/fp8_cutlass_v2.cu
ops/matmul/gemm/w8a8_bf16/sm120/fp8_cutlass_v3.cu
ops/matmul/gemm/int8_int8/sm120/int8_native.cu
ops/matmul/gemm/int4_int4/sm120/int4_via_int8.cu
ops/matmul/gemm/w4a16_bf16/sm120/nvf4_cutlass.cu
ops/matmul/gemm/w4a16_bf16/sm120/nvf4_nvf4_cutlass.cu
# GEMV kernels (Issue #122: Reorganized with w{weight}a{act}_{out} naming)
ops/matmul/gemv/w4a16_bf16/sm120/nvf4.cu
ops/matmul/gemv/w4a16_bf16/sm120/nvf4_kernels.cu
ops/matmul/gemv/w8a16_bf16/sm120/fp8_opt_kernels.cu
ops/matmul/gemv/bf16_bf16/sm120/bf16_opt.cu
ops/matmul/gemv/w8a8_bf16/sm120/fp8_gemv.cu
ops/matmul/gemv/w8a8_bf16/sm120/fp8_accurate.cu
ops/matmul/gemv/w4a4_bf16/sm120/nvf4_gemv.cu
ops/matmul/gemv/int4_int4/sm120/int4_gemv.cu
ops/nn/nn.cu
ops/quantize/quantize.cu
ops/attention/paged_attention.cu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include "../../../../../../core/cuda_graph.hpp"
#include "../../../../../core/cuda_graph.hpp"

namespace pygpukit {
namespace ops {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include "../../../../../../core/cuda_graph.hpp"
#include "../../../../../core/cuda_graph.hpp"

namespace pygpukit {
namespace ops {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include "../../../../../../core/cuda_graph.hpp"
#include "../../../../../core/cuda_graph.hpp"

namespace pygpukit {
namespace ops {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

#include <cuda.h>
#include <cuda_runtime.h>
#include "../../../../../../core/cuda_graph.hpp"
#include "../../../../../core/cuda_graph.hpp"

namespace pygpukit {
namespace ops {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

#include <cuda.h>
#include <cuda_runtime.h>
#include "../../../../../../core/cuda_graph.hpp"
#include "../../../../../core/cuda_graph.hpp"

namespace pygpukit {
namespace ops {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include "../../../../../../core/cuda_graph.hpp"
#include "../../../../../core/cuda_graph.hpp"

namespace pygpukit {
namespace ops {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include "../../../../../../core/cuda_graph.hpp"
#include "../../../../../core/cuda_graph.hpp"

namespace pygpukit {
namespace ops {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
#include "cutlass/util/device_memory.h"

#define PYGPUKIT_PATCH_CUTLASS_LDSM_POST 1
#include "../../../../common/aligned_copy_sm120.cuh"
#include "../../../common/aligned_copy_sm120.cuh"

using namespace cute;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
// Provides alignment-safe LDSM operations for Issue #2902 workaround
// ============================================================================
#define PYGPUKIT_PATCH_CUTLASS_LDSM_POST 1
#include "../../../../common/aligned_copy_sm120.cuh"
#include "../../../common/aligned_copy_sm120.cuh"

using namespace cute;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
#include "cutlass/util/device_memory.h"

#define PYGPUKIT_PATCH_CUTLASS_LDSM_POST 1
#include "../../../../common/aligned_copy_sm120.cuh"
#include "../../../common/aligned_copy_sm120.cuh"

using namespace cute;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

// Alignment patch for Issue #2902 workaround
#define PYGPUKIT_PATCH_CUTLASS_LDSM_POST 1
#include "../../../../common/aligned_copy_sm120.cuh"
#include "../../../common/aligned_copy_sm120.cuh"

using namespace cute;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#include "cutlass/util/device_memory.h"

#define PYGPUKIT_PATCH_CUTLASS_LDSM_POST 1
#include "../../../../common/aligned_copy_sm120.cuh"
#include "../../../common/aligned_copy_sm120.cuh"

using namespace cute;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
#include "cutlass/util/device_memory.h"

#define PYGPUKIT_PATCH_CUTLASS_LDSM_POST 1
#include "../../../../common/aligned_copy_sm120.cuh"
#include "../../../common/aligned_copy_sm120.cuh"

using namespace cute;

Expand Down
18 changes: 9 additions & 9 deletions native/ops/matmul/matmul.cu
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
/**
* Matrix multiplication dispatch
*/
#include "gemm/f32/f32/generic/f32_naive.cuh"
#include "gemm/f32_f32/generic/f32_naive.cuh"
#include "../common/error.cuh"
#include "../common/device.cuh"
#include "../../core/memory.hpp"
#include "../../core/cuda_graph.hpp"
#include "../ops.cuh" // For transpose()

// Include existing optimized kernels
#include "gemm/f32/f32/generic/f32_ampere.cuh"
#include "gemm/f32/f32/generic/tf32_wmma.cuh"
#include "gemm/f32/f32/generic/tf32_mma.cuh"
#include "gemm/bf16/bf16/generic/bf16_naive.cuh"
#include "gemm/bf16/bf16/generic/bf16_wmma.cuh"
#include "gemm/bf16/bf16/generic/bf16_wmma_generic.cuh"
// Include existing optimized kernels (Issue #122: Updated paths)
#include "gemm/f32_f32/generic/f32_ampere.cuh"
#include "gemm/f32_f32/generic/tf32_wmma.cuh"
#include "gemm/f32_f32/generic/tf32_mma.cuh"
#include "gemm/bf16_bf16/generic/bf16_naive.cuh"
#include "gemm/bf16_bf16/generic/bf16_wmma.cuh"
#include "gemm/bf16_bf16/generic/bf16_wmma_generic.cuh"
#include "cublaslt.cuh"
#include "gemm/bf16/bf16/sm80/bf16_cutlass.cuh"
#include "gemm/bf16_bf16/sm80/bf16_cutlass.cuh"

#include <cstdlib>
#include <algorithm>
Expand Down
2 changes: 1 addition & 1 deletion native/ops/matmul/matmul_cutlass.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

#if PYGPUKIT_HAS_CUTLASS

#include "gemm/bf16/bf16/sm80/bf16_cutlass.cuh"
#include "gemm/bf16_bf16/sm80/bf16_cutlass.cuh"

namespace pygpukit {
namespace ops {
Expand Down