Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ jobs:
run: uv run mypy src/kernels

- name: Run tests
run: uv run pytest tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
uv run pytest tests

- name: Check kernel conversion
run: |
Expand Down
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
title: Kernels
- local: api/layers
title: Layers
- local: cli
title: kernels CLI
title: API Reference
- sections:
- local: kernel-requirements
Expand Down
15 changes: 15 additions & 0 deletions docs/source/cli.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Kernels CLI Reference

## Main Functions

### kernels upload

Use `kernels upload <dir_containing_build> --repo_id="hub-username/kernel"` to upload
your kernel builds to the Hub.

**Notes**:

* This will take care of creating a repository on the Hub with the `repo_id` provided.
* If a repo with the `repo_id` already exists and if it contains a `build` with the build variant
being uploaded, it will attempt to delete the files existing under it.
* Make sure to be authenticated (run `hf auth login` if not) to be able to perform uploads to the Hub.
47 changes: 47 additions & 0 deletions src/kernels/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import sys
from pathlib import Path

from huggingface_hub import create_repo, upload_folder

from kernels.compat import tomllib
from kernels.lockfile import KernelLock, get_kernel_locks
from kernels.utils import install_kernel, install_kernel_all_variants
Expand Down Expand Up @@ -31,6 +33,24 @@ def main():
)
download_parser.set_defaults(func=download_kernels)

upload_parser = subparsers.add_parser("upload", help="Upload kernels to the Hub")
upload_parser.add_argument(
"kernel_dir",
type=Path,
help="Directory of the kernel build",
)
upload_parser.add_argument(
"--repo_id",
type=str,
help="Repository ID to use to upload to the Hugging Face Hub",
)
upload_parser.add_argument(
"--private",
action="store_true",
help="If the repository should be private.",
)
upload_parser.set_defaults(func=upload_kernels)

lock_parser = subparsers.add_parser("lock", help="Lock kernel revisions")
lock_parser.add_argument(
"project_dir",
Expand Down Expand Up @@ -153,6 +173,33 @@ def lock_kernels(args):
json.dump(all_locks, f, cls=_JSONEncoder, indent=2)


def upload_kernels(args):
kernel_dir = Path(args.kernel_dir).resolve()
build_dir = kernel_dir / "build"
if not kernel_dir.is_dir():
raise ValueError(f"{kernel_dir} is not a directory")
if not build_dir.is_dir():
raise ValueError("Couldn't find `build` directory inside `kernel_dir`")

repo_id = create_repo(
repo_id=args.repo_id, private=args.private, exist_ok=True
).repo_id

delete_patterns: set[str] = set()
for build_variant in build_dir.iterdir():
if build_variant.is_dir():
delete_patterns.add(f"{build_variant.name}/**")

upload_folder(
repo_id=repo_id,
folder_path=build_dir,
path_in_repo="build",
delete_patterns=list(delete_patterns),
commit_message="Build uploaded using `kernels`.",
)
print(f"✅ Kernel upload successful. Find the kernel in https://hf.co/{repo_id}.")


class _JSONEncoder(json.JSONEncoder):
def default(self, o):
if dataclasses.is_dataclass(o):
Expand Down
92 changes: 92 additions & 0 deletions tests/test_kernel_upload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import logging
import os
import re
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import List

import pytest
from huggingface_hub import model_info

from kernels.cli import upload_kernels

REPO_ID = "kernels-test/kernels-upload-test"

PY_CONTENT = """\
#!/usr/bin/env python3

def main():
print("Hello from torch-universal!")

if __name__ == "__main__":
main()
"""


@dataclass
class UploadArgs:
kernel_dir: None
repo_id: None
private: False


def next_filename(path: Path) -> Path:
"""
Given a path like foo_2050.py, return foo_2051.py.
"""
m = re.match(r"^(.*?)(\d+)(\.py)$", path.name)
if not m:
raise ValueError(
f"Filename {path.name!r} does not match pattern <prefix>_<number>.py"
)

prefix, number, suffix = m.groups()
new_number = str(int(number) + 1).zfill(len(number))
return path.with_name(f"{prefix}{new_number}{suffix}")


def get_filename_to_change(repo_filenames):
for f in repo_filenames:
if "foo" in f and f.endswith(".py"):
filename_to_change = os.path.basename(f)
break
assert filename_to_change
return filename_to_change


def get_filenames_from_a_repo(repo_id: str) -> List[str]:
try:
repo_info = model_info(repo_id=repo_id, files_metadata=True)
repo_siblings = repo_info.siblings
if repo_siblings is not None:
return [f.rfilename for f in repo_siblings]
else:
raise ValueError("No repo siblings found.")
except Exception as e:
logging.error(f"Error connecting to the Hub: {e}.")


@pytest.mark.xfail(
condition=os.environ.get("GITHUB_ACTIONS") == "true",
reason="There is something weird when writing to the Hub from a GitHub CI.",
strict=True,
)
def test_kernel_upload_deletes_as_expected():
repo_filenames = get_filenames_from_a_repo(REPO_ID)
filename_to_change = get_filename_to_change(repo_filenames)

with tempfile.TemporaryDirectory() as tmpdir:
path = f"{tmpdir}/build/torch-universal/upload_test"
build_dir = Path(path)
build_dir.mkdir(parents=True, exist_ok=True)
changed_filename = next_filename(Path(filename_to_change))
script_path = build_dir / changed_filename
script_path.write_text(PY_CONTENT)
upload_kernels(UploadArgs(tmpdir, REPO_ID, False))

repo_filenames = get_filenames_from_a_repo(REPO_ID)
assert any(str(changed_filename) in k for k in repo_filenames), f"{repo_filenames=}"
assert not any(
str(filename_to_change) in k for k in repo_filenames
), f"{repo_filenames=}"
Loading