diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 01fcbe5..a9e8518 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -73,6 +73,12 @@ jobs: uv pip install triton_layer_norm-0.0.1*.whl uv run python -c "import triton_layer_norm" + - name: Check kernel conversion (flat build) + run: | + uv run kernels to-wheel kernels-test/flattened-build 0.0.1 + uv pip install flattened_build-0.0.1*.whl + uv run python -c "import flattened_build" + - name: Check README generation # For now, just checks that generation doesn't fail. run: | diff --git a/src/kernels/wheel.py b/src/kernels/wheel.py index 059a24e..bc33c61 100644 --- a/src/kernels/wheel.py +++ b/src/kernels/wheel.py @@ -165,7 +165,13 @@ def build_wheel( metadata_msg.add_header("Requires-Python", ">=3.9") metadata_msg.add_header("Requires-Dist", requires_dist_torch) - source_pkg_dir = variant_path / name + # Check if the kernel uses a flat build. + if (variant_path / "__init__.py").exists(): + flat_build = True + source_pkg_dir = variant_path + else: + flat_build = False + source_pkg_dir = variant_path / name with WheelFile(wheel_path, "w") as wheel_file: for root, dirnames, filenames in os.walk(source_pkg_dir): @@ -175,6 +181,8 @@ def build_wheel( abs_filepath = os.path.join(root, filename) entry_name = os.path.relpath(abs_filepath, variant_path) + if flat_build: + entry_name = os.path.join(name, entry_name) wheel_file.write(abs_filepath, entry_name) wheel_metadata_path = os.path.join(dist_info_dir_name, "WHEEL")