Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions kernels/src/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,11 @@ def install_kernel(
package_name = package_name_from_repo_id(repo_id)

variants = get_variants(api, repo_id=repo_id, revision=revision)
variant = resolve_variant(variants, backend)

if variant is None:
try:
variant = resolve_variant(variants, backend)
except FileNotFoundError as e:
raise FileNotFoundError(
f"Cannot find a build variant for this system in {repo_id} (revision: {revision}). Available variants: {', '.join([variant.variant_str for variant in variants])}"
f"Cannot find a build variant: {e.filename} for this system in {repo_id} (revision: {revision}). Available variants: {', '.join([variant.variant_str for variant in variants])}"
)

allow_patterns = [f"build/{variant.variant_str}/*"]
Expand Down Expand Up @@ -478,9 +478,11 @@ def load_kernel(
variants = get_variants(api, repo_id=repo_id, revision=locked_sha)
variant = resolve_variant(variants, backend)

if variant is None:
try:
variant = resolve_variant(variants, backend)
except FileNotFoundError as e:
raise FileNotFoundError(
f"Cannot find a build variant for this system in {repo_id} (revision: {locked_sha}). Available variants: {', '.join([variant.variant_str for variant in variants])}"
f"Cannot find a build variant: {e.filename} for this system in {repo_id} (revision: {locked_sha}). Available variants: {', '.join([variant.variant_str for variant in variants])}"
)

allow_patterns = [f"build/{variant.variant_str}/*"]
Expand Down
7 changes: 6 additions & 1 deletion kernels/src/kernels/variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def resolve_variants(variants: list[Variant], backend: str | None = None) -> lis
tvm_ffi_version = parse(tvm_ffi.__version__)
tvm_ffi_version = Version(f"{tvm_ffi_version.major}.{tvm_ffi_version.minor}")

return _resolve_variant_for_system(
variants = _resolve_variant_for_system(
variants=variants,
selected_backend=selected_backend,
cpu=cpu,
Expand All @@ -353,6 +353,11 @@ def resolve_variants(variants: list[Variant], backend: str | None = None) -> lis
torch_cxx11_abi=torch_cxx11_abi,
tvm_ffi_version=tvm_ffi_version,
)
if not variants:
missing_variant = FileNotFoundError( "Variant not found.")
missing_variant.filename = f"torch{torch_version.major}{torch_version.minor}-{'cxx11' if torch_cxx11_abi else 'cxx98'}-cu{selected_backend.version.major}{selected_backend.version.minor}-{cpu}-{os}"
raise missing_variant



def _resolve_variant_for_system(
Expand Down