Skip to content

Commit

Permalink
Merge pull request #23 from thatch/custom-headers
Browse files Browse the repository at this point in the history
Allow specifying custom request headers.
  • Loading branch information
jwodder committed Feb 24, 2024
2 parents 75569b2 + 1ee05df commit 6b1319d
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 5 deletions.
58 changes: 53 additions & 5 deletions src/pypi_simple/client.py
Expand Up @@ -99,6 +99,7 @@ def get_index_page(
self,
timeout: float | tuple[float, float] | None = None,
accept: Optional[str] = None,
headers: Optional[dict[str, str]] = None,
) -> IndexPage:
"""
Fetches the index/root page from the simple repository and returns an
Expand All @@ -113,12 +114,18 @@ def get_index_page(
``accept`` parameter added
.. versionchanged:: 1.5.0
``headers`` parameter added
:param timeout: optional timeout to pass to the ``requests`` call
:type timeout: float | tuple[float,float] | None
:param Optional[str] accept:
The :mailheader:`Accept` header to send in order to
specify what serialization format the server should return;
defaults to the value supplied on client instantiation
:param Optional[dict[str, str]] headers:
Custom headers to provide for the request.
:rtype: IndexPage
:raises requests.HTTPError: if the repository responds with an HTTP
error code
Expand All @@ -127,8 +134,11 @@ def get_index_page(
:raises UnsupportedRepoVersionError: if the repository version has a
greater major component than the supported repository version
"""
request_headers = {"Accept": accept or self.accept}
if headers:
request_headers.update(headers)
r = self.s.get(
self.endpoint, timeout=timeout, headers={"Accept": accept or self.accept}
self.endpoint, timeout=timeout, headers=request_headers,
)
r.raise_for_status()
return IndexPage.from_response(r)
Expand All @@ -138,6 +148,7 @@ def stream_project_names(
chunk_size: int = 65535,
timeout: float | tuple[float, float] | None = None,
accept: Optional[str] = None,
headers: Optional[dict[str, str]] = None,
) -> Iterator[str]:
"""
Returns a generator of names of projects available in the repository.
Expand All @@ -164,6 +175,10 @@ def stream_project_names(
``accept`` parameter added
.. versionchanged:: 1.5.0
``headers`` parameter added
:param int chunk_size: how many bytes to read from the response at a
time
:param timeout: optional timeout to pass to the ``requests`` call
Expand All @@ -172,6 +187,8 @@ def stream_project_names(
The :mailheader:`Accept` header to send in order to
specify what serialization format the server should return;
defaults to the value supplied on client instantiation
:param Optional[dict[str, str]] headers:
Custom headers to provide for the request.
:rtype: Iterator[str]
:raises requests.HTTPError: if the repository responds with an HTTP
error code
Expand All @@ -180,11 +197,14 @@ def stream_project_names(
:raises UnsupportedRepoVersionError: if the repository version has a
greater major component than the supported repository version
"""
request_headers = {"Accept": accept or self.accept}
if headers:
request_headers.update(headers)
with self.s.get(
self.endpoint,
stream=True,
timeout=timeout,
headers={"Accept": accept or self.accept},
headers=request_headers,
) as r:
r.raise_for_status()
ct = ContentType.parse(r.headers.get("content-type", "text/html"))
Expand All @@ -205,6 +225,7 @@ def get_project_page(
project: str,
timeout: float | tuple[float, float] | None = None,
accept: Optional[str] = None,
headers: Optional[dict[str, str]] = None,
) -> ProjectPage:
"""
Fetches the page for the given project from the simple repository and
Expand All @@ -219,6 +240,10 @@ def get_project_page(
- ``accept`` parameter added
.. versionchanged:: 1.5.0
``headers`` parameter added
:param str project: The name of the project to fetch information on.
The name does not need to be normalized.
:param timeout: optional timeout to pass to the ``requests`` call
Expand All @@ -227,6 +252,8 @@ def get_project_page(
The :mailheader:`Accept` header to send in order to
specify what serialization format the server should return;
defaults to the value supplied on client instantiation
:param Optional[dict[str, str]] headers:
Custom headers to provide for the request.
:rtype: ProjectPage
:raises NoSuchProjectError: if the repository responds with a 404 error
code
Expand All @@ -237,8 +264,11 @@ def get_project_page(
:raises UnsupportedRepoVersionError: if the repository version has a
greater major component than the supported repository version
"""
request_headers = {"Accept": accept or self.accept}
if headers:
request_headers.update(headers)
url = self.get_project_url(project)
r = self.s.get(url, timeout=timeout, headers={"Accept": accept or self.accept})
r = self.s.get(url, timeout=timeout, headers=request_headers)
if r.status_code == 404:
raise NoSuchProjectError(project, url)
r.raise_for_status()
Expand All @@ -262,6 +292,7 @@ def download_package(
keep_on_error: bool = False,
progress: Optional[Callable[[Optional[int]], ProgressTracker]] = None,
timeout: float | tuple[float, float] | None = None,
headers: Optional[dict[str, str]] = None,
) -> None:
"""
Download the given `DistributionPackage` to the given path.
Expand All @@ -276,6 +307,10 @@ def download_package(
``update(increment: int)`` method that will be passed the size of each
downloaded chunk as each chunk is received.
.. versionchanged:: 1.5.0
``headers`` parameter added
:param DistributionPackage pkg: the distribution package to download
:param path:
the path at which to save the downloaded file; any parent
Expand Down Expand Up @@ -304,7 +339,7 @@ def download_package(
digester = DigestChecker(pkg.digests)
else:
digester = NullDigestChecker()
with self.s.get(pkg.url, stream=True, timeout=timeout) as r:
with self.s.get(pkg.url, stream=True, timeout=timeout, headers=headers) as r:
r.raise_for_status()
try:
content_length = int(r.headers["Content-Length"])
Expand Down Expand Up @@ -333,6 +368,7 @@ def get_package_metadata_bytes(
pkg: DistributionPackage,
verify: bool = True,
timeout: float | tuple[float, float] | None = None,
headers: Optional[dict[str, str]] = None,
) -> bytes:
"""
.. versionadded:: 1.5.0
Expand All @@ -356,6 +392,9 @@ def get_package_metadata_bytes(
whether to verify the metadata's digests against the retrieved data
:param timeout: optional timeout to pass to the ``requests`` call
:type timeout: float | tuple[float,float] | None
:param Optional[dict[str, str]] headers:
Custom headers to provide for the request.
:rtype: bytes
:raises NoMetadataError:
if the repository responds with a 404 error code
Expand All @@ -373,7 +412,7 @@ def get_package_metadata_bytes(
digester = DigestChecker(pkg.metadata_digests or {})
else:
digester = NullDigestChecker()
r = self.s.get(pkg.metadata_url, timeout=timeout)
r = self.s.get(pkg.metadata_url, timeout=timeout, headers=headers)
if r.status_code == 404:
raise NoMetadataError(pkg.filename)
r.raise_for_status()
Expand All @@ -386,6 +425,7 @@ def get_package_metadata(
pkg: DistributionPackage,
verify: bool = True,
timeout: float | tuple[float, float] | None = None,
headers: Optional[dict[str, str]] = None,
) -> str:
"""
.. versionadded:: 1.3.0
Expand All @@ -407,12 +447,19 @@ def get_package_metadata(
.. _the packaging package:
https://packaging.pypa.io/en/stable/metadata.html
.. versionchanged:: 1.5.0
``headers`` parameter added
:param DistributionPackage pkg:
the distribution package to retrieve the metadata of
:param bool verify:
whether to verify the metadata's digests against the retrieved data
:param timeout: optional timeout to pass to the ``requests`` call
:type timeout: float | tuple[float,float] | None
:param Optional[dict[str, str]] headers:
Custom headers to provide for the request.
:rtype: str
:raises NoMetadataError:
if the repository responds with a 404 error code
Expand All @@ -429,6 +476,7 @@ def get_package_metadata(
pkg,
verify,
timeout,
headers,
).decode("utf-8", "surrogateescape")


Expand Down
44 changes: 44 additions & 0 deletions test/test_client.py
Expand Up @@ -888,3 +888,47 @@ def test_metadata_encoding() -> None:
)
assert simple.get_package_metadata_bytes(pkg) == b"\xff\xfe\x03\x26"
assert simple.get_package_metadata(pkg) == "\udcff\udcfe\u0003\u0026"


@responses.activate
def test_custom_headers_get_index_page() -> None:
with (DATA_DIR / "simple01.html").open() as fp:
responses.add(
method=responses.GET,
url="https://test.nil/simple/",
body=fp.read(),
content_type="text/html",
match=[responses.matchers.header_matcher({"X-Custom": "foo"})],
)
with PyPISimple("https://test.nil/simple/") as simple:
# Just check that the method returns successfully
simple.get_index_page(headers={"X-Custom": "foo"})


@responses.activate
def test_custom_headers_stream_project_names() -> None:
with (DATA_DIR / "simple01.html").open() as fp:
responses.add(
method=responses.GET,
url="https://test.nil/simple/",
body=fp.read(),
content_type="text/html",
match=[responses.matchers.header_matcher({"X-Custom": "foo"})],
)
with PyPISimple("https://test.nil/simple/") as simple:
# Just check that the method returns successfully
list(simple.stream_project_names(headers={"X-Custom": "foo"}))


@responses.activate
def test_custom_headers_get_project_page() -> None:
with (DATA_DIR / "aws-adfs-ebsco.html").open() as fp:
responses.add(
method=responses.GET,
url="https://test.nil/simple/aws-adfs-ebsco/",
body=fp.read(),
content_type="text/html",
match=[responses.matchers.header_matcher({"X-Custom": "foo"})],
)
with PyPISimple("https://test.nil/simple/") as simple:
simple.get_project_page("aws-adfs-ebsco", headers={"X-Custom": "foo"})

0 comments on commit 6b1319d

Please sign in to comment.