Skip to content

Commit

Permalink
Filter out the user provided unsafe packages (#1766)
Browse files Browse the repository at this point in the history
  • Loading branch information
q0w committed Dec 11, 2022
1 parent 3c2eb7f commit ef88dfc
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 19 deletions.
1 change: 1 addition & 0 deletions piptools/scripts/compile.py
Expand Up @@ -610,6 +610,7 @@ def cli(
)
writer.write(
results=results,
unsafe_packages=resolver.unsafe_packages,
unsafe_requirements=resolver.unsafe_constraints,
markers={
key_from_ireq(ireq): ireq.markers for ireq in constraints if ireq.markers
Expand Down
24 changes: 12 additions & 12 deletions piptools/writer.py
Expand Up @@ -16,7 +16,6 @@

from .logging import log
from .utils import (
UNSAFE_PACKAGES,
comment,
dedup,
format_requirement,
Expand Down Expand Up @@ -176,13 +175,13 @@ def write_flags(self) -> Iterator[str]:
def _iter_lines(
self,
results: set[InstallRequirement],
unsafe_requirements: set[InstallRequirement] | None = None,
markers: dict[str, Marker] | None = None,
unsafe_requirements: set[InstallRequirement],
unsafe_packages: set[str],
markers: dict[str, Marker],
hashes: dict[InstallRequirement, set[str]] | None = None,
) -> Iterator[str]:
# default values
unsafe_requirements = unsafe_requirements or set()
markers = markers or {}
unsafe_packages = unsafe_packages if not self.allow_unsafe else set()
hashes = hashes or {}

# Check for unhashed or unpinned packages if at least one package does have
Expand All @@ -199,12 +198,10 @@ def _iter_lines(
yield line
yielded = True

unsafe_requirements = (
{r for r in results if r.name in UNSAFE_PACKAGES}
if not unsafe_requirements
else unsafe_requirements
)
packages = {r for r in results if r.name not in UNSAFE_PACKAGES}
unsafe_requirements = unsafe_requirements or {
r for r in results if r.name in unsafe_packages
}
packages = {r for r in results if r.name not in unsafe_packages}

if packages:
for ireq in sorted(packages, key=self._sort_key):
Expand Down Expand Up @@ -247,6 +244,7 @@ def write(
self,
results: set[InstallRequirement],
unsafe_requirements: set[InstallRequirement],
unsafe_packages: set[str],
markers: dict[str, Marker],
hashes: dict[InstallRequirement, set[str]] | None,
) -> None:
Expand All @@ -259,7 +257,9 @@ def write(
line_buffering=True,
)
try:
for line in self._iter_lines(results, unsafe_requirements, markers, hashes):
for line in self._iter_lines(
results, unsafe_requirements, unsafe_packages, markers, hashes
):
if self.dry_run:
# Bypass the log level to always print this during a dry run
log.log(line)
Expand Down
53 changes: 53 additions & 0 deletions tests/test_cli_compile.py
Expand Up @@ -1546,6 +1546,59 @@ def test_allow_unsafe_option(pip_conf, monkeypatch, runner, option, expected):
assert out.stdout == expected


@pytest.mark.parametrize(
("unsafe_package", "expected"),
(
(
"small-fake-with-deps",
dedent(
"""\
small-fake-a==0.1
small-fake-b==0.3
# The following packages are considered to be unsafe in a requirements file:
# small-fake-with-deps
"""
),
),
(
"small-fake-a",
dedent(
"""\
small-fake-b==0.3
small-fake-with-deps==0.1
# The following packages are considered to be unsafe in a requirements file:
# small-fake-a
"""
),
),
),
)
def test_unsafe_package_option(pip_conf, monkeypatch, runner, unsafe_package, expected):
monkeypatch.setattr("piptools.resolver.UNSAFE_PACKAGES", {"small-fake-with-deps"})
with open("requirements.in", "w") as req_in:
req_in.write("small-fake-b\n")
req_in.write("small-fake-with-deps")

out = runner.invoke(
cli,
[
"--output-file",
"-",
"--quiet",
"--no-header",
"--no-emit-options",
"--no-annotate",
"--unsafe-package",
unsafe_package,
],
)

assert out.exit_code == 0, out
assert out.stdout == expected


@pytest.mark.parametrize(
("option", "attr", "expected"),
(("--cert", "cert", "foo.crt"), ("--client-cert", "client_cert", "bar.pem")),
Expand Down
34 changes: 27 additions & 7 deletions tests/test_writer.py
Expand Up @@ -113,7 +113,10 @@ def test_iter_lines__unsafe_dependencies(writer, from_line, allow_unsafe):
writer.emit_header = False

lines = writer._iter_lines(
[from_line("test==1.2")], [from_line("setuptools==1.10.0")]
{from_line("test==1.2")},
{from_line("setuptools==1.10.0")},
unsafe_packages=set(),
markers={},
)

expected_lines = (
Expand All @@ -132,7 +135,9 @@ def test_iter_lines__unsafe_with_hashes(capsys, writer, from_line):
unsafe_ireqs = [from_line("setuptools==1.10.0")]
hashes = {ireqs[0]: {"FAKEHASH"}, unsafe_ireqs[0]: set()}

lines = writer._iter_lines(ireqs, unsafe_ireqs, hashes=hashes)
lines = writer._iter_lines(
ireqs, unsafe_ireqs, unsafe_packages=set(), markers={}, hashes=hashes
)

expected_lines = (
"test==1.2 \\\n --hash=FAKEHASH",
Expand All @@ -152,7 +157,13 @@ def test_iter_lines__hash_missing(capsys, writer, from_line):
ireqs = [from_line("test==1.2"), from_line("file:///example/#egg=example")]
hashes = {ireqs[0]: {"FAKEHASH"}, ireqs[1]: set()}

lines = writer._iter_lines(ireqs, hashes=hashes)
lines = writer._iter_lines(
ireqs,
hashes=hashes,
unsafe_requirements=set(),
unsafe_packages=set(),
markers={},
)

expected_lines = (
MESSAGE_UNHASHED_PACKAGE,
Expand All @@ -178,7 +189,13 @@ def test_iter_lines__no_warn_if_only_unhashable_packages(writer, from_line):
]
hashes = {ireq: set() for ireq in ireqs}

lines = writer._iter_lines(ireqs, hashes=hashes)
lines = writer._iter_lines(
ireqs,
hashes=hashes,
unsafe_requirements=set(),
unsafe_packages=set(),
markers={},
)

expected_lines = (
"unhashable-pkg1 @ file:///unhashable-pkg1/",
Expand Down Expand Up @@ -389,16 +406,19 @@ def test_write_order(writer, from_line):
"""
writer.emit_header = False

packages = [
packages = {
from_line("package_a==0.1"),
from_line("Package-b==2.3.4"),
from_line("Package==5.6"),
from_line("package2==7.8.9"),
]
}
expected_lines = [
"package==5.6",
"package-a==0.1",
"package-b==2.3.4",
"package2==7.8.9",
]
assert list(writer._iter_lines(packages)) == expected_lines
result = writer._iter_lines(
packages, unsafe_requirements=set(), unsafe_packages=set(), markers={}
)
assert list(result) == expected_lines

0 comments on commit ef88dfc

Please sign in to comment.