Skip to content

Commit

Permalink
fix: fix package relative imports
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Nov 24, 2021
1 parent af6151f commit ae75384
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 17 deletions.
3 changes: 2 additions & 1 deletion protoletariat/fdsetgen.py
Expand Up @@ -33,9 +33,10 @@ def fix_imports(
fdset = FileDescriptorSet.FromString(self.generate_file_descriptor_set_bytes())

for fd in fdset.file:
fd_name = _PROTO_SUFFIX_PATTERN.sub(r"\1", fd.name)
for dep in fd.dependency:
register_import_rewrite(
build_import_rewrite(_PROTO_SUFFIX_PATTERN.sub(r"\1", dep))
build_import_rewrite(fd_name, _PROTO_SUFFIX_PATTERN.sub(r"\1", dep))
)

rewriter = ImportRewriter()
Expand Down
9 changes: 6 additions & 3 deletions protoletariat/rewrite.py
Expand Up @@ -89,7 +89,7 @@ def __call__(self, node: ast.AST) -> ast.AST:
rewrite = Rewriter()


def build_import_rewrite(dep: str) -> Replacement:
def build_import_rewrite(proto: str, dep: str) -> Replacement:
"""Construct a replacement import for `dep`.
Parameters
Expand All @@ -107,17 +107,20 @@ def build_import_rewrite(dep: str) -> Replacement:

*import_parts, part = parts

# compute the number of dots needed to get to the package root
num_leading_dots = proto.count("/") + 1
leading_dots = "." * num_leading_dots
if not import_parts:
old = f"import {part}_pb2 as {part}__pb2"
new = f"from . import {part}_pb2 as {part}__pb2"
new = f"from {leading_dots} import {part}_pb2 as {part}__pb2"
else:
last_part = "__".join(f"{part}_pb2".split("_"))

from_ = ".".join(import_parts)
as_ = f"{'_dot_'.join(import_parts)}_dot_{last_part}"

old = f"from {from_} import {part}_pb2 as {as_}"
new = f"from .{from_} import {part}_pb2 as {as_}"
new = f"from {leading_dots}{from_} import {part}_pb2 as {as_}"

return Replacement(old=old, new=new)

Expand Down
205 changes: 192 additions & 13 deletions protoletariat/tests/test_rewrite.py
@@ -1,6 +1,7 @@
import importlib
import json
import os
import subprocess
import sys
from pathlib import Path
from typing import Iterable

Expand All @@ -22,19 +23,36 @@ def check_proto_out(out: Path) -> None:


@pytest.mark.parametrize(
("case", "expected"),
("proto", "dep", "expected"),
[
("foo", "from . import foo_pb2 as foo__pb2"),
("foo/bar", "from .foo import bar_pb2 as foo_dot_bar__pb2"),
("foo/bar/baz", "from .foo.bar import baz_pb2 as foo_dot_bar_dot_baz__pb2"),
("a", "foo", "from . import foo_pb2 as foo__pb2"),
("a/b", "foo", "from .. import foo_pb2 as foo__pb2"),
("a", "foo/bar", "from .foo import bar_pb2 as foo_dot_bar__pb2"),
("a/b", "foo/bar", "from ..foo import bar_pb2 as foo_dot_bar__pb2"),
(
"a",
"foo/bar/baz",
"from .foo.bar import baz_pb2 as foo_dot_bar_dot_baz__pb2",
),
(
"a/b",
"foo/bar/baz",
"from ..foo.bar import baz_pb2 as foo_dot_bar_dot_baz__pb2",
),
(
"a",
"foo/bar/bizz_buzz",
"from .foo.bar import bizz_buzz_pb2 as foo_dot_bar_dot_bizz__buzz__pb2",
),
(
"a/b",
"foo/bar/bizz_buzz",
"from ..foo.bar import bizz_buzz_pb2 as foo_dot_bar_dot_bizz__buzz__pb2",
),
],
)
def test_build_import_rewrite(case: str, expected: str) -> None:
old, new = build_import_rewrite(case)
def test_build_import_rewrite(proto: str, dep: str, expected: str) -> None:
old, new = build_import_rewrite(proto, dep)
assert new == expected


Expand Down Expand Up @@ -99,9 +117,10 @@ def test_cli(
this_proto: Path,
other_proto: Path,
baz_bizz_buzz_other_proto: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
# 1. generated code using protoc
out = tmp_path.joinpath("out")
out = tmp_path.joinpath("out_cli")
out.mkdir()
subprocess.run(
[
Expand Down Expand Up @@ -149,10 +168,39 @@ def test_cli(

check_import_lines(result, expected_lines)

result = cli.invoke(
main,
[
"-g",
str(out),
"--overwrite",
"--create-init",
"protoc",
"-p",
str(this_proto.parent),
"-p",
str(other_proto.parent),
"-p",
str(baz_bizz_buzz_other_proto.parent),
str(this_proto),
str(other_proto),
str(baz_bizz_buzz_other_proto),
],
)
assert result.exit_code == 0

with monkeypatch.context() as m:
m.syspath_prepend(str(tmp_path))
importlib.import_module("out_cli")
importlib.import_module("out_cli.other_pb2")
importlib.import_module("out_cli.this_pb2")

assert str(tmp_path) not in sys.path


@pytest.fixture
def thing1(tmp_path: Path) -> Path:
code = """
def thing1_text() -> str:
return """
// thing1.proto
syntax = "proto3";
Expand All @@ -164,13 +212,31 @@ def thing1(tmp_path: Path) -> Path:
Thing2 thing2 = 1;
}
"""


@pytest.fixture
def thing1(thing1_text: str, tmp_path: Path) -> Path:
p = tmp_path.joinpath("thing1.proto")
p.write_text(code)
p.write_text(thing1_text)
return p


@pytest.fixture
def thing2(tmp_path: Path) -> Path:
def thing2_text() -> str:
return """
// thing2.proto
syntax = "proto3";
package things;
message Thing2 {
string data = 1;
}
"""


@pytest.fixture
def thing2(thing2_text: str, tmp_path: Path) -> Path:
code = """
// thing2.proto
syntax = "proto3";
Expand All @@ -186,6 +252,36 @@ def thing2(tmp_path: Path) -> Path:
return p


@pytest.fixture
def thing1_nested_text() -> str:
return """
// thing1.proto
syntax = "proto3";
import "a/b/c/thing2.proto";
package thing1.a;
message Thing1 {
thing2.a.b.c.Thing2 thing2 = 1;
}
"""


@pytest.fixture
def thing2_nested_text() -> str:
return """
// thing2.proto
syntax = "proto3";
package thing2.a.b.c;
message Thing2 {
string data = 1;
}
"""


def test_example_protoc(
cli: CliRunner,
tmp_path: Path,
Expand Down Expand Up @@ -237,11 +333,12 @@ def test_example_buf(
tmp_path: Path,
thing1: Path,
thing2: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
out = tmp_path.joinpath("out")
out.mkdir()

os.chdir(tmp_path)
monkeypatch.chdir(tmp_path)
with tmp_path.joinpath("buf.yaml").open(mode="w") as f:
json.dump(
{
Expand Down Expand Up @@ -298,3 +395,85 @@ def test_example_buf(
check_import_lines(result, expected_lines)

assert out.joinpath("__init__.py").exists()


def test_nested_buf(
cli: CliRunner,
tmp_path: Path,
thing1_nested_text: str,
thing2_nested_text: str,
monkeypatch: pytest.MonkeyPatch,
) -> None:
out = tmp_path.joinpath("out_nested")
out.mkdir()

c = tmp_path / "a" / "b" / "c"
c.mkdir(parents=True, exist_ok=True)

thing2 = c / "thing2.proto"
thing2.write_text(thing2_nested_text)

d = tmp_path / "d"
d.mkdir()

thing1 = d / "thing1.proto"
thing1.write_text(thing1_nested_text)

monkeypatch.chdir(tmp_path)
with tmp_path.joinpath("buf.yaml").open(mode="w") as f:
json.dump(
{
"version": "v1",
"lint": {"use": ["DEFAULT"]},
"breaking": {"use": ["FILE"]},
},
f,
)

with tmp_path.joinpath("buf.gen.yaml").open(mode="w") as f:
json.dump(
{
"version": "v1",
"plugins": [
{
"name": "python",
"out": str(out),
},
],
},
f,
)

subprocess.check_call(
[
"protoc",
str(thing1),
str(thing2),
"--proto_path",
str(tmp_path),
"--python_out",
str(out),
],
)

check_proto_out(out)

result = cli.invoke(
main,
[
"-g",
str(out),
"--overwrite",
"--create-init",
"buf",
],
)
assert result.exit_code == 0

# check that we can import nested things
with monkeypatch.context() as m:
m.syspath_prepend(str(tmp_path))

importlib.import_module("out_nested")
importlib.import_module("out_nested.a.b.c.thing2_pb2")
importlib.import_module("out_nested.d.thing1_pb2")

0 comments on commit ae75384

Please sign in to comment.