diff --git a/src/unearth/finder.py b/src/unearth/finder.py index e144489..39b9f23 100644 --- a/src/unearth/finder.py +++ b/src/unearth/finder.py @@ -26,13 +26,15 @@ validate_hashes, ) from unearth.link import Link -from unearth.preparer import unpack_link +from unearth.preparer import noop_download_reporter, noop_unpack_reporter, unpack_link from unearth.session import PyPISession from unearth.utils import LazySequence if TYPE_CHECKING: from typing import TypedDict + from unearth.preparer import DownloadReporter, UnpackReporter + class Source(TypedDict): url: str type: str @@ -377,6 +379,8 @@ def download_and_unpack( location: str | pathlib.Path, download_dir: str | pathlib.Path | None = None, hashes: dict[str, list[str]] | None = None, + download_reporter: DownloadReporter = noop_download_reporter, + unpack_reporter: UnpackReporter = noop_unpack_reporter, ) -> pathlib.Path: """Download and unpack the package at the given link. @@ -393,21 +397,32 @@ def download_and_unpack( download_dir: The directory to download to, or None to use a temporary directory created by unearth. hashes (dict[str, list[str]]|None): The optional hash dict for validation. + download_reporter (DownloadReporter): The download reporter for progress + reporting. By default, it does nothing. + unpack_reporter (UnpackReporter): The unpack reporter for progress + reporting. By default, it does nothing. Returns: The path to the installable file or directory. """ - # Strip the rev part for VCS links + import contextlib + if hashes is None: hashes = link.hash_option - if download_dir is None: - download_dir = TemporaryDirectory(prefix="unearth-download-").name - file = unpack_link( - self.session, - link, - pathlib.Path(download_dir), - pathlib.Path(location), - hashes, - verbosity=self.verbosity, - ) + + with contextlib.ExitStack() as stack: + if download_dir is None: + download_dir = stack.enter_context( + TemporaryDirectory(prefix="unearth-download-") + ) + file = unpack_link( + self.session, + link, + pathlib.Path(download_dir), + pathlib.Path(location), + hashes, + verbosity=self.verbosity, + download_reporter=download_reporter, + unpack_reporter=unpack_reporter, + ) return file.joinpath(link.subdirectory) if link.subdirectory else file diff --git a/src/unearth/preparer.py b/src/unearth/preparer.py index 76cb92a..a4b1fb5 100644 --- a/src/unearth/preparer.py +++ b/src/unearth/preparer.py @@ -1,6 +1,7 @@ """Unpack the link to an installed wheel or source.""" from __future__ import annotations +import functools import hashlib import logging import mimetypes @@ -10,7 +11,7 @@ import tarfile import zipfile from pathlib import Path -from typing import Iterable, cast +from typing import TYPE_CHECKING, Iterable, cast from requests import HTTPError, Session @@ -23,9 +24,30 @@ ZIP_EXTENSIONS, display_path, format_size, + iter_with_callback, ) from unearth.vcs import vcs_support +if TYPE_CHECKING: + from typing import Protocol + + class DownloadReporter(Protocol): + def __call__(self, link: Link, completed: int, total: int | None) -> None: + ... + + class UnpackReporter(Protocol): + def __call__(self, filename: Path, completed: int, total: int | None) -> None: + ... + + +def noop_download_reporter(link: Link, completed: int, total: int | None) -> None: + pass + + +def noop_unpack_reporter(filename: Path, completed: int, total: int | None) -> None: + pass + + READ_CHUNK_SIZE = 8192 logger = logging.getLogger(__name__) @@ -133,30 +155,33 @@ def _check_downloaded(path: Path, hashes: dict[str, list[str]] | None) -> bool: return True -def unpack_archive(archive: Path, dest: Path) -> None: +def unpack_archive( + archive: Path, dest: Path, reporter: UnpackReporter = noop_unpack_reporter +) -> None: content_type = mimetypes.guess_type(str(archive))[0] if ( content_type == "application/zip" or zipfile.is_zipfile(archive) or archive.suffix.lower() in ZIP_EXTENSIONS ): - _unzip_archive(archive, dest) + _unzip_archive(archive, dest, reporter=reporter) elif ( content_type == "application/x-gzip" or tarfile.is_tarfile(archive) or archive.suffix.lower() in (TAR_EXTENSIONS + XZ_EXTENSIONS + BZ2_EXTENSIONS) ): - _untar_archive(archive, dest) + _untar_archive(archive, dest, reporter=reporter) else: raise UnpackError(f"Unknown archive type: {archive.name}") -def _unzip_archive(filename: Path, location: Path) -> None: +def _unzip_archive(filename: Path, location: Path, reporter: UnpackReporter) -> None: os.makedirs(location, exist_ok=True) zipfp = open(filename, "rb") with zipfile.ZipFile(zipfp, allowZip64=True) as zip: leading = has_leading_dir(zip.namelist()) - for info in zip.infolist(): + callback = functools.partial(reporter, filename, total=len(zip.infolist())) + for info in iter_with_callback(zip.infolist(), callback): name = info.filename fn = name if leading: @@ -183,7 +208,7 @@ def _unzip_archive(filename: Path, location: Path) -> None: set_extracted_file_to_default_mode_plus_executable(fn) -def _untar_archive(filename: Path, location: Path) -> None: +def _untar_archive(filename: Path, location: Path, reporter: UnpackReporter) -> None: """Untar the file (with path `filename`) to the destination `location`.""" os.makedirs(location, exist_ok=True) lower_fn = str(filename).lower() @@ -203,7 +228,8 @@ def _untar_archive(filename: Path, location: Path) -> None: mode = "r:*" with tarfile.open(filename, mode, encoding="utf-8") as tar: leading = has_leading_dir([member.name for member in tar.getmembers()]) - for member in tar.getmembers(): + callback = functools.partial(reporter, filename, total=len(tar.getmembers())) + for member in iter_with_callback(tar.getmembers(), callback): fn = member.name if leading: fn = split_leading_dir(fn)[1] @@ -261,6 +287,8 @@ def unpack_link( location: Path, hashes: dict[str, list[str]] | None = None, verbosity: int = 0, + download_reporter: DownloadReporter = noop_download_reporter, + unpack_reporter: UnpackReporter = noop_unpack_reporter, ) -> Path: """Unpack link into location. @@ -302,13 +330,22 @@ def unpack_link( resp.raise_for_status() except HTTPError as e: raise UnpackError(f"Download failed: {e}") from None + try: + total = int(resp.headers["Content-Length"]) + except (KeyError, ValueError, TypeError): + total = None if getattr(resp, "from_cache", False): logger.info("Using cached %s", link) else: size = format_size(resp.headers.get("Content-Length", "")) logger.info("Downloading %s (%s)", link, size) with artifact.open("wb") as f: - for chunk in resp.iter_content(chunk_size=READ_CHUNK_SIZE): + callback = functools.partial(download_reporter, link, total=total) + for chunk in iter_with_callback( + resp.iter_content(chunk_size=READ_CHUNK_SIZE), + callback, + stepper=len, + ): if chunk: validator.update(chunk) f.write(chunk) @@ -323,5 +360,5 @@ def unpack_link( os.replace(artifact, target_file) return target_file - unpack_archive(artifact, location) + unpack_archive(artifact, location, reporter=unpack_reporter) return location diff --git a/src/unearth/utils.py b/src/unearth/utils.py index 8370a8d..1e56aef 100644 --- a/src/unearth/utils.py +++ b/src/unearth/utils.py @@ -9,7 +9,7 @@ import urllib.parse as parse import warnings from pathlib import Path -from typing import Iterable, Iterator, Sequence, TypeVar +from typing import Callable, Iterable, Iterator, Sequence, TypeVar from urllib.request import pathname2url, url2pathname WINDOWS = sys.platform == "win32" @@ -188,7 +188,7 @@ def format_size(size: str) -> str: return f"{int(int_size)} bytes" -T = TypeVar("T", covariant=True) +T = TypeVar("T") class LazySequence(Sequence[T]): @@ -255,3 +255,17 @@ def fix_wildcard(match: re.Match[str]) -> str: return f"{operator}{version}" return _legacy_specifier_re.sub(fix_wildcard, specifier) + + +def iter_with_callback( + iterable: Iterable[T], + callback: Callable[[int], None], + stepper: Callable[[T], int] = lambda _: 1, +) -> Iterator[T]: + completed = 0 + for item in iterable: + try: + yield item + finally: + completed += stepper(item) + callback(completed) diff --git a/tests/fixtures/files/first-2.0.2.tar.gz b/tests/fixtures/files/first-2.0.2.tar.gz new file mode 100644 index 0000000..4f6aef7 Binary files /dev/null and b/tests/fixtures/files/first-2.0.2.tar.gz differ diff --git a/tests/test_finder.py b/tests/test_finder.py index 7bea699..908bf86 100644 --- a/tests/test_finder.py +++ b/tests/test_finder.py @@ -188,3 +188,42 @@ def test_find_requirement_preference_respect_source_order(session, fixtures_dir) best = finder.find_best_match("first").best assert best.link.filename == "first-2.0.2.tar.gz" assert best.link.comes_from == "https://pypi.org/simple/first/" + + +def test_download_package_file(session, tmp_path): + finder = PackageFinder( + session=session, + index_urls=[DEFAULT_INDEX_URL], + ignore_compatibility=True, + ) + found = finder.find_best_match("first").best.link + assert found.filename == "first-2.0.2.tar.gz" + for subdir in ("download", "unpack"): + (tmp_path / subdir).mkdir() + + download_reports = [] + unpack_reports = [] + + def download_reporter(link, completed, total): + download_reports.append((link, completed, total)) + + def unpack_reporter(filename, completed, total): + unpack_reports.append((filename, completed, total)) + + finder.download_and_unpack( + found, + tmp_path / "unpack", + download_dir=tmp_path / "download", + download_reporter=download_reporter, + unpack_reporter=unpack_reporter, + ) + downloaded = tmp_path / "download" / found.filename + assert downloaded.exists() + size = downloaded.stat().st_size + assert size > 0 + _, completed, total = download_reports[-1] + assert completed == total == size + + filename, completed, total = unpack_reports[-1] + assert completed == total + assert filename == downloaded