From a7a443ad0171c65c1b59919450d44882ec02c5cf Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Mon, 1 Dec 2025 12:00:10 +0800 Subject: [PATCH 1/2] Adds dot benchmark coverage Provides parameterized benchmark over CUDA and MPS devices with multiple dtypes and sizes to compare python, PyTorch, Triton, and Cutlass implementations for the dot kernel --- tests/test_dot.py | 84 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 tests/test_dot.py diff --git a/tests/test_dot.py b/tests/test_dot.py new file mode 100644 index 0000000..9e1ec4e --- /dev/null +++ b/tests/test_dot.py @@ -0,0 +1,84 @@ +import pytest +import torch + +from kernel_course import testing +from kernel_course.python_ops import dot as python_dot + +try: + from kernel_course.pytorch_ops import dot as pytorch_dot + + HAS_PYTORCH = True +except Exception: + pytorch_dot = None + HAS_PYTORCH = False + +try: + from kernel_course.triton_ops import dot as triton_dot + + HAS_TRITON = True +except Exception: + triton_dot = None + HAS_TRITON = False + +try: + from kernel_course.cute_ops import dot as cute_dot + + HAS_CUTE = True +except Exception: + cute_dot = None + HAS_CUTE = False + + +def factory( + numel: int, + device: torch.device, + dtype: torch.dtype = torch.float32, +): + x = torch.linspace(0.0, 1.0, steps=numel, device=device, dtype=dtype) + y = torch.linspace(0.0, 1.0, steps=numel, device=device, dtype=dtype) + return (x, y), {} + + +@pytest.mark.parametrize( + "device", + [ + pytest.param( + torch.device("cuda"), + marks=pytest.mark.skipif( + not torch.cuda.is_available(), reason="requires CUDA" + ), + ), + pytest.param( + torch.device("mps"), + marks=pytest.mark.skipif( + not torch.backends.mps.is_available(), reason="requires MPS" + ), + ), + ], +) +@pytest.mark.parametrize( + "dtype", + [torch.float32, torch.float16, torch.bfloat16], +) +@pytest.mark.parametrize( + "numel", + [1 << 4, 1 << 8, 1 << 16], +) +def test_dot_benchmark(device: torch.device, dtype: torch.dtype, numel: int) -> None: + impls = testing.get_impls( + python_impl=python_dot.dot, + pytorch_impl=pytorch_dot.dot if HAS_PYTORCH else None, + triton_impl=triton_dot.dot if HAS_TRITON else None, + cute_impl=cute_dot.dot if HAS_CUTE else None, + ) + + # Benchmark each implementation + config = testing.BenchmarkConfig(warmup=3, repeat=1_000) + results = testing.run_benchmarks( + impls, + lambda: factory(numel, device, dtype), + flops=2 * numel, + config=config, + ) + + testing.show_benchmarks(results) From a7f9fb2e5472d51110d1c4687059c4edb4b735b1 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Mon, 1 Dec 2025 12:00:54 +0800 Subject: [PATCH 2/2] Marks dot test coverage Updates the kernel summary so dot now shows an available unit test, keeping the documentation aligned with actual test coverage --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index f5ed58b..637e9ff 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ The following common BLAS kernels have been implemented in multiple frameworks. | [swap](./docs/swap.md) | swap vectors | $x \leftrightarrow y$ | $0$ | $4n$ | [✅](./kernel_course/python_ops/swap.py) | [✅](./kernel_course/pytorch_ops/swap.py) | [✅](./kernel_course/triton_ops/swap.py) | ❌ | [✅](./tests/test_swap.py) | | [scal](./docs/scal.md) | scale vector | $y = \alpha y$ | $n$ | $2n$ | [✅](./kernel_course/python_ops/scal.py) | [✅](./kernel_course/pytorch_ops/scal.py) | [✅](./kernel_course/triton_ops/scal.py) | ❌ | [✅](./tests/test_scal.py) | | [axpby](./docs/axpby.md) | update vector| $y = \alpha x + \beta y$ | $3n$ | $3n$ | [✅](./kernel_course/python_ops/axpby.py) | [✅](./kernel_course/pytorch_ops/axpby.py) | [✅](./kernel_course/triton_ops/axpby.py) | ❌ | [✅](./tests/test_axpby.py) | -| [dot](./docs/dot.md) | dot product | $z = x^\top y$ | $2n$ | $2n$ | [✅](./kernel_course/python_ops/dot.py) | [✅](./kernel_course/pytorch_ops/dot.py) | [✅](./kernel_course/triton_ops/dot.py) | ❌ | ❌ | +| [dot](./docs/dot.md) | dot product | $z = x^\top y$ | $2n$ | $2n$ | [✅](./kernel_course/python_ops/dot.py) | [✅](./kernel_course/pytorch_ops/dot.py) | [✅](./kernel_course/triton_ops/dot.py) | ❌ | [✅](./tests/test_dot.py) | | gemv | general matrix-vector multiply | $y = \alpha A x + \beta y$ | $2mn$ | $mn + n + 2m$ | ❌ | ❌ | ❌ | ❌ | ❌ | | geru | general rank-1 update | $A = A + \alpha x y^\top$ | $2mn$ | $2mn + m + n$ | ❌ | ❌ | ❌ | ❌ | ❌ | | gemm | general matrix-matrix multiply | $C = \alpha A B + \beta C$ | $2mnk$ | $mk + nk + 2mn$ | ❌ | ❌ | ❌ | ❌ | ❌ |