Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ jobs:
uv pip install triton_layer_norm-0.0.1*.whl
uv run python -c "import triton_layer_norm"

- name: Check kernel conversion (flat build)
run: |
uv run kernels to-wheel kernels-test/flattened-build 0.0.1
uv pip install flattened_build-0.0.1*.whl
uv run python -c "import flattened_build"

- name: Check README generation
# For now, just checks that generation doesn't fail.
run: |
Expand Down
10 changes: 9 additions & 1 deletion src/kernels/wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,13 @@ def build_wheel(
metadata_msg.add_header("Requires-Python", ">=3.9")
metadata_msg.add_header("Requires-Dist", requires_dist_torch)

source_pkg_dir = variant_path / name
# Check if the kernel uses a flat build.
if (variant_path / "__init__.py").exists():
flat_build = True
source_pkg_dir = variant_path
else:
flat_build = False
source_pkg_dir = variant_path / name

with WheelFile(wheel_path, "w") as wheel_file:
for root, dirnames, filenames in os.walk(source_pkg_dir):
Expand All @@ -175,6 +181,8 @@ def build_wheel(

abs_filepath = os.path.join(root, filename)
entry_name = os.path.relpath(abs_filepath, variant_path)
if flat_build:
entry_name = os.path.join(name, entry_name)
wheel_file.write(abs_filepath, entry_name)

wheel_metadata_path = os.path.join(dist_info_dir_name, "WHEEL")
Expand Down
Loading