Skip to content

Commit

Permalink
fix: modules referenced in MapField message type are properly aliased (
Browse files Browse the repository at this point in the history
…#654)

This was noticed when attempting to generate Bigtable Admin in a
message definition: an imported module is given an alias to prevent
collision with a field name. When the module is referenced to describe
the type of a singleton field it is properly disambiguated. When used
to describe the type of a MapField it is _not_ disambiguated.

Fix for that.

Closes #618
  • Loading branch information
software-dov committed Oct 13, 2020
1 parent d7829ef commit 2c79349
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 5 deletions.
14 changes: 14 additions & 0 deletions gapic/schema/metadata.py
Expand Up @@ -90,6 +90,20 @@ def __str__(self) -> str:
# Return the Python identifier.
return '.'.join(self.parent + (self.name,))

def __repr__(self) -> str:
return "({})".format(
", ".join(
(
self.name,
self.module,
str(self.module_path),
str(self.package),
str(self.parent),
str(self.api_naming),
)
)
)

@property
def module_alias(self) -> str:
"""Return an appropriate module alias if necessary.
Expand Down
26 changes: 21 additions & 5 deletions gapic/schema/wrappers.py
Expand Up @@ -222,7 +222,12 @@ def type(self) -> Union['MessageType', 'EnumType', 'PrimitiveType']:
raise TypeError(f'Unrecognized protobuf type: {self.field_pb.type}. '
'This code should not be reachable; please file a bug.')

def with_context(self, *, collisions: FrozenSet[str]) -> 'Field':
def with_context(
self,
*,
collisions: FrozenSet[str],
visited_messages: FrozenSet["MessageType"],
) -> 'Field':
"""Return a derivative of this field with the provided context.
This method is used to address naming collisions. The returned
Expand All @@ -233,7 +238,8 @@ def with_context(self, *, collisions: FrozenSet[str]) -> 'Field':
self,
message=self.message.with_context(
collisions=collisions,
skip_fields=True,
skip_fields=self.message in visited_messages,
visited_messages=visited_messages,
) if self.message else None,
enum=self.enum.with_context(collisions=collisions)
if self.enum else None,
Expand Down Expand Up @@ -406,7 +412,10 @@ def get_field(self, *field_path: str,

# Base case: If this is the last field in the path, return it outright.
if len(field_path) == 1:
return cursor.with_context(collisions=collisions)
return cursor.with_context(
collisions=collisions,
visited_messages=frozenset({self}),
)

# Sanity check: If cursor is a repeated field, then raise an exception.
# Repeated fields are only permitted in the terminal position.
Expand All @@ -433,6 +442,7 @@ def get_field(self, *field_path: str,
def with_context(self, *,
collisions: FrozenSet[str],
skip_fields: bool = False,
visited_messages: FrozenSet["MessageType"] = frozenset(),
) -> 'MessageType':
"""Return a derivative of this message with the provided context.
Expand All @@ -444,10 +454,14 @@ def with_context(self, *,
underlying fields. This provides for an "exit" in the case of circular
references.
"""
visited_messages = visited_messages | {self}
return dataclasses.replace(
self,
fields=collections.OrderedDict(
(k, v.with_context(collisions=collisions))
(k, v.with_context(
collisions=collisions,
visited_messages=visited_messages
))
for k, v in self.fields.items()
) if not skip_fields else self.fields,
nested_enums=collections.OrderedDict(
Expand All @@ -457,7 +471,9 @@ def with_context(self, *,
nested_messages=collections.OrderedDict(
(k, v.with_context(
collisions=collisions,
skip_fields=skip_fields,))
skip_fields=skip_fields,
visited_messages=visited_messages,
))
for k, v in self.nested_messages.items()),
meta=self.meta.with_context(collisions=collisions),
)
Expand Down
73 changes: 73 additions & 0 deletions tests/unit/schema/test_api.py
Expand Up @@ -1214,3 +1214,76 @@ def test_resources_referenced_but_not_typed(reference_attr="type"):

def test_resources_referenced_but_not_typed_child_type():
test_resources_referenced_but_not_typed("child_type")


def test_map_field_name_disambiguation():
squid_file_pb = descriptor_pb2.FileDescriptorProto(
name="mollusc.proto",
package="animalia.mollusca.v2",
message_type=(
descriptor_pb2.DescriptorProto(
name="Mollusc",
),
),
)
method_types_file_pb = descriptor_pb2.FileDescriptorProto(
name="mollusc_service.proto",
package="animalia.mollusca.v2",
message_type=(
descriptor_pb2.DescriptorProto(
name="CreateMolluscRequest",
field=(
descriptor_pb2.FieldDescriptorProto(
name="mollusc",
type="TYPE_MESSAGE",
type_name=".animalia.mollusca.v2.Mollusc",
number=1,
),
descriptor_pb2.FieldDescriptorProto(
name="molluscs_map",
type="TYPE_MESSAGE",
number=2,
type_name=".animalia.mollusca.v2.CreateMolluscRequest.MolluscsMapEntry",
label="LABEL_REPEATED",
),
),
nested_type=(
descriptor_pb2.DescriptorProto(
name="MolluscsMapEntry",
field=(
descriptor_pb2.FieldDescriptorProto(
name="key",
type="TYPE_STRING",
number=1,
),
descriptor_pb2.FieldDescriptorProto(
name="value",
type="TYPE_MESSAGE",
number=2,
# We use the same type for the map value as for
# the singleton above to better highlight the
# problem raised in
# https://github.com/googleapis/gapic-generator-python/issues/618.
# The module _is_ disambiguated for singleton
# fields but NOT for map fields.
type_name=".animalia.mollusca.v2.Mollusc"
),
),
options=descriptor_pb2.MessageOptions(map_entry=True),
),
),
),
),
)
my_api = api.API.build(
file_descriptors=[squid_file_pb, method_types_file_pb],
)
create = my_api.messages['animalia.mollusca.v2.CreateMolluscRequest']
mollusc = create.fields['mollusc']
molluscs_map = create.fields['molluscs_map']
mollusc_ident = str(mollusc.type.ident)
mollusc_map_ident = str(molluscs_map.message.fields['value'].type.ident)

# The same module used in the same place should have the same import alias.
# Because there's a "mollusc" name used, the import should be disambiguated.
assert mollusc_ident == mollusc_map_ident == "am_mollusc.Mollusc"

0 comments on commit 2c79349

Please sign in to comment.