Skip to content

Commit

Permalink
enable old style single-line annotations with --annotation-style=line
Browse files Browse the repository at this point in the history
  • Loading branch information
AndydeCleyre committed Sep 20, 2021
1 parent f1ee721 commit d28dfbe
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 15 deletions.
8 changes: 8 additions & 0 deletions piptools/scripts/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,12 @@ def _get_default_option(option_name: str) -> Any:
default=True,
help="Annotate results, indicating where dependencies come from",
)
@click.option(
"--annotation-style",
type=click.Choice(("line", "split")),
default="split",
help="Choose the format of annotation comments",
)
@click.option(
"-U",
"--upgrade/--no-upgrade",
Expand Down Expand Up @@ -244,6 +250,7 @@ def cli(
header: bool,
emit_trusted_host: bool,
annotate: bool,
annotation_style: str,
upgrade: bool,
upgrade_packages: Tuple[str, ...],
output_file: Union[LazyFile, IO[Any], None],
Expand Down Expand Up @@ -471,6 +478,7 @@ def cli(
emit_index_url=emit_index_url,
emit_trusted_host=emit_trusted_host,
annotate=annotate,
annotation_style=annotation_style,
strip_extras=strip_extras,
generate_hashes=generate_hashes,
default_index_url=repository.DEFAULT_INDEX_URL,
Expand Down
38 changes: 28 additions & 10 deletions piptools/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,23 @@ def _comes_from_as_string(ireq: InstallRequirement) -> str:
return key_from_ireq(ireq.comes_from)


def annotation_style_split(required_by: Set[str]) -> str:
sorted_required_by = sorted(required_by)
if len(sorted_required_by) == 1:
source = sorted_required_by[0]
annotation = "# via " + source
else:
annotation_lines = ["# via"]
for source in sorted_required_by:
annotation_lines.append(" # " + source)
annotation = "\n".join(annotation_lines)
return annotation


def annotation_style_line(required_by: Set[str]) -> str:
return f"# via {', '.join(sorted(required_by))}"


class OutputWriter:
def __init__(
self,
Expand All @@ -61,6 +78,7 @@ def __init__(
emit_index_url: bool,
emit_trusted_host: bool,
annotate: bool,
annotation_style: str,
strip_extras: bool,
generate_hashes: bool,
default_index_url: str,
Expand All @@ -79,6 +97,7 @@ def __init__(
self.emit_index_url = emit_index_url
self.emit_trusted_host = emit_trusted_host
self.annotate = annotate
self.annotation_style = annotation_style
self.strip_extras = strip_extras
self.generate_hashes = generate_hashes
self.default_index_url = default_index_url
Expand Down Expand Up @@ -258,15 +277,14 @@ def _format_requirement(
required_by.add(_comes_from_as_string(ireq))

if required_by:
sorted_required_by = sorted(required_by)
if len(sorted_required_by) == 1:
source = sorted_required_by[0]
annotation = " # via " + source
else:
annotation_lines = [" # via"]
for source in sorted_required_by:
annotation_lines.append(" # " + source)
annotation = "\n".join(annotation_lines)
line = f"{line}\n{comment(annotation)}"
annotation = {
"split": annotation_style_split,
"line": annotation_style_line,
}[self.annotation_style](required_by)
sep = {"split": "\n ", "line": "\n " if ireq_hashes else " "}[
self.annotation_style
]
lines = f"{line:24}{sep}{comment(annotation)}".splitlines()
line = "\n".join(ln.rstrip() for ln in lines)

return line
11 changes: 6 additions & 5 deletions tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def writer(tmpdir_cwd):
emit_index_url=True,
emit_trusted_host=True,
annotate=True,
annotation_style="split",
generate_hashes=False,
default_index_url=None,
index_urls=[],
Expand All @@ -57,37 +58,37 @@ def test_format_requirement_annotation_editable(from_editable, writer):

assert writer._format_requirement(
ireq
) == "-e git+git://fake.org/x/y.git#egg=y\n" + comment(" # via xyz")
) == "-e git+git://fake.org/x/y.git#egg=y\n " + comment("# via xyz")


def test_format_requirement_annotation(from_line, writer):
ireq = from_line("test==1.2")
ireq.comes_from = "xyz"

assert writer._format_requirement(ireq) == "test==1.2\n" + comment(" # via xyz")
assert writer._format_requirement(ireq) == "test==1.2\n " + comment("# via xyz")


def test_format_requirement_annotation_lower_case(from_line, writer):
ireq = from_line("Test==1.2")
ireq.comes_from = "xyz"

assert writer._format_requirement(ireq) == "test==1.2\n" + comment(" # via xyz")
assert writer._format_requirement(ireq) == "test==1.2\n " + comment("# via xyz")


def test_format_requirement_for_primary(from_line, writer):
"Primary packages should get annotated."
ireq = from_line("test==1.2")
ireq.comes_from = "xyz"

assert writer._format_requirement(ireq) == "test==1.2\n" + comment(" # via xyz")
assert writer._format_requirement(ireq) == "test==1.2\n " + comment("# via xyz")


def test_format_requirement_for_primary_lower_case(from_line, writer):
"Primary packages should get annotated."
ireq = from_line("Test==1.2")
ireq.comes_from = "xyz"

assert writer._format_requirement(ireq) == "test==1.2\n" + comment(" # via xyz")
assert writer._format_requirement(ireq) == "test==1.2\n " + comment("# via xyz")


def test_format_requirement_environment_marker(from_line, writer):
Expand Down

0 comments on commit d28dfbe

Please sign in to comment.