Skip to content

Commit

Permalink
feat: Download platform agnostic wheels using rctx.download
Browse files Browse the repository at this point in the history
If we can determine that there's a platform agnostic wheel
(dep-ver-py3-none-any.whl) then download it using rctx.download,
allowing it to be cached.

This partially solves bazelbuild#1357 but
it still doesn't handle dependencies that have platform specific binary wheels
like https://pypi.org/project/psycopg2-binary/
  • Loading branch information
michaelboulton committed Mar 3, 2024
1 parent da10ac4 commit 636fc40
Showing 1 changed file with 71 additions and 12 deletions.
83 changes: 71 additions & 12 deletions python/pip_install/pip_repository.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,56 @@ def _pip_repository_impl(rctx):

return

def _try_finding_platform_agnostic_wheel(rctx, requirement):
"""Tries to find a platform agnostic wheel for the given requirement, otherwise returns None"""
split_req = [i for i in requirement.split(" ") if i != ""]
requirement_name, requirement_version = split_req[0].split("==")
if "[" in requirement_name:
# Remove 'extras' specification
requirement_name = requirement_name.split("[")[0]

# Get all possible expected hashes
requirement_hashes = [s.removeprefix("--hash=") for s in split_req[1:]]
split_hashes = [s.split(":") for s in requirement_hashes]
by_hash_type = {}
for alg, digest in split_hashes:
if alg in by_hash_type:
by_hash_type[alg].append(digest)
else:
by_hash_type[alg] = [digest]

rctx.download(
url = "https://pypi.org/pypi/{}/{}/json".format(requirement_name, requirement_version),
output = "package_spec.json",
)
package_contents = json.decode(rctx.read("package_spec.json"))
if not rctx.delete("package_spec.json"):
fail("failed to delete the package_spec.json file")

# Filter yanked uploads
downloads_for_tag = [d for d in package_contents["urls"] if not d["yanked"]]
platform_agnostic_wheels = [d for d in downloads_for_tag if d["filename"].endswith("-py3-none-any.whl")]

if not platform_agnostic_wheels:
return None

with_correct_hash = []

# What we want the hash to be
for hash_type, expected_digests in by_hash_type.items():
for d in platform_agnostic_wheels:
digests = d["digests"]

# What pypi says the hash is
actual_digest = digests.get(hash_type, None)
if actual_digest and actual_digest in expected_digests:
with_correct_hash.append(d)

if not with_correct_hash:
fail("Hash mismatch for requirement: {}".format(requirement))

return with_correct_hash[0]

common_env = [
"RULES_PYTHON_PIP_ISOLATED",
REPO_DEBUG_ENV_VAR,
Expand Down Expand Up @@ -766,18 +816,28 @@ def _whl_library_impl(rctx):
# Manually construct the PYTHONPATH since we cannot use the toolchain here
environment = _create_repository_execution_environment(rctx, python_interpreter)

repo_utils.execute_checked(
rctx,
op = "whl_library.ResolveRequirement({}, {})".format(rctx.attr.name, rctx.attr.requirement),
arguments = args,
environment = environment,
quiet = rctx.attr.quiet,
timeout = rctx.attr.timeout,
)
target_platforms = rctx.attr.experimental_target_platforms
platform_agnostic_wheel = _try_finding_platform_agnostic_wheel(rctx, rctx.attr.requirement)
if platform_agnostic_wheel and not target_platforms:
rctx.download(
url = platform_agnostic_wheel["url"],
sha256 = platform_agnostic_wheel["digests"]["sha256"],
output = platform_agnostic_wheel["filename"],
)
whl_path = rctx.path(platform_agnostic_wheel["filename"])
else:
repo_utils.execute_checked(
rctx,
op = "whl_library.ResolveRequirement({}, {})".format(rctx.attr.name, rctx.attr.requirement),
arguments = args,
environment = environment,
quiet = rctx.attr.quiet,
timeout = rctx.attr.timeout,
)

whl_path = rctx.path(json.decode(rctx.read("whl_file.json"))["whl_file"])
if not rctx.delete("whl_file.json"):
fail("failed to delete the whl_file.json file")
whl_path = rctx.path(json.decode(rctx.read("whl_file.json"))["whl_file"])
if not rctx.delete("whl_file.json"):
fail("failed to delete the whl_file.json file")

if rctx.attr.whl_patches:
patches = {}
Expand All @@ -795,7 +855,6 @@ def _whl_library_impl(rctx):
timeout = rctx.attr.timeout,
)

target_platforms = rctx.attr.experimental_target_platforms
if target_platforms:
parsed_whl = parse_whl_name(whl_path.basename)
if parsed_whl.platform_tag != "any":
Expand Down

0 comments on commit 636fc40

Please sign in to comment.