diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 62863b54..cd35baf1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -49,7 +49,7 @@ jobs: - name: Move compiled files to betterproto2 shell: bash - run: mv betterproto2_compiler/tests/output_betterproto betterproto2_compiler/tests/output_betterproto_pydantic betterproto2_compiler/tests/output_betterproto_descriptor betterproto2_compiler/tests/output_reference betterproto2/tests + run: cp -r betterproto2_compiler/tests/outputs betterproto2/tests - name: Execute test suite working-directory: ./betterproto2 diff --git a/.gitignore b/.gitignore index 442c3f77..c5c01617 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ .python-version build/ */tests/output_* +*/tests/outputs/* **/__pycache__ dist **/*.egg-info @@ -18,4 +19,4 @@ output .asv venv .devcontainer -.ruff_cache \ No newline at end of file +.ruff_cache diff --git a/betterproto2/pyproject.toml b/betterproto2/pyproject.toml index b997cfc5..1187be7d 100644 --- a/betterproto2/pyproject.toml +++ b/betterproto2/pyproject.toml @@ -92,7 +92,7 @@ addopts = "-p no:warnings" # Dev workflow tasks [tool.poe.tasks.get-local-compiled-tests] # task useful for local development. Copies the compiled test files from the compiler folder to the tests folder -shell = "rm -rf tests/output_* && cp -r ../betterproto2_compiler/tests/output_* tests" +shell = "rm -rf tests/outputs* && cp -r ../betterproto2_compiler/tests/outputs tests" [tool.poe.tasks.test] cmd = "pytest" @@ -138,18 +138,6 @@ rm -rf .coverage .mypy_cache .pytest_cache """ help = "Clean out generated files from the workspace" -[tool.poe.tasks.pull-compiled-tests] -shell = """ -rm -rf tests/output_* && -git clone https://github.com/betterproto/python-betterproto2-compiler --branch compiled-test-files --single-branch compiled_files && -mv compiled_files/tests_betterproto tests/output_betterproto && -mv compiled_files/tests_betterproto_pydantic tests/output_betterproto_pydantic && -mv compiled_files/tests_betterproto_pydantic tests/output_betterproto_descriptor && -mv compiled_files/tests_reference tests/output_reference && -rm -rf compiled_files -""" -help = "Pulls the compiled test files from the betterproto2-compiler repository" - [tool.poe.tasks.serve-docs] cmd = "mkdocs serve" help = "Serve the documentation locally" diff --git a/betterproto2/tests/grpc/test_grpclib_client.py b/betterproto2/tests/grpc/test_grpclib_client.py index bc4ef537..af45a654 100644 --- a/betterproto2/tests/grpc/test_grpclib_client.py +++ b/betterproto2/tests/grpc/test_grpclib_client.py @@ -9,7 +9,7 @@ from grpclib.testing import ChannelFor from tests.grpc.async_channel import AsyncChannel -from tests.output_betterproto.service import ( +from tests.outputs.service.service import ( DoThingRequest, DoThingResponse, GetThingRequest, diff --git a/betterproto2/tests/grpc/test_grpclib_reflection.py b/betterproto2/tests/grpc/test_grpclib_reflection.py index 3680cf7a..b7507255 100644 --- a/betterproto2/tests/grpc/test_grpclib_reflection.py +++ b/betterproto2/tests/grpc/test_grpclib_reflection.py @@ -8,15 +8,17 @@ from grpclib.reflection.v1alpha.reflection_grpc import ServerReflectionBase as ServerReflectionBaseV1Alpha from grpclib.testing import ChannelFor -from tests.output_betterproto.example_service import TestBase -from tests.output_betterproto.grpc.reflection.v1 import ( +from tests.outputs.grpclib_reflection.example_service import TestBase +from tests.outputs.grpclib_reflection.grpc.reflection.v1 import ( ErrorResponse, ListServiceResponse, ServerReflectionRequest, ServerReflectionStub, ServiceResponse, ) -from tests.output_betterproto_descriptor.google_proto_descriptor_pool import default_google_proto_descriptor_pool +from tests.outputs.grpclib_reflection_descriptors.google_proto_descriptor_pool import ( + default_google_proto_descriptor_pool, +) class TestService(TestBase): @@ -78,7 +80,7 @@ async def test_grpclib_reflection(): assert response.file_descriptor_response is None # now it should work - import tests.output_betterproto_descriptor.example_service as example_service_with_desc + import tests.outputs.grpclib_reflection_descriptors.example_service as example_service_with_desc requests.put(ServerReflectionRequest(file_containing_symbol="example_service.Test")) response = await anext(responses) diff --git a/betterproto2/tests/grpc/test_message_enum_descriptors.py b/betterproto2/tests/grpc/test_message_enum_descriptors.py index a7922820..3f71425d 100644 --- a/betterproto2/tests/grpc/test_message_enum_descriptors.py +++ b/betterproto2/tests/grpc/test_message_enum_descriptors.py @@ -1,10 +1,14 @@ import pytest -from tests.output_betterproto.import_cousin_package_same_name.test.subpackage import Test +from tests.outputs.import_cousin_package_same_name.import_cousin_package_same_name.test.subpackage import Test # importing the cousin should cause no descriptor pool errors since the subpackage imports it once already -from tests.output_betterproto_descriptor.import_cousin_package_same_name.cousin.subpackage import CousinMessage -from tests.output_betterproto_descriptor.import_cousin_package_same_name.test.subpackage import Test as TestWithDesc +from tests.outputs.import_cousin_package_same_name_descriptors.import_cousin_package_same_name.cousin.subpackage import ( # noqa: E501 + CousinMessage, +) +from tests.outputs.import_cousin_package_same_name_descriptors.import_cousin_package_same_name.test.subpackage import ( + Test as TestWithDesc, +) def test_message_enum_descriptors(): diff --git a/betterproto2/tests/grpc/test_stream_stream.py b/betterproto2/tests/grpc/test_stream_stream.py index a0d4c8e8..a3031e30 100644 --- a/betterproto2/tests/grpc/test_stream_stream.py +++ b/betterproto2/tests/grpc/test_stream_stream.py @@ -4,7 +4,7 @@ import pytest from tests.grpc.async_channel import AsyncChannel -from tests.output_betterproto.stream_stream import Message +from tests.outputs.stream_stream.stream_stream import Message @pytest.fixture diff --git a/betterproto2/tests/grpc/thing_service.py b/betterproto2/tests/grpc/thing_service.py index 392a086d..e76e1e10 100644 --- a/betterproto2/tests/grpc/thing_service.py +++ b/betterproto2/tests/grpc/thing_service.py @@ -1,7 +1,7 @@ import grpclib import grpclib.server -from tests.output_betterproto.service import ( +from tests.outputs.service.service import ( DoThingRequest, DoThingResponse, GetThingRequest, diff --git a/betterproto2/tests/inputs/bool/test_bool.py b/betterproto2/tests/inputs/bool/test_bool.py index b6cfc8a2..aa553e37 100644 --- a/betterproto2/tests/inputs/bool/test_bool.py +++ b/betterproto2/tests/inputs/bool/test_bool.py @@ -2,28 +2,28 @@ def test_value(): - from tests.output_betterproto.bool import Test + from tests.outputs.bool.bool import Test message = Test() assert not message.value, "Boolean is False by default" def test_pydantic_no_value(): - from tests.output_betterproto_pydantic.bool import Test as TestPyd + from tests.outputs.bool_pydantic.bool import Test as TestPyd message = TestPyd() assert not message.value, "Boolean is False by default" def test_pydantic_value(): - from tests.output_betterproto_pydantic.bool import Test as TestPyd + from tests.outputs.bool_pydantic.bool import Test as TestPyd message = TestPyd(value=False) assert not message.value def test_pydantic_bad_value(): - from tests.output_betterproto_pydantic.bool import Test as TestPyd + from tests.outputs.bool_pydantic.bool import Test as TestPyd with pytest.raises(ValueError): TestPyd(value=123) diff --git a/betterproto2/tests/inputs/casing/test_casing.py b/betterproto2/tests/inputs/casing/test_casing.py index feee009a..911c5fb4 100644 --- a/betterproto2/tests/inputs/casing/test_casing.py +++ b/betterproto2/tests/inputs/casing/test_casing.py @@ -1,5 +1,5 @@ -import tests.output_betterproto.casing as casing -from tests.output_betterproto.casing import Test +import tests.outputs.casing.casing as casing +from tests.outputs.casing.casing import Test def test_message_attributes(): diff --git a/betterproto2/tests/inputs/casing_inner_class/test_casing_inner_class.py b/betterproto2/tests/inputs/casing_inner_class/test_casing_inner_class.py index 2560b6c2..5ddba1a7 100644 --- a/betterproto2/tests/inputs/casing_inner_class/test_casing_inner_class.py +++ b/betterproto2/tests/inputs/casing_inner_class/test_casing_inner_class.py @@ -1,4 +1,4 @@ -import tests.output_betterproto.casing_inner_class as casing_inner_class +import tests.outputs.casing_inner_class.casing_inner_class as casing_inner_class def test_message_casing_inner_class_name(): diff --git a/betterproto2/tests/inputs/config.py b/betterproto2/tests/inputs/config.py index 2c10d97c..cd5c11c9 100644 --- a/betterproto2/tests/inputs/config.py +++ b/betterproto2/tests/inputs/config.py @@ -19,10 +19,3 @@ "empty_service", "service_uppercase", } - - -# Indicate json sample messages to skip when testing that json (de)serialization -# is symmetrical becuase some cases legitimately are not symmetrical. -# Each key references the name of the test scenario and the values in the tuple -# Are the names of the json files. -non_symmetrical_json = {"empty_repeated": ("empty_repeated",)} diff --git a/betterproto2/tests/inputs/empty_repeated/empty_repeated.json b/betterproto2/tests/inputs/empty_repeated/empty_repeated.json deleted file mode 100644 index 12a801c6..00000000 --- a/betterproto2/tests/inputs/empty_repeated/empty_repeated.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "msg": [{"values":[]}] -} diff --git a/betterproto2/tests/inputs/enum/test_enum.py b/betterproto2/tests/inputs/enum/test_enum.py index 3ec4cbdd..6f7f6c9e 100644 --- a/betterproto2/tests/inputs/enum/test_enum.py +++ b/betterproto2/tests/inputs/enum/test_enum.py @@ -1,4 +1,4 @@ -from tests.output_betterproto.enum import ( +from tests.outputs.enum.enum import ( ArithmeticOperator, Choice, Test, diff --git a/betterproto2/tests/inputs/example_service/test_example_service.py b/betterproto2/tests/inputs/example_service/test_example_service.py index cc257e94..a465b7c2 100644 --- a/betterproto2/tests/inputs/example_service/test_example_service.py +++ b/betterproto2/tests/inputs/example_service/test_example_service.py @@ -3,7 +3,7 @@ import pytest from grpclib.testing import ChannelFor -from tests.output_betterproto.example_service import ( +from tests.outputs.example_service.example_service import ( ExampleRequest, ExampleResponse, TestBase, diff --git a/betterproto2/tests/inputs/google_impl_behavior_equivalence/test_google_impl_behavior_equivalence.py b/betterproto2/tests/inputs/google_impl_behavior_equivalence/test_google_impl_behavior_equivalence.py index d473f885..81ec5003 100644 --- a/betterproto2/tests/inputs/google_impl_behavior_equivalence/test_google_impl_behavior_equivalence.py +++ b/betterproto2/tests/inputs/google_impl_behavior_equivalence/test_google_impl_behavior_equivalence.py @@ -7,14 +7,14 @@ from google.protobuf import json_format from google.protobuf.timestamp_pb2 import Timestamp -from tests.output_betterproto.google_impl_behavior_equivalence import ( +from tests.outputs.google_impl_behavior_equivalence.google_impl_behavior_equivalence import ( Empty, Foo, Request, Spam, Test, ) -from tests.output_reference.google_impl_behavior_equivalence.google_impl_behavior_equivalence_pb2 import ( +from tests.outputs.google_impl_behavior_equivalence_reference.google_impl_behavior_equivalence_pb2 import ( Empty as ReferenceEmpty, Foo as ReferenceFoo, Request as ReferenceRequest, diff --git a/betterproto2/tests/inputs/googletypes_request/test_googletypes_request.py b/betterproto2/tests/inputs/googletypes_request/test_googletypes_request.py index 79504535..05cc905e 100644 --- a/betterproto2/tests/inputs/googletypes_request/test_googletypes_request.py +++ b/betterproto2/tests/inputs/googletypes_request/test_googletypes_request.py @@ -9,12 +9,9 @@ import pytest -import tests.output_betterproto.google.protobuf as protobuf +import tests.outputs.googletypes_request.google.protobuf as protobuf from tests.mocks import MockChannel -from tests.output_betterproto.googletypes_request import ( - Input, - TestStub, -) +from tests.outputs.googletypes_request.googletypes_request import Input, TestStub test_cases = [ (TestStub.send_double, protobuf.DoubleValue, 2.5), diff --git a/betterproto2/tests/inputs/googletypes_response/test_googletypes_response.py b/betterproto2/tests/inputs/googletypes_response/test_googletypes_response.py index 1bb2ef1d..14f49c0a 100644 --- a/betterproto2/tests/inputs/googletypes_response/test_googletypes_response.py +++ b/betterproto2/tests/inputs/googletypes_response/test_googletypes_response.py @@ -5,12 +5,9 @@ import pytest -import tests.output_betterproto.google.protobuf as protobuf +import tests.outputs.googletypes_response.google.protobuf as protobuf from tests.mocks import MockChannel -from tests.output_betterproto.googletypes_response import ( - Input, - TestStub, -) +from tests.outputs.googletypes_response.googletypes_response import Input, TestStub test_cases = [ (TestStub.get_double, protobuf.DoubleValue, 2.5), diff --git a/betterproto2/tests/inputs/googletypes_response_embedded/test_googletypes_response_embedded.py b/betterproto2/tests/inputs/googletypes_response_embedded/test_googletypes_response_embedded.py index 57ebce1b..74b5cda6 100644 --- a/betterproto2/tests/inputs/googletypes_response_embedded/test_googletypes_response_embedded.py +++ b/betterproto2/tests/inputs/googletypes_response_embedded/test_googletypes_response_embedded.py @@ -1,7 +1,7 @@ import pytest from tests.mocks import MockChannel -from tests.output_betterproto.googletypes_response_embedded import ( +from tests.outputs.googletypes_response_embedded.googletypes_response_embedded import ( Input, Output, TestStub, diff --git a/betterproto2/tests/inputs/import_service_input_message/test_import_service_input_message.py b/betterproto2/tests/inputs/import_service_input_message/test_import_service_input_message.py index 5885da71..5c7cd6cc 100644 --- a/betterproto2/tests/inputs/import_service_input_message/test_import_service_input_message.py +++ b/betterproto2/tests/inputs/import_service_input_message/test_import_service_input_message.py @@ -1,13 +1,13 @@ import pytest from tests.mocks import MockChannel -from tests.output_betterproto.import_service_input_message import ( +from tests.outputs.import_service_input_message.import_service_input_message import ( NestedRequestMessage, RequestMessage, RequestResponse, TestStub, ) -from tests.output_betterproto.import_service_input_message.child import ( +from tests.outputs.import_service_input_message.import_service_input_message.child import ( ChildRequestMessage, ) diff --git a/betterproto2/tests/inputs/invalid_field/test_invalid_field.py b/betterproto2/tests/inputs/invalid_field/test_invalid_field.py index 947b8e13..aefead03 100644 --- a/betterproto2/tests/inputs/invalid_field/test_invalid_field.py +++ b/betterproto2/tests/inputs/invalid_field/test_invalid_field.py @@ -2,7 +2,7 @@ def test_invalid_field(): - from tests.output_betterproto.invalid_field import Test + from tests.outputs.invalid_field.invalid_field import Test with pytest.raises(TypeError): Test(unknown_field=12) @@ -11,7 +11,7 @@ def test_invalid_field(): def test_invalid_field_pydantic(): from pydantic import ValidationError - from tests.output_betterproto_pydantic.invalid_field import Test + from tests.outputs.invalid_field_pydantic.invalid_field import Test with pytest.raises(ValidationError): Test(unknown_field=12) diff --git a/betterproto2/tests/inputs/nestedtwice/test_nestedtwice.py b/betterproto2/tests/inputs/nestedtwice/test_nestedtwice.py index ca0557a7..532d9d56 100644 --- a/betterproto2/tests/inputs/nestedtwice/test_nestedtwice.py +++ b/betterproto2/tests/inputs/nestedtwice/test_nestedtwice.py @@ -1,6 +1,6 @@ import pytest -from tests.output_betterproto.nestedtwice import ( +from tests.outputs.nestedtwice.nestedtwice import ( Test, TestTop, TestTopMiddle, diff --git a/betterproto2/tests/inputs/oneof/test_oneof.py b/betterproto2/tests/inputs/oneof/test_oneof.py index dfb80879..a703cf62 100644 --- a/betterproto2/tests/inputs/oneof/test_oneof.py +++ b/betterproto2/tests/inputs/oneof/test_oneof.py @@ -3,28 +3,28 @@ def test_which_count(): - from tests.output_betterproto.oneof import Test + from tests.outputs.oneof.oneof import Test message = Test.from_json(get_test_case_json_data("oneof")[0].json) assert betterproto2.which_one_of(message, "foo") == ("pitied", 100) def test_which_name(): - from tests.output_betterproto.oneof import Test + from tests.outputs.oneof.oneof import Test message = Test.from_json(get_test_case_json_data("oneof", "oneof_name.json")[0].json) assert betterproto2.which_one_of(message, "foo") == ("pitier", "Mr. T") def test_which_count_pyd(): - from tests.output_betterproto_pydantic.oneof import Test as TestPyd + from tests.outputs.oneof_pydantic.oneof import Test - message = TestPyd(pitier="Mr. T", just_a_regular_field=2, bar_name="a_bar") + message = Test(pitier="Mr. T", just_a_regular_field=2, bar_name="a_bar") assert betterproto2.which_one_of(message, "foo") == ("pitier", "Mr. T") def test_oneof_constructor_assign(): - from tests.output_betterproto.oneof import MixedDrink, Test + from tests.outputs.oneof.oneof import MixedDrink, Test message = Test(mixed_drink=MixedDrink(shots=42)) field, value = betterproto2.which_one_of(message, "bar") diff --git a/betterproto2/tests/inputs/oneof_default_value_serialization/test_oneof_default_value_serialization.py b/betterproto2/tests/inputs/oneof_default_value_serialization/test_oneof_default_value_serialization.py index 0060261d..3048139a 100644 --- a/betterproto2/tests/inputs/oneof_default_value_serialization/test_oneof_default_value_serialization.py +++ b/betterproto2/tests/inputs/oneof_default_value_serialization/test_oneof_default_value_serialization.py @@ -1,7 +1,7 @@ import datetime import betterproto2 -from tests.output_betterproto.oneof_default_value_serialization import ( +from tests.outputs.oneof_default_value_serialization.oneof_default_value_serialization import ( Message, NestedMessage, Test, diff --git a/betterproto2/tests/inputs/oneof_enum/test_oneof_enum.py b/betterproto2/tests/inputs/oneof_enum/test_oneof_enum.py index 25891c0e..ba86ca32 100644 --- a/betterproto2/tests/inputs/oneof_enum/test_oneof_enum.py +++ b/betterproto2/tests/inputs/oneof_enum/test_oneof_enum.py @@ -1,5 +1,5 @@ import betterproto2 -from tests.output_betterproto.oneof_enum import Move, Signal, Test +from tests.outputs.oneof_enum.oneof_enum import Move, Signal, Test from tests.util import get_test_case_json_data diff --git a/betterproto2/tests/inputs/proto3_field_presence/test_proto3_field_presence.py b/betterproto2/tests/inputs/proto3_field_presence/test_proto3_field_presence.py index 9c2d6e69..7847b774 100644 --- a/betterproto2/tests/inputs/proto3_field_presence/test_proto3_field_presence.py +++ b/betterproto2/tests/inputs/proto3_field_presence/test_proto3_field_presence.py @@ -1,8 +1,6 @@ import json -from tests.output_betterproto.proto3_field_presence import ( - Test, -) +from tests.outputs.proto3_field_presence.proto3_field_presence import Test def test_null_fields_json(): diff --git a/betterproto2/tests/inputs/proto3_field_presence_oneof/test_proto3_field_presence_oneof.py b/betterproto2/tests/inputs/proto3_field_presence_oneof/test_proto3_field_presence_oneof.py index 2320dc64..2a9a0151 100644 --- a/betterproto2/tests/inputs/proto3_field_presence_oneof/test_proto3_field_presence_oneof.py +++ b/betterproto2/tests/inputs/proto3_field_presence_oneof/test_proto3_field_presence_oneof.py @@ -1,8 +1,4 @@ -from tests.output_betterproto.proto3_field_presence_oneof import ( - Nested, - Test, - WithOptional, -) +from tests.outputs.proto3_field_presence_oneof.proto3_field_presence_oneof import Nested, Test, WithOptional def test_serialization(): diff --git a/betterproto2/tests/inputs/regression_387/test_regression_387.py b/betterproto2/tests/inputs/regression_387/test_regression_387.py index e1200ccd..420a9b54 100644 --- a/betterproto2/tests/inputs/regression_387/test_regression_387.py +++ b/betterproto2/tests/inputs/regression_387/test_regression_387.py @@ -1,4 +1,4 @@ -from tests.output_betterproto.regression_387 import ParentElement, Test +from tests.outputs.regression_387.regression_387 import ParentElement, Test def test_regression_387(): diff --git a/betterproto2/tests/inputs/regression_414/test_regression_414.py b/betterproto2/tests/inputs/regression_414/test_regression_414.py index 61f5bfe6..46a5f355 100644 --- a/betterproto2/tests/inputs/regression_414/test_regression_414.py +++ b/betterproto2/tests/inputs/regression_414/test_regression_414.py @@ -1,4 +1,4 @@ -from tests.output_betterproto.regression_414 import Test +from tests.outputs.regression_414.regression_414 import Test def test_full_cycle(): diff --git a/betterproto2/tests/inputs/repeated_duration_timestamp/test_repeated_duration_timestamp.py b/betterproto2/tests/inputs/repeated_duration_timestamp/test_repeated_duration_timestamp.py index efc34866..b3984f7c 100644 --- a/betterproto2/tests/inputs/repeated_duration_timestamp/test_repeated_duration_timestamp.py +++ b/betterproto2/tests/inputs/repeated_duration_timestamp/test_repeated_duration_timestamp.py @@ -1,9 +1,6 @@ -from datetime import ( - datetime, - timedelta, -) +from datetime import datetime, timedelta -from tests.output_betterproto.repeated_duration_timestamp import Test +from tests.outputs.repeated_duration_timestamp.repeated_duration_timestamp import Test def test_roundtrip(): diff --git a/betterproto2/tests/inputs/rpc_empty_input_message/test_rpc_empty_input_message.py b/betterproto2/tests/inputs/rpc_empty_input_message/test_rpc_empty_input_message.py index f77578f6..29a90a21 100644 --- a/betterproto2/tests/inputs/rpc_empty_input_message/test_rpc_empty_input_message.py +++ b/betterproto2/tests/inputs/rpc_empty_input_message/test_rpc_empty_input_message.py @@ -4,7 +4,7 @@ @pytest.mark.asyncio async def test_rpc_input_message(): - from tests.output_betterproto.rpc_empty_input_message import ( + from tests.outputs.rpc_empty_input_message.rpc_empty_input_message import ( Response, ServiceBase, ServiceStub, diff --git a/betterproto2/tests/inputs/service_uppercase/test_service.py b/betterproto2/tests/inputs/service_uppercase/test_service.py index 35405e13..cbf50f4c 100644 --- a/betterproto2/tests/inputs/service_uppercase/test_service.py +++ b/betterproto2/tests/inputs/service_uppercase/test_service.py @@ -1,6 +1,6 @@ import inspect -from tests.output_betterproto.service_uppercase import TestStub +from tests.outputs.service_uppercase.service_uppercase import TestStub def test_parameters(): diff --git a/betterproto2/tests/inputs/timestamp_dict_encode/test_timestamp_dict_encode.py b/betterproto2/tests/inputs/timestamp_dict_encode/test_timestamp_dict_encode.py index 872acd4b..559922b4 100644 --- a/betterproto2/tests/inputs/timestamp_dict_encode/test_timestamp_dict_encode.py +++ b/betterproto2/tests/inputs/timestamp_dict_encode/test_timestamp_dict_encode.py @@ -6,7 +6,7 @@ import pytest -from tests.output_betterproto.timestamp_dict_encode import Test +from tests.outputs.timestamp_dict_encode.timestamp_dict_encode import Test # Current World Timezone range (UTC-12 to UTC+14) MIN_UTC_OFFSET_MIN = -12 * 60 diff --git a/betterproto2/tests/oneof_pattern_matching.py b/betterproto2/tests/oneof_pattern_matching.py index b46b6e5a..e7d3e591 100644 --- a/betterproto2/tests/oneof_pattern_matching.py +++ b/betterproto2/tests/oneof_pattern_matching.py @@ -2,7 +2,7 @@ def test_oneof_pattern_matching(): - from tests.output_betterproto.features import IntMsg, OneofMsg + from tests.outputs.features.features import IntMsg, OneofMsg msg = OneofMsg(y="test1", b="test2") diff --git a/betterproto2/tests/test_all_definition.py b/betterproto2/tests/test_all_definition.py index e93c064f..26304f5d 100644 --- a/betterproto2/tests/test_all_definition.py +++ b/betterproto2/tests/test_all_definition.py @@ -4,8 +4,8 @@ def test_all_definition(): These modules have been chosen since they contain messages, services and enums. """ - import tests.output_betterproto.enum as enum - import tests.output_betterproto.service as service + import tests.outputs.enum.enum as enum + import tests.outputs.service.service as service assert service.__all__ == ( "DoThingRequest", diff --git a/betterproto2/tests/test_any.py b/betterproto2/tests/test_any.py index 6cf51b89..d25b05d3 100644 --- a/betterproto2/tests/test_any.py +++ b/betterproto2/tests/test_any.py @@ -1,7 +1,7 @@ def test_any() -> None: # TODO using a custom message pool will no longer be necessary when the well-known types will be compiled as well - from tests.output_betterproto.any import Person - from tests.output_betterproto.google.protobuf import Any + from tests.outputs.any.any import Person + from tests.outputs.any.google.protobuf import Any person = Person(first_name="John", last_name="Smith") @@ -14,8 +14,8 @@ def test_any() -> None: def test_any_to_dict() -> None: - from tests.output_betterproto.any import Person - from tests.output_betterproto.google.protobuf import Any + from tests.outputs.any.any import Person + from tests.outputs.any.google.protobuf import Any person = Person(first_name="John", last_name="Smith") diff --git a/betterproto2/tests/test_deprecated.py b/betterproto2/tests/test_deprecated.py index 2930f6cf..f77da7ee 100644 --- a/betterproto2/tests/test_deprecated.py +++ b/betterproto2/tests/test_deprecated.py @@ -3,7 +3,7 @@ import pytest from tests.mocks import MockChannel -from tests.output_betterproto.deprecated import ( +from tests.outputs.deprecated.deprecated import ( Empty, Message, Test, diff --git a/betterproto2/tests/test_documentation.py b/betterproto2/tests/test_documentation.py index de9790d3..86332722 100644 --- a/betterproto2/tests/test_documentation.py +++ b/betterproto2/tests/test_documentation.py @@ -11,7 +11,7 @@ def check(generated_doc: str, type: str) -> None: def test_documentation() -> None: - from .output_betterproto.documentation import ( + from .outputs.documentation.documentation import ( Enum, ServiceBase, ServiceStub, @@ -39,7 +39,7 @@ def test_documentation() -> None: def test_escaping() -> None: - from .output_betterproto.documentation import ComplexDocumentation + from .outputs.documentation.documentation import ComplexDocumentation ComplexDocumentation.__doc__ == """ A comment with backslashes \\ and triple quotes \"\"\" diff --git a/betterproto2/tests/test_encoding_decoding.py b/betterproto2/tests/test_encoding_decoding.py index 3e9e535e..a284ea7f 100644 --- a/betterproto2/tests/test_encoding_decoding.py +++ b/betterproto2/tests/test_encoding_decoding.py @@ -1,6 +1,6 @@ def test_int_overflow(): """Make sure that overflows in encoded values are handled correctly.""" - from tests.output_betterproto_pydantic.encoding_decoding import Overflow32, Overflow64 + from tests.outputs.encoding_decoding.encoding_decoding import Overflow32, Overflow64 b = bytes(Overflow64(uint=2**50 + 42)) msg = Overflow32.parse(b) diff --git a/betterproto2/tests/test_features.py b/betterproto2/tests/test_features.py index ac3f67b8..0c5a4e16 100644 --- a/betterproto2/tests/test_features.py +++ b/betterproto2/tests/test_features.py @@ -17,7 +17,7 @@ def test_class_init(): - from tests.output_betterproto.features import Bar, Foo + from tests.outputs.features.features import Bar, Foo foo = Foo(name="foo", child=Bar(name="bar")) @@ -26,7 +26,7 @@ def test_class_init(): def test_enum_as_int_json(): - from tests.output_betterproto.features import Enum, EnumMsg + from tests.outputs.features.features import Enum, EnumMsg # JSON strings are supported, but ints should still be supported too. enum_msg = EnumMsg().from_dict({"enum": 1}) @@ -43,7 +43,7 @@ def test_enum_as_int_json(): def test_unknown_fields(): - from tests.output_betterproto.features import Newer, Older + from tests.outputs.features.features import Newer, Older newer = Newer(x=True, y=1, z="Hello") serialized_newer = bytes(newer) @@ -57,7 +57,7 @@ def test_unknown_fields(): def test_from_dict_unknown_fields(): - from tests.output_betterproto.features import Older + from tests.outputs.features.features import Older with pytest.raises(KeyError): Older.from_dict({"x": True, "y": 1}) @@ -66,7 +66,7 @@ def test_from_dict_unknown_fields(): def test_from_json_unknown_fields(): - from tests.output_betterproto.features import Older + from tests.outputs.features.features import Older with pytest.raises(KeyError): Older.from_json('{"x": true, "y": 1}') @@ -75,7 +75,7 @@ def test_from_json_unknown_fields(): def test_oneof_support(): - from tests.output_betterproto.features import IntMsg, OneofMsg + from tests.outputs.features.features import IntMsg, OneofMsg msg = OneofMsg() @@ -111,7 +111,7 @@ def test_oneof_support(): def test_json_casing(): - from tests.output_betterproto.features import JsonCasingMsg + from tests.outputs.features.features import JsonCasingMsg # Parsing should accept almost any input msg = JsonCasingMsg().from_dict({"PascalCase": 1, "camelCase": 2, "snake_case": 3, "kabob-case": 4}) @@ -135,7 +135,7 @@ def test_json_casing(): def test_dict_casing(): - from tests.output_betterproto.features import JsonCasingMsg + from tests.outputs.features.features import JsonCasingMsg # Parsing should accept almost any input msg = JsonCasingMsg().from_dict({"PascalCase": 1, "camelCase": 2, "snake_case": 3, "kabob-case": 4}) @@ -159,7 +159,7 @@ def test_dict_casing(): def test_optional_flag(): - from tests.output_betterproto.features import OptionalBoolMsg + from tests.outputs.features.features import OptionalBoolMsg # Serialization of not passed vs. set vs. zero-value. assert bytes(OptionalBoolMsg()) == b"" @@ -172,7 +172,7 @@ def test_optional_flag(): def test_optional_datetime_to_dict(): - from tests.output_betterproto.features import OptionalDatetimeMsg + from tests.outputs.features.features import OptionalDatetimeMsg # Check dict serialization assert OptionalDatetimeMsg().to_dict() == {} @@ -198,7 +198,7 @@ def test_optional_datetime_to_dict(): def test_to_json_default_values(): - from tests.output_betterproto.features import MsgA + from tests.outputs.features.features import MsgA # Empty dict test = MsgA().from_dict({}) @@ -222,7 +222,7 @@ def test_to_json_default_values(): def test_to_dict_default_values(): - from tests.output_betterproto.features import MsgA, MsgB + from tests.outputs.features.features import MsgA, MsgB # Empty dict test = MsgA() @@ -268,7 +268,7 @@ def test_to_dict_default_values(): def test_to_dict_datetime_values(): - from tests.output_betterproto.features import TimeMsg + from tests.outputs.features.features import TimeMsg test = TimeMsg.from_dict({"timestamp": "2020-01-01T00:00:00Z", "duration": "86400s"}) assert test.to_dict() == {"timestamp": "2020-01-01T00:00:00Z", "duration": "86400s"} @@ -288,7 +288,7 @@ def test_to_dict_datetime_values(): def test_oneof_default_value_set_causes_writes_wire(): - from tests.output_betterproto.features import Empty, MsgC + from tests.outputs.features.features import Empty, MsgC def _round_trip_serialization(msg: MsgC) -> MsgC: return MsgC.parse(bytes(msg)) @@ -328,7 +328,7 @@ def _round_trip_serialization(msg: MsgC) -> MsgC: def test_message_repr(): - from tests.output_betterproto.recursivemessage import Test + from tests.outputs.recursivemessage.recursivemessage import Test assert repr(Test(name="Loki")) == "Test(name='Loki')" assert repr(Test(child=Test(), name="Loki")) == "Test(name='Loki', child=Test())" @@ -346,7 +346,7 @@ def test_bool(): >>> bool(test) ... False """ - from tests.output_betterproto.features import Empty, IntMsg + from tests.outputs.features.features import Empty, IntMsg assert not Empty() t = IntMsg() @@ -400,7 +400,7 @@ def test_bool(): def test_iso_datetime(): - from tests.output_betterproto.features import TimeMsg + from tests.outputs.features.features import TimeMsg for _, candidate in enumerate(iso_candidates): msg = TimeMsg.from_dict({"timestamp": candidate}) @@ -408,7 +408,7 @@ def test_iso_datetime(): def test_iso_datetime_list(): - from tests.output_betterproto.features import MsgD + from tests.outputs.features.features import MsgD msg = MsgD() @@ -417,7 +417,7 @@ def test_iso_datetime_list(): def test_service_argument__expected_parameter(): - from tests.output_betterproto.service import TestStub + from tests.outputs.service.service import TestStub sig = signature(TestStub.do_thing) do_thing_request_parameter = sig.parameters["message"] @@ -426,7 +426,7 @@ def test_service_argument__expected_parameter(): def test_is_set(): - from tests.output_betterproto.features import MsgE + from tests.outputs.features.features import MsgE assert not MsgE().is_set("bool_field") assert not MsgE().is_set("int_field") @@ -435,7 +435,7 @@ def test_is_set(): def test_equality_comparison(): - from tests.output_betterproto.bool import Test as TestMessage + from tests.outputs.bool.bool import Test as TestMessage msg = TestMessage(value=True) diff --git a/betterproto2/tests/test_inputs.py b/betterproto2/tests/test_inputs.py index dd609435..6db984c9 100644 --- a/betterproto2/tests/test_inputs.py +++ b/betterproto2/tests/test_inputs.py @@ -4,16 +4,13 @@ import math import os import sys -from collections import namedtuple -from types import ModuleType +from dataclasses import dataclass +from pathlib import Path from typing import Any import pytest import betterproto2 -from tests.inputs import config as test_input_config -from tests.mocks import MockChannel -from tests.util import find_module, get_directories, get_test_case_json_data, inputs_path # Force pure-python implementation instead of C++, otherwise imports # break things because we can't properly reset the symbol database. @@ -22,42 +19,13 @@ from google.protobuf.json_format import Parse -class TestCases: - def __init__( - self, - path, - services: set[str], - xfail: set[str], - ): - all = set(get_directories(path)) - {"__pycache__"} - messages = {test for test in all - services if get_test_case_json_data(test)} - - unknown_xfail_tests = xfail - all - if unknown_xfail_tests: - raise Exception(f"Unknown test(s) in config.py: {unknown_xfail_tests}") - - self.services = self.apply_xfail_marks(services, xfail) - self.messages = self.apply_xfail_marks(messages, xfail) - - @staticmethod - def apply_xfail_marks(test_set: set[str], xfail: set[str]): - return [pytest.param(test, marks=pytest.mark.xfail) if test in xfail else test for test in test_set] - - -test_cases = TestCases( - path=inputs_path, - services=test_input_config.services, - xfail=test_input_config.xfail, -) - -plugin_output_package = "tests.output_betterproto" -reference_output_package = "tests.output_reference" - -TestData = namedtuple("TestData", ["plugin_module", "reference_module", "json_data"]) - - -def module_has_entry_point(module: ModuleType): - return any(hasattr(module, attr) for attr in ["Test", "TestStub"]) +@dataclass +class TestCase: + jsons: list[str] + plugin_package: str + reference_package: str + reference_path: list[str] | None = None + xfail: bool = False def list_replace_nans(items: list) -> list[Any]: @@ -116,80 +84,158 @@ def reset_sys_path(): sys.path = original -@pytest.fixture -def test_data(request, reset_sys_path): - test_case_name = request.param - - reference_module_root = os.path.join(*reference_output_package.split("."), test_case_name) - sys.path.append(reference_module_root) - - plugin_module = importlib.import_module(f"{plugin_output_package}.{test_case_name}") - - plugin_module_entry_point = find_module(plugin_module, module_has_entry_point) - - if not plugin_module_entry_point: - raise Exception( - f"Test case {repr(test_case_name)} has no entry point. " - "Please add a proto message or service called Test and recompile." - ) - - yield ( - TestData( - plugin_module=plugin_module_entry_point, - reference_module=lambda: importlib.import_module( - f"{reference_output_package}.{test_case_name}.{test_case_name}_pb2" - ), - json_data=get_test_case_json_data(test_case_name), - ) - ) - - -@pytest.mark.parametrize("test_data", test_cases.messages, indirect=True) -def test_message_can_instantiated(test_data: TestData) -> None: - plugin_module, *_ = test_data - plugin_module.Test() - - -@pytest.mark.parametrize("test_data", test_cases.messages, indirect=True) -def test_message_equality(test_data: TestData) -> None: - plugin_module, *_ = test_data - message1 = plugin_module.Test() - message2 = plugin_module.Test() - assert message1 == message2 - - -@pytest.mark.parametrize("test_data", test_cases.messages, indirect=True) -def test_message_json(test_data: TestData) -> None: - plugin_module, _, json_data = test_data - - for sample in json_data: - if sample.belongs_to(test_input_config.non_symmetrical_json): - continue - - message: betterproto2.Message = plugin_module.Test.from_json(sample.json) +TEST_CASES = [ + TestCase(["bool/bool.json"], "bool.bool", "bool_reference.bool_pb2"), + TestCase(["bytes/bytes.json"], "bytes.bytes", "bytes_reference.bytes_pb2"), + TestCase(["casing/casing.json"], "casing.casing", "casing_reference.casing_pb2"), + TestCase(["deprecated/deprecated.json"], "deprecated.deprecated", "deprecated_reference.deprecated_pb2"), + TestCase(["double/double.json", "double/double-negative.json"], "double.double", "double_reference.double_pb2"), + TestCase(["enum/enum.json"], "enum.enum", "enum_reference.enum_pb2"), + TestCase( + ["field_name_identical_to_type/field_name_identical_to_type.json"], + "field_name_identical_to_type.field_name_identical_to_type", + "field_name_identical_to_type_reference.field_name_identical_to_type_pb2", + ), + TestCase(["fixed/fixed.json"], "fixed.fixed", "fixed_reference.fixed_pb2"), + TestCase(["float/float.json"], "float.float", "float_reference.float_pb2"), + TestCase( + ["googletypes/googletypes.json", "googletypes/googletypes-missing.json"], + "googletypes.googletypes", + "googletypes_reference.googletypes_pb2", + ), + TestCase( + ["googletypes_struct/googletypes_struct.json"], + "googletypes_struct.googletypes_struct", + "googletypes_struct_reference.googletypes_struct_pb2", + xfail=True, + ), + TestCase( + ["googletypes_value/googletypes_value.json"], + "googletypes_value.googletypes_value", + "googletypes_value_reference.googletypes_value_pb2", + xfail=True, + ), + TestCase(["int32/int32.json"], "int32.int32", "int32_reference.int32_pb2"), + TestCase(["map/map.json"], "map.map", "map_reference.map_pb2"), + TestCase(["mapmessage/mapmessage.json"], "mapmessage.mapmessage", "mapmessage_reference.mapmessage_pb2"), + TestCase( + ["namespace_builtin_types/namespace_builtin_types.json"], + "namespace_builtin_types.namespace_builtin_types", + "namespace_builtin_types_reference.namespace_builtin_types_pb2", + ), + TestCase( + ["namespace_keywords/namespace_keywords.json"], + "namespace_keywords.namespace_keywords", + "namespace_keywords_reference.namespace_keywords_pb2", + xfail=True, + ), + TestCase(["nested/nested.json"], "nested.nested", "nested_reference.nested_pb2"), + TestCase(["nestedtwice/nestedtwice.json"], "nestedtwice.nestedtwice", "nestedtwice_reference.nestedtwice_pb2"), + TestCase( + ["oneof_empty/oneof_empty.json", "oneof_empty/oneof_empty_maybe1.json", "oneof_empty/oneof_empty_maybe2.json"], + "oneof_empty.oneof_empty", + "oneof_empty_reference.oneof_empty_pb2", + ), + TestCase( + ["oneof_enum/oneof_enum-enum-0.json", "oneof_enum/oneof_enum-enum-1.json", "oneof_enum/oneof_enum.json"], + "oneof_enum.oneof_enum", + "oneof_enum_reference.oneof_enum_pb2", + ), + TestCase( + ["oneof/oneof.json", "oneof/oneof-name.json", "oneof/oneof_name.json"], + "oneof.oneof", + "oneof_reference.oneof_pb2", + ), + TestCase( + ["proto3_field_presence_oneof/proto3_field_presence_oneof.json"], + "proto3_field_presence_oneof.proto3_field_presence_oneof", + "proto3_field_presence_oneof_reference.proto3_field_presence_oneof_pb2", + ), + TestCase( + [ + "proto3_field_presence/proto3_field_presence_default.json", + "proto3_field_presence/proto3_field_presence.json", + "proto3_field_presence/proto3_field_presence_missing.json", + ], + "proto3_field_presence.proto3_field_presence", + "proto3_field_presence_reference.proto3_field_presence_pb2", + ), + TestCase( + ["recursivemessage/recursivemessage.json"], + "recursivemessage.recursivemessage", + "recursivemessage_reference.recursivemessage_pb2", + ), + TestCase(["ref/ref.json"], "ref.ref", "ref_reference.ref_pb2", reference_path=["ref_reference"]), + TestCase( + ["repeated_duration_timestamp/repeated_duration_timestamp.json"], + "repeated_duration_timestamp.repeated_duration_timestamp", + "repeated_duration_timestamp_reference.repeated_duration_timestamp_pb2", + ), + TestCase( + ["repeatedmessage/repeatedmessage.json"], + "repeatedmessage.repeatedmessage", + "repeatedmessage_reference.repeatedmessage_pb2", + ), + TestCase( + ["repeatedpacked/repeatedpacked.json"], + "repeatedpacked.repeatedpacked", + "repeatedpacked_reference.repeatedpacked_pb2", + ), + TestCase(["repeated/repeated.json"], "repeated.repeated", "repeated_reference.repeated_pb2"), + TestCase(["signed/signed.json"], "signed.signed", "signed_reference.signed_pb2"), + TestCase( + ["timestamp_dict_encode/timestamp_dict_encode.json"], + "timestamp_dict_encode.timestamp_dict_encode", + "timestamp_dict_encode_reference.timestamp_dict_encode_pb2", + ), +] + + +@pytest.mark.parametrize("test_case", TEST_CASES, ids=lambda x: x.plugin_package) +def test_message_json(test_case: TestCase) -> None: + if test_case.xfail: + pytest.xfail(f"Test case {test_case.plugin_package} is expected to fail.") + + plugin_module = importlib.import_module(f"tests.outputs.{test_case.plugin_package}") + + current_dir = Path(os.path.dirname(os.path.abspath(__file__))) + + for json_path in test_case.jsons: + with open(current_dir / "inputs" / json_path) as f: + json_data = f.read() + + message: betterproto2.Message = plugin_module.Test.from_json(json_data) message_json = message.to_json(indent=0) - assert dict_replace_nans(json.loads(message_json)) == dict_replace_nans(json.loads(sample.json)) + assert dict_replace_nans(json.loads(message_json)) == dict_replace_nans(json.loads(json_data)) + +@pytest.mark.parametrize("test_case", TEST_CASES, ids=lambda x: x.plugin_package) +def test_binary_compatibility(test_case: TestCase, reset_sys_path) -> None: + if test_case.xfail: + pytest.xfail(f"Test case {test_case.plugin_package} is expected to fail.") -@pytest.mark.parametrize("test_data", test_cases.services, indirect=True) -def test_service_can_be_instantiated(test_data: TestData) -> None: - test_data.plugin_module.TestStub(MockChannel()) + current_dir = Path(os.path.dirname(os.path.abspath(__file__))) + if test_case.reference_path: + for path in test_case.reference_path: + sys.path.append(str(current_dir / "outputs" / path)) -@pytest.mark.parametrize("test_data", test_cases.messages, indirect=True) -def test_binary_compatibility(test_data: TestData) -> None: - plugin_module, reference_module, json_data = test_data + plugin_module = importlib.import_module(f"tests.outputs.{test_case.plugin_package}") + reference_module = importlib.import_module(f"tests.outputs.{test_case.reference_package}") # TODO fix and delete if "map" in plugin_module.__file__.replace("\\", "/").split("/"): pytest.skip("Skipping this test for now.") - for sample in json_data: - reference_instance = Parse(sample.json, reference_module().Test()) + for json_path in test_case.jsons: + with open(current_dir / "inputs" / json_path) as f: + json_data = f.read() + + reference_instance = Parse(json_data, reference_module.Test()) reference_binary_output = reference_instance.SerializeToString() - plugin_instance_from_json: betterproto2.Message = plugin_module.Test().from_json(sample.json) + plugin_instance_from_json: betterproto2.Message = plugin_module.Test().from_json(json_data) plugin_instance_from_binary = plugin_module.Test.FromString(reference_binary_output) # Generally this can't be relied on, but here we are aiming to match the diff --git a/betterproto2/tests/test_manual_validation.py b/betterproto2/tests/test_manual_validation.py index 3f00d4b4..342ec685 100644 --- a/betterproto2/tests/test_manual_validation.py +++ b/betterproto2/tests/test_manual_validation.py @@ -3,7 +3,7 @@ def test_manual_validation(): - from tests.output_betterproto_pydantic.manual_validation import Msg + from tests.outputs.manual_validation_pydantic.manual_validation import Msg msg = Msg() @@ -16,7 +16,7 @@ def test_manual_validation(): def test_manual_validation_non_pydantic(): - from tests.output_betterproto.manual_validation import Msg + from tests.outputs.manual_validation.manual_validation import Msg # Validation is not available for non-pydantic messages with pytest.raises(TypeError): diff --git a/betterproto2/tests/test_mapmessage.py b/betterproto2/tests/test_mapmessage.py index 16bd6ce6..be324989 100644 --- a/betterproto2/tests/test_mapmessage.py +++ b/betterproto2/tests/test_mapmessage.py @@ -1,4 +1,4 @@ -from tests.output_betterproto.mapmessage import ( +from tests.outputs.mapmessage.mapmessage import ( Nested, Test, ) diff --git a/betterproto2/tests/test_message_wraping.py b/betterproto2/tests/test_message_wraping.py index d63d1648..e4252d98 100644 --- a/betterproto2/tests/test_message_wraping.py +++ b/betterproto2/tests/test_message_wraping.py @@ -2,7 +2,7 @@ def test_message_wrapping_map(): - from tests.output_betterproto.message_wrapping import MapMessage + from tests.outputs.message_wrapping.message_wrapping import MapMessage msg = MapMessage(map1={"key": 12.0}, map2={"key": datetime.timedelta(seconds=1)}) diff --git a/betterproto2/tests/test_pickling.py b/betterproto2/tests/test_pickling.py index 84eb50c3..aba2cf54 100644 --- a/betterproto2/tests/test_pickling.py +++ b/betterproto2/tests/test_pickling.py @@ -3,8 +3,8 @@ import cachelib -from tests.output_betterproto.google import protobuf as google -from tests.output_betterproto.pickling import Complex, Fe, Fi, NestedData, PickledMessage +from tests.outputs.pickling.google import protobuf as google +from tests.outputs.pickling.pickling import Complex, Fe, Fi, NestedData, PickledMessage def unpickled(message): @@ -46,7 +46,7 @@ def test_pickling_complex_message(): def test_recursive_message_defaults(): - from tests.output_betterproto.recursivemessage import Intermediate, Test as RecursiveMessage + from tests.outputs.recursivemessage.recursivemessage import Intermediate, Test as RecursiveMessage msg = RecursiveMessage(name="bob", intermediate=Intermediate(42)) msg = unpickled(msg) diff --git a/betterproto2/tests/test_streams.py b/betterproto2/tests/test_streams.py index 44524acf..3f84f2ef 100644 --- a/betterproto2/tests/test_streams.py +++ b/betterproto2/tests/test_streams.py @@ -6,13 +6,11 @@ import pytest import betterproto2 -from tests.output_betterproto import ( - map, - nested, - oneof, - repeated, - repeatedpacked, -) +from tests.outputs.map import map +from tests.outputs.nested import nested +from tests.outputs.oneof import oneof +from tests.outputs.repeated import repeated +from tests.outputs.repeatedpacked import repeatedpacked oneof_example = oneof.Test().from_dict({"pitied": 1, "just_a_regular_field": 123456789, "bar_name": "Testing"}) diff --git a/betterproto2/tests/test_sync_client.py b/betterproto2/tests/test_sync_client.py index be7b0395..ec0611ba 100644 --- a/betterproto2/tests/test_sync_client.py +++ b/betterproto2/tests/test_sync_client.py @@ -6,7 +6,7 @@ import pytest from grpclib.server import Server -from tests.output_betterproto.simple_service import Request, Response, SimpleServiceBase, SimpleServiceSyncStub +from tests.outputs.simple_service.simple_service import Request, Response, SimpleServiceBase, SimpleServiceSyncStub class SimpleService(SimpleServiceBase): diff --git a/betterproto2/tests/test_timestamp.py b/betterproto2/tests/test_timestamp.py index 4b76a10a..8187939b 100644 --- a/betterproto2/tests/test_timestamp.py +++ b/betterproto2/tests/test_timestamp.py @@ -5,7 +5,7 @@ import pytest -from tests.output_betterproto.google.protobuf import Timestamp +from tests.outputs.google.google.protobuf import Timestamp @pytest.mark.parametrize( diff --git a/betterproto2/tests/test_unwrap.py b/betterproto2/tests/test_unwrap.py index 034189f4..9d547813 100644 --- a/betterproto2/tests/test_unwrap.py +++ b/betterproto2/tests/test_unwrap.py @@ -3,7 +3,7 @@ def test_unwrap() -> None: from betterproto2 import unwrap - from tests.output_betterproto.unwrap import Message, NestedMessage + from tests.outputs.unwrap.unwrap import Message, NestedMessage with pytest.raises(ValueError): unwrap(Message().x) diff --git a/betterproto2/tests/test_validation.py b/betterproto2/tests/test_validation.py index c6f731bc..fa267013 100644 --- a/betterproto2/tests/test_validation.py +++ b/betterproto2/tests/test_validation.py @@ -3,7 +3,7 @@ def test_int32_validation(): - from .output_betterproto_pydantic.validation import Message + from .outputs.validation_pydantic.validation import Message # Test valid values Message(int32_value=1) @@ -18,7 +18,7 @@ def test_int32_validation(): def test_int64_validation(): - from .output_betterproto_pydantic.validation import Message + from .outputs.validation_pydantic.validation import Message # Test valid values Message(int64_value=1) @@ -33,7 +33,7 @@ def test_int64_validation(): def test_uint32_validation(): - from .output_betterproto_pydantic.validation import Message + from .outputs.validation_pydantic.validation import Message # Test valid values Message(uint32_value=0) @@ -47,7 +47,7 @@ def test_uint32_validation(): def test_uint64_validation(): - from .output_betterproto_pydantic.validation import Message + from .outputs.validation_pydantic.validation import Message # Test valid values Message(uint64_value=0) @@ -61,7 +61,7 @@ def test_uint64_validation(): def test_sint32_validation(): - from .output_betterproto_pydantic.validation import Message + from .outputs.validation_pydantic.validation import Message # Test valid values Message(sint32_value=1) @@ -76,7 +76,7 @@ def test_sint32_validation(): def test_sint64_validation(): - from .output_betterproto_pydantic.validation import Message + from .outputs.validation_pydantic.validation import Message # Test valid values Message(sint64_value=1) @@ -91,7 +91,7 @@ def test_sint64_validation(): def test_fixed32_validation(): - from .output_betterproto_pydantic.validation import Message + from .outputs.validation_pydantic.validation import Message # Test valid values Message(fixed32_value=0) @@ -105,7 +105,7 @@ def test_fixed32_validation(): def test_fixed64_validation(): - from .output_betterproto_pydantic.validation import Message + from .outputs.validation_pydantic.validation import Message # Test valid values Message(fixed64_value=0) @@ -119,7 +119,7 @@ def test_fixed64_validation(): def test_sfixed32_validation(): - from .output_betterproto_pydantic.validation import Message + from .outputs.validation_pydantic.validation import Message # Test valid values Message(sfixed32_value=1) @@ -134,7 +134,7 @@ def test_sfixed32_validation(): def test_sfixed64_validation(): - from .output_betterproto_pydantic.validation import Message + from .outputs.validation_pydantic.validation import Message # Test valid values Message(sfixed64_value=1) @@ -149,7 +149,7 @@ def test_sfixed64_validation(): def test_float_validation(): - from .output_betterproto_pydantic.validation import Message + from .outputs.validation_pydantic.validation import Message # Test valid values Message(float_value=0.0) @@ -161,7 +161,7 @@ def test_float_validation(): def test_string_validation(): - from .output_betterproto_pydantic.validation import Message + from .outputs.validation_pydantic.validation import Message # Test valid values Message(string_value="") diff --git a/betterproto2/tests/util.py b/betterproto2/tests/util.py index 90a8d9a5..c4ad0c54 100644 --- a/betterproto2/tests/util.py +++ b/betterproto2/tests/util.py @@ -22,9 +22,6 @@ class TestCaseJsonFile: test_name: str file_name: str - def belongs_to(self, non_symmetrical_json: dict[str, tuple[str, ...]]) -> bool: - return self.file_name in non_symmetrical_json.get(self.test_name, ()) - def get_test_case_json_data(test_case_name: str, *json_file_names: str) -> list[TestCaseJsonFile]: """ diff --git a/betterproto2_compiler/pyproject.toml b/betterproto2_compiler/pyproject.toml index ab3a8c6c..6181a640 100644 --- a/betterproto2_compiler/pyproject.toml +++ b/betterproto2_compiler/pyproject.toml @@ -88,7 +88,8 @@ cmd = "pytest" help = "Run tests" [tool.poe.tasks.generate] -sequence = ["_generate_tests", "_generate_tests_lib"] +# sequence = ["_generate_tests", "_generate_tests_lib"] +sequence = ["_generate_tests"] help = "Generate test cases" [tool.poe.tasks._generate_tests] diff --git a/betterproto2_compiler/src/betterproto2_compiler/known_types/google_values.py b/betterproto2_compiler/src/betterproto2_compiler/known_types/google_values.py index 3b1d4300..e5439011 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/known_types/google_values.py +++ b/betterproto2_compiler/src/betterproto2_compiler/known_types/google_values.py @@ -24,10 +24,10 @@ def to_wrapped(self) -> bool: return self.value @classmethod - def from_dict(cls, value): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False): if isinstance(value, bool): return BoolValue(value=value) - return super().from_dict(value) + return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields) def to_dict( self, @@ -48,10 +48,10 @@ def to_wrapped(self) -> int: return self.value @classmethod - def from_dict(cls, value): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False): if isinstance(value, int): return Int32Value(value=value) - return super().from_dict(value) + return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields) def to_dict( self, @@ -72,10 +72,10 @@ def to_wrapped(self) -> int: return self.value @classmethod - def from_dict(cls, value): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False): if isinstance(value, int): return Int64Value(value=value) - return super().from_dict(value) + return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields) def to_dict( self, @@ -96,10 +96,10 @@ def to_wrapped(self) -> int: return self.value @classmethod - def from_dict(cls, value): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False): if isinstance(value, int): return UInt32Value(value=value) - return super().from_dict(value) + return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields) def to_dict( self, @@ -120,10 +120,10 @@ def to_wrapped(self) -> int: return self.value @classmethod - def from_dict(cls, value): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False): if isinstance(value, int): return UInt64Value(value=value) - return super().from_dict(value) + return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields) def to_dict( self, @@ -144,10 +144,10 @@ def to_wrapped(self) -> float: return self.value @classmethod - def from_dict(cls, value): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False): if isinstance(value, float): return FloatValue(value=value) - return super().from_dict(value) + return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields) def to_dict( self, @@ -168,10 +168,10 @@ def to_wrapped(self) -> float: return self.value @classmethod - def from_dict(cls, value): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False): if isinstance(value, float): return DoubleValue(value=value) - return super().from_dict(value) + return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields) def to_dict( self, @@ -192,10 +192,10 @@ def to_wrapped(self) -> str: return self.value @classmethod - def from_dict(cls, value): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False): if isinstance(value, str): return StringValue(value=value) - return super().from_dict(value) + return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields) def to_dict( self, @@ -216,10 +216,10 @@ def to_wrapped(self) -> bytes: return self.value @classmethod - def from_dict(cls, value): + def from_dict(cls, value, *, ignore_unknown_fields: bool = False): if isinstance(value, bytes): return BytesValue(value=value) - return super().from_dict(value) + return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields) def to_dict( self, diff --git a/betterproto2_compiler/src/betterproto2_compiler/templates/service_stub_sync.py.j2 b/betterproto2_compiler/src/betterproto2_compiler/templates/service_stub_sync.py.j2 index 97647025..8cf49b29 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/templates/service_stub_sync.py.j2 +++ b/betterproto2_compiler/src/betterproto2_compiler/templates/service_stub_sync.py.j2 @@ -29,24 +29,22 @@ {% block method_body %} {% if method.server_streaming %} {% if method.client_streaming %} - for response in self._channel.stream_stream( + yield from self._channel.stream_stream( "{{ method.route }}", {{ method.py_input_message_type }}.SerializeToString, {{ method.py_output_message_type }}.FromString, - )(iter(messages)): - yield response + )(iter(messages)) {% else %} {% if method.is_input_msg_empty %} if message is None: message = {{ method.py_input_message_type }}() {% endif %} - for response in self._channel.unary_stream( + yield from self._channel.unary_stream( "{{ method.route }}", {{ method.py_input_message_type }}.SerializeToString, {{ method.py_output_message_type }}.FromString, - )(message): - yield response + )(message) {% endif %} {% else %} diff --git a/betterproto2_compiler/tests/generate.py b/betterproto2_compiler/tests/generate.py index 16b5297b..105c2b22 100644 --- a/betterproto2_compiler/tests/generate.py +++ b/betterproto2_compiler/tests/generate.py @@ -2,166 +2,178 @@ import asyncio import os import shutil -import sys -from pathlib import Path - -from tests.util import ( - get_directories, - inputs_path, - output_path_betterproto, - output_path_betterproto_descriptor, - output_path_betterproto_pydantic, - output_path_reference, - protoc, -) + +from tests.util import protoc # Force pure-python implementation instead of C++, otherwise imports # break things because we can't properly reset the symbol database. os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" -def clear_directory(dir_path: Path): - for file_or_directory in dir_path.glob("*"): - if file_or_directory.is_dir(): - shutil.rmtree(file_or_directory) - else: - file_or_directory.unlink() - - -async def generate(verbose: bool): - test_case_names = set(get_directories(inputs_path)) - {"__pycache__"} - - generation_tasks = [] - for test_case_name in sorted(test_case_names): - test_case_input_path = inputs_path.joinpath(test_case_name).resolve() - generation_tasks.append(generate_test_case_output(test_case_input_path, test_case_name, verbose)) - - failed_test_cases = [] - # Wait for all subprocs and match any failures to names to report - for test_case_name, result in zip(sorted(test_case_names), await asyncio.gather(*generation_tasks)): - if result != 0: - failed_test_cases.append(test_case_name) - - if len(failed_test_cases) > 0: - sys.stderr.write("\n\033[31;1;4mFailed to generate the following test cases:\033[0m\n") - for failed_test_case in failed_test_cases: - sys.stderr.write(f"- {failed_test_case}\n") - - sys.exit(1) - - -async def generate_test_case_output(test_case_input_path: Path, test_case_name: str, verbose: bool) -> int: - """ - Returns the max of the subprocess return values - """ - - test_case_output_path_reference = output_path_reference.joinpath(test_case_name) - test_case_output_path_betterproto = output_path_betterproto - test_case_output_path_betterproto_pyd = output_path_betterproto_pydantic - test_case_output_path_betterproto_desc = output_path_betterproto_descriptor - - os.makedirs(test_case_output_path_reference, exist_ok=True) - os.makedirs(test_case_output_path_betterproto, exist_ok=True) - os.makedirs(test_case_output_path_betterproto_pyd, exist_ok=True) - os.makedirs(test_case_output_path_betterproto_desc, exist_ok=True) - - clear_directory(test_case_output_path_reference) - clear_directory(test_case_output_path_betterproto) - clear_directory(test_case_output_path_betterproto_pyd) - clear_directory(test_case_output_path_betterproto_desc) - - ( - (ref_out, ref_err, ref_code), - (plg_out, plg_err, plg_code), - (plg_out_pyd, plg_err_pyd, plg_code_pyd), - (plg_out_desc, plg_err_desc, plg_code_desc), - ) = await asyncio.gather( - protoc(test_case_input_path, test_case_output_path_reference, True), - protoc(test_case_input_path, test_case_output_path_betterproto, False), - protoc(test_case_input_path, test_case_output_path_betterproto_pyd, False, True), - protoc(test_case_input_path, test_case_output_path_betterproto_desc, False, False, True), +async def generate_test( + name, + semaphore: asyncio.Semaphore, + *, + reference: bool = False, + pydantic: bool = False, + descriptors: bool = False, +): + await semaphore.acquire() + + dir_path = os.path.dirname(os.path.realpath(__file__)) + + options = [] + if reference: + options.append("reference") + if pydantic: + options.append("pydantic") + if descriptors: + options.append("descriptors") + + input_dir = dir_path + "/inputs/" + name + output_dir = dir_path + "/outputs/" + name + ("_" + "_".join(options) if options else "") + + os.mkdir(output_dir) + + stdout, stderr, returncode = await protoc( + input_dir, + output_dir, + reference=reference, + pydantic_dataclasses=pydantic, + google_protobuf_descriptors=descriptors, ) - if ref_code == 0: - print(f"\033[31;1;4mGenerated reference output for {test_case_name!r}\033[0m") - else: - print(f"\033[31;1;4mFailed to generate reference output for {test_case_name!r}\033[0m") - print(ref_err.decode()) - - if verbose: - if ref_out: - print("Reference stdout:") - sys.stdout.buffer.write(ref_out) - sys.stdout.buffer.flush() - - if ref_err: - print("Reference stderr:") - sys.stderr.buffer.write(ref_err) - sys.stderr.buffer.flush() - - if plg_code == 0: - print(f"\033[31;1;4mGenerated plugin output for {test_case_name!r}\033[0m") - else: - print(f"\033[31;1;4mFailed to generate plugin output for {test_case_name!r}\033[0m") - print(plg_err.decode()) - - if verbose: - if plg_out: - print("Plugin stdout:") - sys.stdout.buffer.write(plg_out) - sys.stdout.buffer.flush() - - if plg_err: - print("Plugin stderr:") - sys.stderr.buffer.write(plg_err) - sys.stderr.buffer.flush() - - if plg_code_pyd == 0: - print(f"\033[31;1;4mGenerated plugin (pydantic compatible) output for {test_case_name!r}\033[0m") + if options: + options_str = ", ".join(options) + options_str = f" ({options_str})" else: - print(f"\033[31;1;4mFailed to generate plugin (pydantic compatible) output for {test_case_name!r}\033[0m") - print(plg_err_pyd.decode()) - - if verbose: - if plg_out_pyd: - print("Plugin stdout:") - sys.stdout.buffer.write(plg_out_pyd) - sys.stdout.buffer.flush() - - if plg_err_pyd: - print("Plugin stderr:") - sys.stderr.buffer.write(plg_err_pyd) - sys.stderr.buffer.flush() - - if plg_code_desc == 0: - print(f"\033[31;1;4mGenerated plugin (google protobuf descriptor) output for {test_case_name!r}\033[0m") - else: - print( - f"\033[31;1;4mFailed to generate plugin (google protobuf descriptor) output for {test_case_name!r}\033[0m" - ) - print(plg_err_desc.decode()) - - if verbose: - if plg_out_desc: - print("Plugin stdout:") - sys.stdout.buffer.write(plg_out_desc) - sys.stdout.buffer.flush() + options_str = "" - if plg_err_desc: - print("Plugin stderr:") - sys.stderr.buffer.write(plg_err_desc) - sys.stderr.buffer.flush() - - return max(ref_code, plg_code, plg_code_pyd, plg_code_desc) + if returncode == 0: + print(f"\033[31;1;4mGenerated output for {name!r}{options_str}\033[0m") + else: + print(f"\033[31;1;4mFailed to generate reference output for {name!r}{options_str}\033[0m") + print(stderr.decode()) + + semaphore.release() + + +async def main_async(): + # Don't compile too many tests in parallel + semaphore = asyncio.Semaphore(os.cpu_count() or 1) + + tasks = [ + generate_test("any", semaphore), + generate_test("bool", semaphore, pydantic=True), + generate_test("bool", semaphore, reference=True), + generate_test("bool", semaphore), + generate_test("bytes", semaphore, reference=True), + generate_test("bytes", semaphore), + generate_test("casing_inner_class", semaphore), + generate_test("casing", semaphore, reference=True), + generate_test("casing", semaphore), + generate_test("deprecated", semaphore, reference=True), + generate_test("deprecated", semaphore), + generate_test("documentation", semaphore), + generate_test("double", semaphore, reference=True), + generate_test("double", semaphore), + generate_test("encoding_decoding", semaphore), + generate_test("enum", semaphore, reference=True), + generate_test("enum", semaphore), + generate_test("example_service", semaphore), + generate_test("features", semaphore), + generate_test("field_name_identical_to_type", semaphore, reference=True), + generate_test("field_name_identical_to_type", semaphore), + generate_test("fixed", semaphore, reference=True), + generate_test("fixed", semaphore), + generate_test("float", semaphore, reference=True), + generate_test("float", semaphore), + generate_test("google_impl_behavior_equivalence", semaphore, reference=True), + generate_test("google_impl_behavior_equivalence", semaphore), + generate_test("google", semaphore), + generate_test("googletypes_request", semaphore), + generate_test("googletypes_response_embedded", semaphore), + generate_test("googletypes_response", semaphore), + generate_test("googletypes_struct", semaphore, reference=True), + generate_test("googletypes_struct", semaphore), + generate_test("googletypes_value", semaphore, reference=True), + generate_test("googletypes_value", semaphore), + generate_test("googletypes", semaphore, reference=True), + generate_test("googletypes", semaphore), + generate_test("grpclib_reflection", semaphore, descriptors=True), + generate_test("grpclib_reflection", semaphore), + generate_test("import_cousin_package_same_name", semaphore, descriptors=True), + generate_test("import_cousin_package_same_name", semaphore), + generate_test("import_service_input_message", semaphore), + generate_test("int32", semaphore, reference=True), + generate_test("int32", semaphore), + generate_test("invalid_field", semaphore, pydantic=True), + generate_test("invalid_field", semaphore), + generate_test("manual_validation", semaphore, pydantic=True), + generate_test("manual_validation", semaphore), + generate_test("map", semaphore, reference=True), + generate_test("map", semaphore), + generate_test("mapmessage", semaphore, reference=True), + generate_test("mapmessage", semaphore), + generate_test("message_wrapping", semaphore), + generate_test("namespace_builtin_types", semaphore, reference=True), + generate_test("namespace_builtin_types", semaphore), + generate_test("namespace_keywords", semaphore, reference=True), + generate_test("namespace_keywords", semaphore), + generate_test("nested", semaphore, reference=True), + generate_test("nested", semaphore), + generate_test("nestedtwice", semaphore, reference=True), + generate_test("nestedtwice", semaphore), + generate_test("oneof_default_value_serialization", semaphore), + generate_test("oneof_empty", semaphore, reference=True), + generate_test("oneof_empty", semaphore), + generate_test("oneof_enum", semaphore, reference=True), + generate_test("oneof_enum", semaphore), + generate_test("oneof", semaphore, pydantic=True), + generate_test("oneof", semaphore, reference=True), + generate_test("oneof", semaphore), + generate_test("pickling", semaphore), + generate_test("proto3_field_presence_oneof", semaphore, reference=True), + generate_test("proto3_field_presence_oneof", semaphore), + generate_test("proto3_field_presence", semaphore, reference=True), + generate_test("proto3_field_presence", semaphore), + generate_test("recursivemessage", semaphore, reference=True), + generate_test("recursivemessage", semaphore), + generate_test("ref", semaphore, reference=True), + generate_test("ref", semaphore), + generate_test("regression_387", semaphore), + generate_test("regression_414", semaphore), + generate_test("repeated_duration_timestamp", semaphore, reference=True), + generate_test("repeated_duration_timestamp", semaphore), + generate_test("repeated", semaphore, reference=True), + generate_test("repeated", semaphore), + generate_test("repeatedmessage", semaphore, reference=True), + generate_test("repeatedmessage", semaphore), + generate_test("repeatedpacked", semaphore, reference=True), + generate_test("repeatedpacked", semaphore), + generate_test("rpc_empty_input_message", semaphore), + generate_test("service_uppercase", semaphore), + generate_test("service", semaphore), + generate_test("signed", semaphore, reference=True), + generate_test("signed", semaphore), + generate_test("simple_service", semaphore), + generate_test("stream_stream", semaphore), + generate_test("timestamp_dict_encode", semaphore, reference=True), + generate_test("timestamp_dict_encode", semaphore), + generate_test("unwrap", semaphore), + generate_test("validation", semaphore, pydantic=True), + ] + await asyncio.gather(*tasks) def main(): - if sys.argv[1:2] == ["-v"]: - verbose = True - else: - verbose = False + # Clean the output directory + dir_path = os.path.dirname(os.path.realpath(__file__)) + + shutil.rmtree(dir_path + "/outputs", ignore_errors=True) + os.mkdir(dir_path + "/outputs") - asyncio.run(generate(verbose)) + asyncio.run(main_async()) if __name__ == "__main__": diff --git a/betterproto2_compiler/tests/inputs/any/any.proto b/betterproto2_compiler/tests/inputs/any/any.proto index cd136be3..5bb95d71 100644 --- a/betterproto2_compiler/tests/inputs/any/any.proto +++ b/betterproto2_compiler/tests/inputs/any/any.proto @@ -1,5 +1,7 @@ syntax = "proto3"; +import "google/protobuf/any.proto"; + package any; message Person { diff --git a/betterproto2_compiler/tests/inputs/empty_repeated/empty_repeated.proto b/betterproto2_compiler/tests/inputs/empty_repeated/empty_repeated.proto deleted file mode 100644 index f787301f..00000000 --- a/betterproto2_compiler/tests/inputs/empty_repeated/empty_repeated.proto +++ /dev/null @@ -1,11 +0,0 @@ -syntax = "proto3"; - -package empty_repeated; - -message MessageA { - repeated float values = 1; -} - -message Test { - repeated MessageA msg = 1; -} diff --git a/betterproto2_compiler/tests/inputs/google/google.proto b/betterproto2_compiler/tests/inputs/google/google.proto new file mode 100644 index 00000000..67c4bbda --- /dev/null +++ b/betterproto2_compiler/tests/inputs/google/google.proto @@ -0,0 +1,12 @@ +syntax = "proto3"; + +import "google/protobuf/any.proto"; +import "google/protobuf/api.proto"; +import "google/protobuf/duration.proto"; +import "google/protobuf/empty.proto"; +import "google/protobuf/field_mask.proto"; +import "google/protobuf/source_context.proto"; +import "google/protobuf/struct.proto"; +import "google/protobuf/timestamp.proto"; +import "google/protobuf/type.proto"; +import "google/protobuf/wrappers.proto"; diff --git a/betterproto2_compiler/tests/inputs/grpclib_reflection/example_service.proto b/betterproto2_compiler/tests/inputs/grpclib_reflection/example_service.proto new file mode 100644 index 00000000..4ef60236 --- /dev/null +++ b/betterproto2_compiler/tests/inputs/grpclib_reflection/example_service.proto @@ -0,0 +1,24 @@ +syntax = "proto3"; + +package example_service; + +import "google/protobuf/struct.proto"; + +service Test { + rpc ExampleUnaryUnary(ExampleRequest) returns (ExampleResponse); + rpc ExampleUnaryStream(ExampleRequest) returns (stream ExampleResponse); + rpc ExampleStreamUnary(stream ExampleRequest) returns (ExampleResponse); + rpc ExampleStreamStream(stream ExampleRequest) returns (stream ExampleResponse); +} + +message ExampleRequest { + string example_string = 1; + int64 example_integer = 2; + google.protobuf.Struct example_struct = 3; +} + +message ExampleResponse { + string example_string = 1; + int64 example_integer = 2; + google.protobuf.Struct example_struct = 3; +} diff --git a/betterproto2_compiler/tests/inputs/grpclib_reflection/reflection.proto b/betterproto2_compiler/tests/inputs/grpclib_reflection/reflection.proto new file mode 100644 index 00000000..f9f349fe --- /dev/null +++ b/betterproto2_compiler/tests/inputs/grpclib_reflection/reflection.proto @@ -0,0 +1,146 @@ +// Copyright 2016 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Service exported by server reflection. A more complete description of how +// server reflection works can be found at +// https://github.com/grpc/grpc/blob/master/doc/server-reflection.md +// +// The canonical version of this proto can be found at +// https://github.com/grpc/grpc-proto/blob/master/grpc/reflection/v1/reflection.proto + +syntax = "proto3"; + +package grpc.reflection.v1; + +option go_package = "google.golang.org/grpc/reflection/grpc_reflection_v1"; +option java_multiple_files = true; +option java_package = "io.grpc.reflection.v1"; +option java_outer_classname = "ServerReflectionProto"; + +service ServerReflection { + // The reflection service is structured as a bidirectional stream, ensuring + // all related requests go to a single server. + rpc ServerReflectionInfo(stream ServerReflectionRequest) + returns (stream ServerReflectionResponse); +} + +// The message sent by the client when calling ServerReflectionInfo method. +message ServerReflectionRequest { + string host = 1; + // To use reflection service, the client should set one of the following + // fields in message_request. The server distinguishes requests by their + // defined field and then handles them using corresponding methods. + oneof message_request { + // Find a proto file by the file name. + string file_by_filename = 3; + + // Find the proto file that declares the given fully-qualified symbol name. + // This field should be a fully-qualified symbol name + // (e.g. .[.] or .). + string file_containing_symbol = 4; + + // Find the proto file which defines an extension extending the given + // message type with the given field number. + ExtensionRequest file_containing_extension = 5; + + // Finds the tag numbers used by all known extensions of the given message + // type, and appends them to ExtensionNumberResponse in an undefined order. + // Its corresponding method is best-effort: it's not guaranteed that the + // reflection service will implement this method, and it's not guaranteed + // that this method will provide all extensions. Returns + // StatusCode::UNIMPLEMENTED if it's not implemented. + // This field should be a fully-qualified type name. The format is + // . + string all_extension_numbers_of_type = 6; + + // List the full names of registered services. The content will not be + // checked. + string list_services = 7; + } +} + +// The type name and extension number sent by the client when requesting +// file_containing_extension. +message ExtensionRequest { + // Fully-qualified type name. The format should be . + string containing_type = 1; + int32 extension_number = 2; +} + +// The message sent by the server to answer ServerReflectionInfo method. +message ServerReflectionResponse { + string valid_host = 1; + ServerReflectionRequest original_request = 2; + // The server sets one of the following fields according to the message_request + // in the request. + oneof message_response { + // This message is used to answer file_by_filename, file_containing_symbol, + // file_containing_extension requests with transitive dependencies. + // As the repeated label is not allowed in oneof fields, we use a + // FileDescriptorResponse message to encapsulate the repeated fields. + // The reflection service is allowed to avoid sending FileDescriptorProtos + // that were previously sent in response to earlier requests in the stream. + FileDescriptorResponse file_descriptor_response = 4; + + // This message is used to answer all_extension_numbers_of_type requests. + ExtensionNumberResponse all_extension_numbers_response = 5; + + // This message is used to answer list_services requests. + ListServiceResponse list_services_response = 6; + + // This message is used when an error occurs. + ErrorResponse error_response = 7; + } +} + +// Serialized FileDescriptorProto messages sent by the server answering +// a file_by_filename, file_containing_symbol, or file_containing_extension +// request. +message FileDescriptorResponse { + // Serialized FileDescriptorProto messages. We avoid taking a dependency on + // descriptor.proto, which uses proto2 only features, by making them opaque + // bytes instead. + repeated bytes file_descriptor_proto = 1; +} + +// A list of extension numbers sent by the server answering +// all_extension_numbers_of_type request. +message ExtensionNumberResponse { + // Full name of the base type, including the package name. The format + // is . + string base_type_name = 1; + repeated int32 extension_number = 2; +} + +// A list of ServiceResponse sent by the server answering list_services request. +message ListServiceResponse { + // The information of each service may be expanded in the future, so we use + // ServiceResponse message to encapsulate it. + repeated ServiceResponse service = 1; +} + +// The information of a single service used by ListServiceResponse to answer +// list_services request. +message ServiceResponse { + // Full name of a registered service, including its package name. The format + // is . + string name = 1; +} + +// The error code and error message sent by the server when an error occurs. +message ErrorResponse { + // This field uses the error codes defined in grpc::StatusCode. + int32 error_code = 1; + string error_message = 2; +} diff --git a/betterproto2_compiler/tests/util.py b/betterproto2_compiler/tests/util.py index 8e12dda7..b837bbd5 100644 --- a/betterproto2_compiler/tests/util.py +++ b/betterproto2_compiler/tests/util.py @@ -5,18 +5,6 @@ os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" -root_path = Path(__file__).resolve().parent -inputs_path = root_path.joinpath("inputs") -output_path_reference = root_path.joinpath("output_reference") -output_path_betterproto = root_path.joinpath("output_betterproto") -output_path_betterproto_pydantic = root_path.joinpath("output_betterproto_pydantic") -output_path_betterproto_descriptor = root_path.joinpath("output_betterproto_descriptor") - - -def get_directories(path): - for root, directories, files in os.walk(path): - yield from directories - async def protoc( path: str | Path,