From 636fc40d6d5dc34e739e0dff32e2770ed5143a3f Mon Sep 17 00:00:00 2001 From: Michael Boulton Date: Sat, 2 Mar 2024 14:19:49 +0000 Subject: [PATCH] feat: Download platform agnostic wheels using rctx.download If we can determine that there's a platform agnostic wheel (dep-ver-py3-none-any.whl) then download it using rctx.download, allowing it to be cached. This partially solves https://github.com/bazelbuild/rules_python/issues/1357 but it still doesn't handle dependencies that have platform specific binary wheels like https://pypi.org/project/psycopg2-binary/ --- python/pip_install/pip_repository.bzl | 83 +++++++++++++++++++++++---- 1 file changed, 71 insertions(+), 12 deletions(-) diff --git a/python/pip_install/pip_repository.bzl b/python/pip_install/pip_repository.bzl index 7b8160e95..6093ee09e 100644 --- a/python/pip_install/pip_repository.bzl +++ b/python/pip_install/pip_repository.bzl @@ -416,6 +416,56 @@ def _pip_repository_impl(rctx): return +def _try_finding_platform_agnostic_wheel(rctx, requirement): + """Tries to find a platform agnostic wheel for the given requirement, otherwise returns None""" + split_req = [i for i in requirement.split(" ") if i != ""] + requirement_name, requirement_version = split_req[0].split("==") + if "[" in requirement_name: + # Remove 'extras' specification + requirement_name = requirement_name.split("[")[0] + + # Get all possible expected hashes + requirement_hashes = [s.removeprefix("--hash=") for s in split_req[1:]] + split_hashes = [s.split(":") for s in requirement_hashes] + by_hash_type = {} + for alg, digest in split_hashes: + if alg in by_hash_type: + by_hash_type[alg].append(digest) + else: + by_hash_type[alg] = [digest] + + rctx.download( + url = "https://pypi.org/pypi/{}/{}/json".format(requirement_name, requirement_version), + output = "package_spec.json", + ) + package_contents = json.decode(rctx.read("package_spec.json")) + if not rctx.delete("package_spec.json"): + fail("failed to delete the package_spec.json file") + + # Filter yanked uploads + downloads_for_tag = [d for d in package_contents["urls"] if not d["yanked"]] + platform_agnostic_wheels = [d for d in downloads_for_tag if d["filename"].endswith("-py3-none-any.whl")] + + if not platform_agnostic_wheels: + return None + + with_correct_hash = [] + + # What we want the hash to be + for hash_type, expected_digests in by_hash_type.items(): + for d in platform_agnostic_wheels: + digests = d["digests"] + + # What pypi says the hash is + actual_digest = digests.get(hash_type, None) + if actual_digest and actual_digest in expected_digests: + with_correct_hash.append(d) + + if not with_correct_hash: + fail("Hash mismatch for requirement: {}".format(requirement)) + + return with_correct_hash[0] + common_env = [ "RULES_PYTHON_PIP_ISOLATED", REPO_DEBUG_ENV_VAR, @@ -766,18 +816,28 @@ def _whl_library_impl(rctx): # Manually construct the PYTHONPATH since we cannot use the toolchain here environment = _create_repository_execution_environment(rctx, python_interpreter) - repo_utils.execute_checked( - rctx, - op = "whl_library.ResolveRequirement({}, {})".format(rctx.attr.name, rctx.attr.requirement), - arguments = args, - environment = environment, - quiet = rctx.attr.quiet, - timeout = rctx.attr.timeout, - ) + target_platforms = rctx.attr.experimental_target_platforms + platform_agnostic_wheel = _try_finding_platform_agnostic_wheel(rctx, rctx.attr.requirement) + if platform_agnostic_wheel and not target_platforms: + rctx.download( + url = platform_agnostic_wheel["url"], + sha256 = platform_agnostic_wheel["digests"]["sha256"], + output = platform_agnostic_wheel["filename"], + ) + whl_path = rctx.path(platform_agnostic_wheel["filename"]) + else: + repo_utils.execute_checked( + rctx, + op = "whl_library.ResolveRequirement({}, {})".format(rctx.attr.name, rctx.attr.requirement), + arguments = args, + environment = environment, + quiet = rctx.attr.quiet, + timeout = rctx.attr.timeout, + ) - whl_path = rctx.path(json.decode(rctx.read("whl_file.json"))["whl_file"]) - if not rctx.delete("whl_file.json"): - fail("failed to delete the whl_file.json file") + whl_path = rctx.path(json.decode(rctx.read("whl_file.json"))["whl_file"]) + if not rctx.delete("whl_file.json"): + fail("failed to delete the whl_file.json file") if rctx.attr.whl_patches: patches = {} @@ -795,7 +855,6 @@ def _whl_library_impl(rctx): timeout = rctx.attr.timeout, ) - target_platforms = rctx.attr.experimental_target_platforms if target_platforms: parsed_whl = parse_whl_name(whl_path.basename) if parsed_whl.platform_tag != "any":