From 0f3e011058f25c602e3c14023766ed41ac0d02ee Mon Sep 17 00:00:00 2001 From: Madhav Kumar <47245900+kumar-ish@users.noreply.github.com> Date: Mon, 22 Apr 2024 18:05:36 +1000 Subject: [PATCH] Move to using protogen's .pyi file generator Signed-off-by: Madhav Kumar <47245900+kumar-ish@users.noreply.github.com> --- .gitignore | 3 +- .reuse/dep5 | 4 +- CMakeLists.txt | 26 +-- tools/protoc-gen-mypy.bat | 4 - tools/protoc-gen-mypy.py | 424 ------------------------------------ tools/protoc-gen-mypy.sh.in | 4 - 6 files changed, 7 insertions(+), 458 deletions(-) delete mode 100644 tools/protoc-gen-mypy.bat delete mode 100755 tools/protoc-gen-mypy.py delete mode 100644 tools/protoc-gen-mypy.sh.in diff --git a/.gitignore b/.gitignore index 84df2918ed3..ab7818a29aa 100644 --- a/.gitignore +++ b/.gitignore @@ -81,7 +81,6 @@ compile_commands.json # generated files onnx/version.py compile_commands.json -tools/protoc-gen-mypy.sh # test generated files .cache @@ -112,4 +111,4 @@ docs/docsgen/**/*.onnx docs/docsgen/**/*.pb # PyEnv files -.python-version \ No newline at end of file +.python-version diff --git a/.reuse/dep5 b/.reuse/dep5 index 115752ab129..e3ab3c3d9e3 100644 --- a/.reuse/dep5 +++ b/.reuse/dep5 @@ -130,11 +130,11 @@ Files: cmake/ONNXConfig*.in CMakeLists.txt cmake/*cmake cmake/ COPYRIGHT: Copyright (c) ONNX Project Contributors License: Apache-2.0 -Files: .vscode/settings.json docs/docsgen/source/_templates/layout.html docs/docsgen/source/_templates/sidebar-nav-bs.html docs/images/onnx_hub_arch.svg docs/onnx-horizontal-color.png tools/protoc-gen-mypy.sh.in +Files: .vscode/settings.json docs/docsgen/source/_templates/layout.html docs/docsgen/source/_templates/sidebar-nav-bs.html docs/images/onnx_hub_arch.svg docs/onnx-horizontal-color.png COPYRIGHT: Copyright (c) ONNX Project Contributors License: Apache-2.0 -Files: onnx/gen_proto.py tools/protoc-gen-mypy.bat onnx/reference/op_run.py +Files: onnx/gen_proto.py onnx/reference/op_run.py COPYRIGHT: Copyright (c) ONNX Project Contributors License: Apache-2.0 diff --git a/CMakeLists.txt b/CMakeLists.txt index 71c2f54752f..2cb1377736e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -373,29 +373,11 @@ function(RELATIVE_PROTOBUF_GENERATE_CPP NAME SRCS HDRS ROOT_DIR DEPEND) --cpp_out ${ONNX_DLLEXPORT_STR}${CMAKE_CURRENT_BINARY_DIR}) if(BUILD_ONNX_PYTHON) - list(APPEND PROTOC_ARGS --python_out - ${CMAKE_CURRENT_BINARY_DIR}) + list(APPEND PROTOC_ARGS --python_out) if(ONNX_GEN_PB_TYPE_STUBS) - if(NOT WIN32) - # use PYTHON_EXECUTABLE to python protoc-gen-mypy.py - configure_file( - ${ROOT_DIR}/tools/protoc-gen-mypy.sh.in - ${PROJECT_BINARY_DIR}/tools/protoc-gen-mypy.sh - @ONLY) - execute_process(COMMAND chmod +x ${PROJECT_BINARY_DIR}/tools/protoc-gen-mypy.sh) - configure_file( - ${ROOT_DIR}/tools/protoc-gen-mypy.py - ${PROJECT_BINARY_DIR}/tools/protoc-gen-mypy.py - COPYONLY) - set(PROTOC_MYPY_PLUGIN_FILE ${PROJECT_BINARY_DIR}/tools/protoc-gen-mypy.sh) - else() - set(PROTOC_MYPY_PLUGIN_FILE ${ROOT_DIR}/tools/protoc-gen-mypy.bat) - endif() - list(APPEND PROTOC_ARGS - --plugin - protoc-gen-mypy=${PROTOC_MYPY_PLUGIN_FILE} - --mypy_out - ${ONNX_DLLEXPORT_STR}${CMAKE_CURRENT_BINARY_DIR}) + list(APPEND PROTOC_ARGS pyi_out:${CMAKE_CURRENT_BINARY_DIR}) + else() + list(APPEND PROTOC_ARGS ${CMAKE_CURRENT_BINARY_DIR}) endif() endif() if(NOT ONNX_PROTOC_EXECUTABLE) diff --git a/tools/protoc-gen-mypy.bat b/tools/protoc-gen-mypy.bat deleted file mode 100644 index bf3d27d7c84..00000000000 --- a/tools/protoc-gen-mypy.bat +++ /dev/null @@ -1,4 +0,0 @@ -:: SPDX-License-Identifier: Apache-2.0 - -@echo off -python -u "%~dp0\protoc-gen-mypy.py" diff --git a/tools/protoc-gen-mypy.py b/tools/protoc-gen-mypy.py deleted file mode 100755 index 7226d2cb5e9..00000000000 --- a/tools/protoc-gen-mypy.py +++ /dev/null @@ -1,424 +0,0 @@ -#!/usr/bin/env python - -# Copyright (c) ONNX Project Contributors - -# Copyright (c) ONNX Project Contributors -# -# SPDX-License-Identifier: Apache-2.0 - -# Taken from https://github.com/dropbox/mypy-protobuf/blob/d984389124eae6dbbb517f766b9266bb32171510/python/protoc-gen-mypy -# (Apache 2.0 License) -# with own fixes to -# - appease flake8 -# - exit without error when protobuf isn't installed -# - fix recognition of whether an identifier is defined locally -# (unfortunately, we use a python package name ONNX_NAMESPACE_FOO_BAR_FOR_CI -# on CI, which by the original protoc-gen-mypy script was recognized to be -# camel case and therefore handled as an entry in the local package) - - -"""Protoc Plugin to generate mypy stubs. Loosely based on @zbarsky's go implementation""" -from __future__ import annotations - -import sys -from collections import defaultdict -from contextlib import contextmanager -from typing import Any, Callable, Generator, cast - -try: - import google.protobuf.descriptor_pb2 as d_typed - from google.protobuf.compiler import plugin_pb2 as plugin -except ImportError as e: - raise RuntimeError("Failed to generate mypy stubs") from e - - -# Hax to get around fact that google protobuf libraries aren't in typeshed yet -d: Any = d_typed - -# Split the string so phabricator doesn't think this file is generated -GENERATED = "@ge" + "nerated" -HEADER = ( - f"# {GENERATED} by protoc-gen-mypy.py. Do not edit!\n" - "# mypy: disable-error-code=override\n" -) - - -class Descriptors: - def __init__(self, request: plugin.CodeGeneratorRequest) -> None: - files = {f.name: f for f in request.proto_file} - to_generate = {n: files[n] for n in request.file_to_generate} - self.files: dict[str, d.FileDescriptorProto] = files - self.to_generate: dict[str, d.FileDescriptorProto] = to_generate - self.messages: dict[str, d.DescriptorProto] = {} - self.message_to_fd: dict[str, d.FileDescriptorProto] = {} - - def _add_enums( - enums: d.EnumDescriptorProto, prefix: str, fd: d.FileDescriptorProto - ) -> None: - for enum in enums: - self.message_to_fd[prefix + enum.name] = fd - - def _add_messages( - messages: d.DescriptorProto, prefix: str, fd: d.FileDescriptorProto - ) -> None: - for message in messages: - self.messages[prefix + message.name] = message - self.message_to_fd[prefix + message.name] = fd - sub_prefix = prefix + message.name + "." - _add_messages(message.nested_type, sub_prefix, fd) - _add_enums(message.enum_type, sub_prefix, fd) - - for fd in request.proto_file: - start_prefix = "." + fd.package + "." - _add_messages(fd.message_type, start_prefix, fd) - _add_enums(fd.enum_type, start_prefix, fd) - - -class PkgWriter: - """Writes a single pyi file""" - - def __init__(self, fd: d.FileDescriptorProto, descriptors: Descriptors) -> None: - self.fd = fd - self.descriptors = descriptors - self.lines: list[str] = [] - self.indent = "" - - # dictionary of x->y for `from {x} import {y}` - self.imports: dict[str, set[str]] = defaultdict(set) - self.locals: set[str] = set() - - def _import(self, path: str, name: str, import_as: str | None = None) -> str: - """Imports a stdlib path and returns a handle to it - eg. self._import("typing", "Optional") -> "Optional" - """ - imp = path.replace("/", ".") - if import_as is not None: - self.imports[imp].add(f"{name} as {import_as}") - return import_as - else: - self.imports[imp].add(name) - return name - - def _import_message(self, type_name: d.FieldDescriptorProto) -> str: - """Import a referenced message and return a handle""" - name = cast(str, type_name) - - if name[0] == "." and name[1].isupper() and name[2].islower(): - # Message defined in this file - return name[1:] - - message_fd = self.descriptors.message_to_fd[name] - if message_fd.name == self.fd.name: - # message defined in this package - split = name.split(".") - for i, segment in enumerate(split): - if segment and segment[0].isupper() and segment[1].islower(): - return ".".join(split[i:]) - - # Not in package. Must import - split = name.split(".") - for i, segment in enumerate(split): - if segment and segment[0].isupper() and segment[1].islower(): - assert message_fd.name.endswith(".proto") - import_name = self._import( - message_fd.name[:-6].replace("-", "_") + "_pb2", segment - ) - remains = ".".join(split[i + 1 :]) - if not remains: - return import_name - raise AssertionError("Don't support nested imports yet") - # return new_nested_import(import_name, remains) - - raise AssertionError("Could not parse local name " + name) - - @contextmanager # type: ignore - def _indent(self) -> Generator[None, None, None]: - self.indent = self.indent + " " - yield - self.indent = self.indent[:-4] - - def _write_line(self, line: str, *args: str) -> None: - self.lines.append(self.indent + line.format(*args)) - - def write_enums(self, enums: list[d.EnumDescriptorProto]) -> None: - line = self._write_line - for enum in enums: - line("class {}(int):", enum.name) - with self._indent(): - line("@classmethod") - line("def Name(cls, number: int) -> str: ...") - line("@classmethod") - line("def Value(cls, name: str) -> int: ...") - line("@classmethod") - line("def keys(cls) -> {}[str]: ...", self._import("typing", "List")) - line("@classmethod") - line("def values(cls) -> {}[int]: ...", self._import("typing", "List")) - line("@classmethod") - line( - "def items(cls) -> {}[{}[str, int]]: ...", - self._import("typing", "List"), - self._import("typing", "Tuple"), - ) - - for val in enum.value: - line( - "{} = {}({}, {})", - val.name, - self._import("typing", "cast"), - enum.name, - val.number, - ) - line("") - - def write_messages(self, messages: list[d.DescriptorProto], prefix: str) -> None: - line = self._write_line - message_class = self._import("google.protobuf.message", "Message") - - for desc in messages: - self.locals.add(desc.name) - qualified_name = prefix + desc.name - line("class {}({}):", desc.name, message_class) - with self._indent(): - # Nested enums/messages - self.write_enums(desc.enum_type) - self.write_messages(desc.nested_type, qualified_name + ".") - - # Scalar fields - for field in [f for f in desc.field if is_scalar(f)]: - if field.label == d.FieldDescriptorProto.LABEL_REPEATED: - container = self._import( - "google.protobuf.internal.containers", - "RepeatedScalarFieldContainer", - ) - line( - "{} = ... # type: {}[{}]", - field.name, - container, - self.python_type(field), - ) - else: - line("{} = ... # type: {}", field.name, self.python_type(field)) - line("") - - # Getters for non-scalar fields - for field in [f for f in desc.field if not is_scalar(f)]: - line("@property") - if field.label == d.FieldDescriptorProto.LABEL_REPEATED: - msg = self.descriptors.messages[field.type_name] - if msg.options.map_entry: - # map generates a special Entry wrapper message - container = self._import("typing", "MutableMapping") - line( - "def {}(self) -> {}[{}, {}]: ...", - field.name, - container, - self.python_type(msg.field[0]), - self.python_type(msg.field[1]), - ) - else: - container = self._import( - "google.protobuf.internal.containers", - "RepeatedCompositeFieldContainer", - ) - line( - "def {}(self) -> {}[{}]: ...", - field.name, - container, - self.python_type(field), - ) - else: - line( - "def {}(self) -> {}: ...", - field.name, - self.python_type(field), - ) - line("") - - # Constructor - line("def __init__(self,") - with self._indent(): - # Required args - for field in [ - f - for f in desc.field - if f.label == d.FieldDescriptorProto.LABEL_REQUIRED - ]: - line("{} : {},", field.name, self.python_type(field)) - for field in [ - f - for f in desc.field - if f.label != d.FieldDescriptorProto.LABEL_REQUIRED - ]: - if field.label == d.FieldDescriptorProto.LABEL_REPEATED: - if ( - field.type_name != "" - and self.descriptors.messages[ - field.type_name - ].options.map_entry - ): - msg = self.descriptors.messages[field.type_name] - line( - "{} : {}[{}[{}, {}]] = None,", - field.name, - self._import("typing", "Optional", "OptionalType"), - self._import("typing", "Mapping"), - self.python_type(msg.field[0]), - self.python_type(msg.field[1]), - ) - else: - line( - "{} : {}[{}[{}]] = None,", - field.name, - self._import("typing", "Optional", "OptionalType"), - self._import("typing", "Iterable"), - self.python_type(field), - ) - else: - line( - "{} : {}[{}] = None,", - field.name, - self._import("typing", "Optional", "OptionalType"), - self.python_type(field), - ) - line(") -> None: ...") - - # Standard message methods - line("@classmethod") - line("def FromString(cls, s: bytes) -> {}: ...", qualified_name) - line("def MergeFrom(self, other_msg: {}) -> None: ...", message_class) - line("def CopyFrom(self, other_msg: {}) -> None: ...", message_class) - line("") - - def write_services(self, services: d.ServiceDescriptorProto) -> None: - line = self._write_line - - for service in services: - # The service definition interface - line( - "class {}({}, metaclass={}):", - service.name, - self._import("google.protobuf.service", "Service"), - self._import("abc", "ABCMeta"), - ) - with self._indent(): - for method in service.method: - line("@{}", self._import("abc", "abstractmethod")) - line("def {}(self,", method.name) - with self._indent(): - line( - "rpc_controller: {},", - self._import("google.protobuf.service", "RpcController"), - ) - line("request: {},", self._import_message(method.input_type)) - line( - "done: {}[{}[[{}], None]],", - self._import("typing", "Optional"), - self._import("typing", "Callable"), - self._import_message(method.output_type), - ) - line( - ") -> {}[{}]: ...", - self._import("concurrent.futures", "Future"), - self._import_message(method.output_type), - ) - - # The stub client - line("class {}({}):", service.name + "_Stub", service.name) - with self._indent(): - line( - "def __init__(self, rpc_channel: {}) -> None: ...", - self._import("google.protobuf.service", "RpcChannel"), - ) - - def python_type(self, field: d.FieldDescriptorProto) -> str: - mapping: dict[int, Callable[[], str]] = { - d.FieldDescriptorProto.TYPE_DOUBLE: lambda: "float", - d.FieldDescriptorProto.TYPE_FLOAT: lambda: "float", - d.FieldDescriptorProto.TYPE_INT64: lambda: "int", - d.FieldDescriptorProto.TYPE_UINT64: lambda: "int", - d.FieldDescriptorProto.TYPE_FIXED64: lambda: "int", - d.FieldDescriptorProto.TYPE_SFIXED64: lambda: "int", - d.FieldDescriptorProto.TYPE_SINT64: lambda: "int", - d.FieldDescriptorProto.TYPE_INT32: lambda: "int", - d.FieldDescriptorProto.TYPE_UINT32: lambda: "int", - d.FieldDescriptorProto.TYPE_FIXED32: lambda: "int", - d.FieldDescriptorProto.TYPE_SFIXED32: lambda: "int", - d.FieldDescriptorProto.TYPE_SINT32: lambda: "int", - d.FieldDescriptorProto.TYPE_BOOL: lambda: "bool", - d.FieldDescriptorProto.TYPE_STRING: lambda: "str", - d.FieldDescriptorProto.TYPE_BYTES: lambda: "bytes", - d.FieldDescriptorProto.TYPE_ENUM: lambda: self._import_message( - field.type_name - ), - d.FieldDescriptorProto.TYPE_MESSAGE: lambda: self._import_message( - field.type_name - ), - d.FieldDescriptorProto.TYPE_GROUP: lambda: self._import_message( - field.type_name - ), - } - - assert field.type in mapping, "Unrecognized type: " + field.type - return mapping[field.type]() - - def write(self) -> str: - imports = [] - for pkg, items in self.imports.items(): - if pkg.startswith("google."): - imports.append(f"from {pkg} import ( # type: ignore") - else: - imports.append(f"from {pkg} import (") - for item in sorted(items): - imports.append(f" {item},") # noqa: PERF401 - imports.append(")\n") - - return "\n".join(imports + self.lines) - - -def is_scalar(fd: d.FileDescriptorProto) -> bool: - return not ( - fd.type == d.FieldDescriptorProto.TYPE_MESSAGE # noqa: PLR1714 - or fd.type == d.FieldDescriptorProto.TYPE_GROUP - ) - - -def generate_mypy_stubs( - descriptors: Descriptors, response: plugin.CodeGeneratorResponse -) -> None: - for name, fd in descriptors.to_generate.items(): - pkg_writer = PkgWriter(fd, descriptors) - pkg_writer.write_enums(fd.enum_type) - pkg_writer.write_messages(fd.message_type, "") - pkg_writer.write_services(fd.service) - - assert name == fd.name - assert fd.name.endswith(".proto") - output = response.file.add() - output.name = fd.name[:-6].replace("-", "_") + "_pb2.pyi" - output.content = HEADER + pkg_writer.write() - print("Writing mypy to", output.name, file=sys.stderr) - - -def main() -> None: - # Read request message from stdin - data = sys.stdin.buffer.read() - - # Parse request - request = plugin.CodeGeneratorRequest() - request.ParseFromString(data) - - # Create response - response = plugin.CodeGeneratorResponse() - - # Generate mypy - generate_mypy_stubs(Descriptors(request), response) - - # Serialise response message - output = response.SerializeToString() - - # Write to stdout - sys.stdout.buffer.write(output) - - -if __name__ == "__main__": - main() diff --git a/tools/protoc-gen-mypy.sh.in b/tools/protoc-gen-mypy.sh.in deleted file mode 100644 index 6231cb290a8..00000000000 --- a/tools/protoc-gen-mypy.sh.in +++ /dev/null @@ -1,4 +0,0 @@ -#!/usr/bin/env bash -DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" -# use PYTHON_EXECUTABLE from CMake to get the right python to execute -"@PYTHON_EXECUTABLE@" -u "$DIR"/protoc-gen-mypy.py