Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add callback to report download status #78

Merged
merged 1 commit into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 27 additions & 12 deletions src/unearth/finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@
validate_hashes,
)
from unearth.link import Link
from unearth.preparer import unpack_link
from unearth.preparer import noop_download_reporter, noop_unpack_reporter, unpack_link
from unearth.session import PyPISession
from unearth.utils import LazySequence

if TYPE_CHECKING:
from typing import TypedDict

from unearth.preparer import DownloadReporter, UnpackReporter

class Source(TypedDict):
url: str
type: str
Expand Down Expand Up @@ -377,6 +379,8 @@ def download_and_unpack(
location: str | pathlib.Path,
download_dir: str | pathlib.Path | None = None,
hashes: dict[str, list[str]] | None = None,
download_reporter: DownloadReporter = noop_download_reporter,
unpack_reporter: UnpackReporter = noop_unpack_reporter,
) -> pathlib.Path:
"""Download and unpack the package at the given link.

Expand All @@ -393,21 +397,32 @@ def download_and_unpack(
download_dir: The directory to download to, or None to use a
temporary directory created by unearth.
hashes (dict[str, list[str]]|None): The optional hash dict for validation.
download_reporter (DownloadReporter): The download reporter for progress
reporting. By default, it does nothing.
unpack_reporter (UnpackReporter): The unpack reporter for progress
reporting. By default, it does nothing.

Returns:
The path to the installable file or directory.
"""
# Strip the rev part for VCS links
import contextlib

if hashes is None:
hashes = link.hash_option
if download_dir is None:
download_dir = TemporaryDirectory(prefix="unearth-download-").name
file = unpack_link(
self.session,
link,
pathlib.Path(download_dir),
pathlib.Path(location),
hashes,
verbosity=self.verbosity,
)

with contextlib.ExitStack() as stack:
if download_dir is None:
download_dir = stack.enter_context(
TemporaryDirectory(prefix="unearth-download-")
)
file = unpack_link(
self.session,
link,
pathlib.Path(download_dir),
pathlib.Path(location),
hashes,
verbosity=self.verbosity,
download_reporter=download_reporter,
unpack_reporter=unpack_reporter,
)
return file.joinpath(link.subdirectory) if link.subdirectory else file
57 changes: 47 additions & 10 deletions src/unearth/preparer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Unpack the link to an installed wheel or source."""
from __future__ import annotations

import functools
import hashlib
import logging
import mimetypes
Expand All @@ -10,7 +11,7 @@
import tarfile
import zipfile
from pathlib import Path
from typing import Iterable, cast
from typing import TYPE_CHECKING, Iterable, cast

from requests import HTTPError, Session

Expand All @@ -23,9 +24,30 @@
ZIP_EXTENSIONS,
display_path,
format_size,
iter_with_callback,
)
from unearth.vcs import vcs_support

if TYPE_CHECKING:
from typing import Protocol

class DownloadReporter(Protocol):
def __call__(self, link: Link, completed: int, total: int | None) -> None:
...

class UnpackReporter(Protocol):
def __call__(self, filename: Path, completed: int, total: int | None) -> None:
...


def noop_download_reporter(link: Link, completed: int, total: int | None) -> None:
pass


def noop_unpack_reporter(filename: Path, completed: int, total: int | None) -> None:
pass


READ_CHUNK_SIZE = 8192
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -133,30 +155,33 @@ def _check_downloaded(path: Path, hashes: dict[str, list[str]] | None) -> bool:
return True


def unpack_archive(archive: Path, dest: Path) -> None:
def unpack_archive(
archive: Path, dest: Path, reporter: UnpackReporter = noop_unpack_reporter
) -> None:
content_type = mimetypes.guess_type(str(archive))[0]
if (
content_type == "application/zip"
or zipfile.is_zipfile(archive)
or archive.suffix.lower() in ZIP_EXTENSIONS
):
_unzip_archive(archive, dest)
_unzip_archive(archive, dest, reporter=reporter)
elif (
content_type == "application/x-gzip"
or tarfile.is_tarfile(archive)
or archive.suffix.lower() in (TAR_EXTENSIONS + XZ_EXTENSIONS + BZ2_EXTENSIONS)
):
_untar_archive(archive, dest)
_untar_archive(archive, dest, reporter=reporter)
else:
raise UnpackError(f"Unknown archive type: {archive.name}")


def _unzip_archive(filename: Path, location: Path) -> None:
def _unzip_archive(filename: Path, location: Path, reporter: UnpackReporter) -> None:
os.makedirs(location, exist_ok=True)
zipfp = open(filename, "rb")
with zipfile.ZipFile(zipfp, allowZip64=True) as zip:
leading = has_leading_dir(zip.namelist())
for info in zip.infolist():
callback = functools.partial(reporter, filename, total=len(zip.infolist()))
for info in iter_with_callback(zip.infolist(), callback):
name = info.filename
fn = name
if leading:
Expand All @@ -183,7 +208,7 @@ def _unzip_archive(filename: Path, location: Path) -> None:
set_extracted_file_to_default_mode_plus_executable(fn)


def _untar_archive(filename: Path, location: Path) -> None:
def _untar_archive(filename: Path, location: Path, reporter: UnpackReporter) -> None:
"""Untar the file (with path `filename`) to the destination `location`."""
os.makedirs(location, exist_ok=True)
lower_fn = str(filename).lower()
Expand All @@ -203,7 +228,8 @@ def _untar_archive(filename: Path, location: Path) -> None:
mode = "r:*"
with tarfile.open(filename, mode, encoding="utf-8") as tar:
leading = has_leading_dir([member.name for member in tar.getmembers()])
for member in tar.getmembers():
callback = functools.partial(reporter, filename, total=len(tar.getmembers()))
for member in iter_with_callback(tar.getmembers(), callback):
fn = member.name
if leading:
fn = split_leading_dir(fn)[1]
Expand Down Expand Up @@ -261,6 +287,8 @@ def unpack_link(
location: Path,
hashes: dict[str, list[str]] | None = None,
verbosity: int = 0,
download_reporter: DownloadReporter = noop_download_reporter,
unpack_reporter: UnpackReporter = noop_unpack_reporter,
) -> Path:
"""Unpack link into location.

Expand Down Expand Up @@ -302,13 +330,22 @@ def unpack_link(
resp.raise_for_status()
except HTTPError as e:
raise UnpackError(f"Download failed: {e}") from None
try:
total = int(resp.headers["Content-Length"])
except (KeyError, ValueError, TypeError):
total = None
if getattr(resp, "from_cache", False):
logger.info("Using cached %s", link)
else:
size = format_size(resp.headers.get("Content-Length", ""))
logger.info("Downloading %s (%s)", link, size)
with artifact.open("wb") as f:
for chunk in resp.iter_content(chunk_size=READ_CHUNK_SIZE):
callback = functools.partial(download_reporter, link, total=total)
for chunk in iter_with_callback(
resp.iter_content(chunk_size=READ_CHUNK_SIZE),
callback,
stepper=len,
):
if chunk:
validator.update(chunk)
f.write(chunk)
Expand All @@ -323,5 +360,5 @@ def unpack_link(
os.replace(artifact, target_file)
return target_file

unpack_archive(artifact, location)
unpack_archive(artifact, location, reporter=unpack_reporter)
return location
18 changes: 16 additions & 2 deletions src/unearth/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import urllib.parse as parse
import warnings
from pathlib import Path
from typing import Iterable, Iterator, Sequence, TypeVar
from typing import Callable, Iterable, Iterator, Sequence, TypeVar
from urllib.request import pathname2url, url2pathname

WINDOWS = sys.platform == "win32"
Expand Down Expand Up @@ -188,7 +188,7 @@ def format_size(size: str) -> str:
return f"{int(int_size)} bytes"


T = TypeVar("T", covariant=True)
T = TypeVar("T")


class LazySequence(Sequence[T]):
Expand Down Expand Up @@ -255,3 +255,17 @@ def fix_wildcard(match: re.Match[str]) -> str:
return f"{operator}{version}"

return _legacy_specifier_re.sub(fix_wildcard, specifier)


def iter_with_callback(
iterable: Iterable[T],
callback: Callable[[int], None],
stepper: Callable[[T], int] = lambda _: 1,
) -> Iterator[T]:
completed = 0
for item in iterable:
try:
yield item
finally:
completed += stepper(item)
callback(completed)
Binary file added tests/fixtures/files/first-2.0.2.tar.gz
Binary file not shown.
39 changes: 39 additions & 0 deletions tests/test_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,42 @@ def test_find_requirement_preference_respect_source_order(session, fixtures_dir)
best = finder.find_best_match("first").best
assert best.link.filename == "first-2.0.2.tar.gz"
assert best.link.comes_from == "https://pypi.org/simple/first/"


def test_download_package_file(session, tmp_path):
finder = PackageFinder(
session=session,
index_urls=[DEFAULT_INDEX_URL],
ignore_compatibility=True,
)
found = finder.find_best_match("first").best.link
assert found.filename == "first-2.0.2.tar.gz"
for subdir in ("download", "unpack"):
(tmp_path / subdir).mkdir()

download_reports = []
unpack_reports = []

def download_reporter(link, completed, total):
download_reports.append((link, completed, total))

def unpack_reporter(filename, completed, total):
unpack_reports.append((filename, completed, total))

finder.download_and_unpack(
found,
tmp_path / "unpack",
download_dir=tmp_path / "download",
download_reporter=download_reporter,
unpack_reporter=unpack_reporter,
)
downloaded = tmp_path / "download" / found.filename
assert downloaded.exists()
size = downloaded.stat().st_size
assert size > 0
_, completed, total = download_reports[-1]
assert completed == total == size

filename, completed, total = unpack_reports[-1]
assert completed == total
assert filename == downloaded