From 8f3d1df938789a56eb5a2f37ca1c31dad8c996c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 20 Nov 2025 09:55:05 +0000 Subject: [PATCH] `to-wheel`: support flat builds This change adds support for flat builds (kernel in `build/`) in `to-wheel`. --- .github/workflows/test.yml | 6 ++++++ src/kernels/wheel.py | 10 +++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 01fcbe5..a9e8518 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -73,6 +73,12 @@ jobs: uv pip install triton_layer_norm-0.0.1*.whl uv run python -c "import triton_layer_norm" + - name: Check kernel conversion (flat build) + run: | + uv run kernels to-wheel kernels-test/flattened-build 0.0.1 + uv pip install flattened_build-0.0.1*.whl + uv run python -c "import flattened_build" + - name: Check README generation # For now, just checks that generation doesn't fail. run: | diff --git a/src/kernels/wheel.py b/src/kernels/wheel.py index 059a24e..bc33c61 100644 --- a/src/kernels/wheel.py +++ b/src/kernels/wheel.py @@ -165,7 +165,13 @@ def build_wheel( metadata_msg.add_header("Requires-Python", ">=3.9") metadata_msg.add_header("Requires-Dist", requires_dist_torch) - source_pkg_dir = variant_path / name + # Check if the kernel uses a flat build. + if (variant_path / "__init__.py").exists(): + flat_build = True + source_pkg_dir = variant_path + else: + flat_build = False + source_pkg_dir = variant_path / name with WheelFile(wheel_path, "w") as wheel_file: for root, dirnames, filenames in os.walk(source_pkg_dir): @@ -175,6 +181,8 @@ def build_wheel( abs_filepath = os.path.join(root, filename) entry_name = os.path.relpath(abs_filepath, variant_path) + if flat_build: + entry_name = os.path.join(name, entry_name) wheel_file.write(abs_filepath, entry_name) wheel_metadata_path = os.path.join(dist_info_dir_name, "WHEEL")