Skip to content

Commit

Permalink
feat: add the ability to bring your own (FileDescriptorSet) bytes
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Dec 1, 2021
1 parent 26f7c0b commit 36589a6
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 5 deletions.
18 changes: 17 additions & 1 deletion protoletariat/__main__.py
Expand Up @@ -2,11 +2,13 @@

from __future__ import annotations

import sys
from pathlib import Path
from typing import IO

import click

from .fdsetgen import Buf, Protoc
from .fdsetgen import Buf, Protoc, Raw


def _overwrite(python_file: Path, code: str) -> None:
Expand Down Expand Up @@ -155,5 +157,19 @@ def buf(ctx: click.Context, buf_path: str) -> None:
Buf(buf_path).fix_imports(**ctx.obj)


@main.command(help="Rewrite imports using FileDescriptorSet bytes from a file or stdin")
@click.option(
"--descriptor-set",
type=click.File("rb"),
default=sys.stdin,
show_default=True,
show_envvar=True,
help="Path to the `buf` executable",
)
@click.pass_context
def raw(ctx: click.Context, descriptor_set: IO[bytes]) -> None:
Raw(descriptor_set.read()).fix_imports(**ctx.obj)


if __name__ == "__main__":
main()
22 changes: 18 additions & 4 deletions protoletariat/fdsetgen.py
Expand Up @@ -28,9 +28,6 @@ def _should_ignore(fd_name: str, patterns: Sequence[str]) -> bool:
class FileDescriptorSetGenerator(abc.ABC):
"""Base class that implements fixing imports."""

def __init__(self, fdset_generator_binary: str) -> None:
self.fdset_generator_binary = fdset_generator_binary

@abc.abstractmethod
def generate_file_descriptor_set_bytes(self) -> bytes:
"""Generate the bytes of a `FileDescriptorSet`"""
Expand Down Expand Up @@ -90,13 +87,15 @@ def fix_imports(


class Protoc(FileDescriptorSetGenerator):
"""Generate the FileDescriptorSet using `protoc`."""

def __init__(
self,
protoc_path: str,
proto_files: Iterable[Path],
proto_paths: Iterable[Path],
) -> None:
super().__init__(protoc_path)
self.fdset_generator_binary = protoc_path
self.proto_files = list(proto_files)
self.proto_paths = list(proto_paths)

Expand All @@ -120,6 +119,11 @@ def generate_file_descriptor_set_bytes(self) -> bytes:


class Buf(FileDescriptorSetGenerator):
"""Generate the FileDescriptorSet using `buf`."""

def __init__(self, fdset_generator_binary: str) -> None:
self.fdset_generator_binary = fdset_generator_binary

def generate_file_descriptor_set_bytes(self) -> bytes:
return subprocess.check_output(
[
Expand All @@ -131,3 +135,13 @@ def generate_file_descriptor_set_bytes(self) -> bytes:
"-",
]
)


class Raw(FileDescriptorSetGenerator):
"""Generate the FileDescriptorSet using user-provided bytes."""

def __init__(self, fdset_bytes: bytes) -> None:
self.fdset_bytes = fdset_bytes

def generate_file_descriptor_set_bytes(self) -> bytes:
return self.fdset_bytes
106 changes: 106 additions & 0 deletions protoletariat/tests/conftest.py
Expand Up @@ -6,6 +6,7 @@
import os
import shutil
import subprocess
import tempfile
from functools import partial
from pathlib import Path
from typing import Generator, Iterable, NamedTuple, Sequence
Expand Down Expand Up @@ -196,6 +197,74 @@ def do_generate(self, cli: CliRunner, *, args: Iterable[str] = ()) -> Result:
)


class RawFixture(ProtoletariatFixture):
def __init__(
self,
*,
base_dir: Path,
package: str,
proto_texts: Iterable[ProtoFile],
monkeypatch: pytest.MonkeyPatch,
grpc: bool = False,
mypy: bool = False,
mypy_grpc: bool = False,
) -> None:
super().__init__(
base_dir=base_dir,
package=package,
proto_texts=proto_texts,
monkeypatch=monkeypatch,
)
self.grpc = grpc
self.mypy = mypy
self.mypy_grpc = mypy_grpc

def do_generate(self, cli: CliRunner, *, args: Iterable[str] = ()) -> Result:
# TODO: refactor this, it duplicates a lot of what's in ProtocFixture
with tempfile.NamedTemporaryFile(delete=False) as f:
filename = f.name

protoc_args = [
"protoc",
"--include_imports",
f"--descriptor_set_out={filename}",
"--proto_path",
str(self.base_dir),
"--python_out",
str(self.package_dir),
*(str(fn) for fn, _ in self.proto_texts),
]

if self.grpc:
# XXX: why isn't this found? PATH is set properly
grpc_python_plugin = shutil.which("grpc_python_plugin")
protoc_args.extend(
(
f"--plugin=protoc-gen-grpc_python={grpc_python_plugin}",
"--grpc_python_out",
str(self.package_dir),
)
)
if self.mypy:
protoc_args.extend(("--mypy_out", str(self.package_dir)))
if self.mypy_grpc:
protoc_args.extend(("--mypy_grpc_out", str(self.package_dir)))

subprocess.check_call(protoc_args)

protol_args = [
"--python-out",
str(self.package_dir),
*args,
"raw",
f"--descriptor-set={filename}",
]
try:
return cli.invoke(main, protol_args, catch_exceptions=False)
finally:
os.unlink(filename)


@pytest.fixture
def cli() -> CliRunner:
return CliRunner()
Expand Down Expand Up @@ -250,6 +319,10 @@ def basic_cli_texts() -> list[ProtoFile]:
partial(ProtocFixture, package="basic_cli"),
id="basic_cli_protoc",
),
pytest.param(
partial(RawFixture, package="basic_cli"),
id="basic_cli_raw",
),
]
)
def basic_cli(
Expand Down Expand Up @@ -324,6 +397,10 @@ def thing_service_texts() -> list[ProtoFile]:
partial(ProtocFixture, package="thing_service", grpc=True),
id="thing_service_protoc",
),
pytest.param(
partial(RawFixture, package="thing_service", grpc=True),
id="thing_service_raw",
),
]
)
def thing_service(
Expand Down Expand Up @@ -375,6 +452,7 @@ def nested_texts() -> list[ProtoFile]:
id="nested_buf",
),
pytest.param(partial(ProtocFixture, package="nested"), id="nested_protoc"),
pytest.param(partial(RawFixture, package="nested"), id="nested_raw"),
]
)
def nested(
Expand Down Expand Up @@ -432,6 +510,16 @@ def no_imports_service_texts() -> list[ProtoFile]:
),
id="no_imports_service_protoc",
),
pytest.param(
partial(
RawFixture,
package="no_imports_service",
grpc=True,
mypy=True,
mypy_grpc=True,
),
id="no_imports_service_raw",
),
]
)
def no_imports_service(
Expand Down Expand Up @@ -510,6 +598,16 @@ def imports_service_texts() -> list[ProtoFile]:
),
id="imports_service_protoc",
),
pytest.param(
partial(
RawFixture,
package="imports_service",
grpc=True,
mypy=True,
mypy_grpc=True,
),
id="imports_service_raw",
),
]
)
def grpc_imports(
Expand Down Expand Up @@ -558,6 +656,10 @@ def long_names_texts() -> list[ProtoFile]:
partial(ProtocFixture, package="long_names", mypy=True),
id="long_names_protoc",
),
pytest.param(
partial(RawFixture, package="long_names", mypy=True),
id="long_names_raw",
),
]
)
def long_names(
Expand Down Expand Up @@ -604,6 +706,10 @@ def ignored_import_texts() -> list[ProtoFile]:
partial(ProtocFixture, package="ignored_imports"),
id="ignored_imports_protoc",
),
pytest.param(
partial(RawFixture, package="ignored_imports"),
id="ignored_imports_raw",
),
]
)
def ignored_imports(
Expand Down

0 comments on commit 36589a6

Please sign in to comment.