Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 8 additions & 12 deletions protovalidate/internal/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def _scalar_field_value_to_cel(val: typing.Any, field: descriptor.FieldDescripto


def _field_value_to_cel(val: typing.Any, field: descriptor.FieldDescriptor) -> celtypes.Value:
if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
if field.is_repeated: # type: ignore
if field.message_type is not None and field.message_type.GetOptions().map_entry:
return _map_field_value_to_cel(val, field)
return _repeated_field_value_to_cel(val, field)
Expand All @@ -165,7 +165,7 @@ def _field_value_to_cel(val: typing.Any, field: descriptor.FieldDescriptor) -> c
def _is_empty_field(msg: message.Message, field: descriptor.FieldDescriptor) -> bool:
if field.has_presence:
return not _proto_message_has_field(msg, field)
if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
if field.is_repeated: # type: ignore
return len(_proto_message_get_field(msg, field)) == 0
return _proto_message_get_field(msg, field) == field.default_value

Expand Down Expand Up @@ -194,7 +194,7 @@ def _map_field_to_cel(msg: message.Message, field: descriptor.FieldDescriptor) -


def field_to_cel(msg: message.Message, field: descriptor.FieldDescriptor) -> celtypes.Value:
if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
if field.is_repeated: # type: ignore
return _repeated_field_to_cel(msg, field)
elif field.message_type is not None and not _proto_message_has_field(msg, field):
return None
Expand Down Expand Up @@ -484,19 +484,15 @@ def check_field_type(field: descriptor.FieldDescriptor, expected: int, wrapper_n


def _is_map(field: descriptor.FieldDescriptor):
return (
field.label == descriptor.FieldDescriptor.LABEL_REPEATED
and field.message_type is not None
and field.message_type.GetOptions().map_entry
)
return field.is_repeated and field.message_type is not None and field.message_type.GetOptions().map_entry # type: ignore


def _is_list(field: descriptor.FieldDescriptor):
return field.label == descriptor.FieldDescriptor.LABEL_REPEATED and not _is_map(field)
return field.is_repeated and not _is_map(field) # type: ignore


def _zero_value(field: descriptor.FieldDescriptor):
if field.message_type is not None and field.label != descriptor.FieldDescriptor.LABEL_REPEATED:
if field.message_type is not None and not field.is_repeated: # type: ignore
return _field_value_to_cel(message_factory.GetMessageClass(field.message_type)(), field)
else:
return _field_value_to_cel(field.default_value, field)
Expand Down Expand Up @@ -1003,7 +999,7 @@ def _new_field_rule(
field: descriptor.FieldDescriptor,
rules: validate_pb2.FieldRules,
) -> FieldRules:
if field.label != descriptor.FieldDescriptor.LABEL_REPEATED:
if not field.is_repeated: # type: ignore
return self._new_scalar_field_rule(field, rules)
if field.message_type is not None and field.message_type.GetOptions().map_entry:
key_rules = None
Expand Down Expand Up @@ -1057,7 +1053,7 @@ def _new_rules(self, desc: descriptor.Descriptor) -> list[Rules]:
if value_field.type != descriptor.FieldDescriptor.TYPE_MESSAGE:
continue
result.append(MapValMsgRule(self, field, key_field, value_field))
elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
elif field.is_repeated:
result.append(RepeatedMsgRule(self, field))
else:
result.append(SubMsgRule(self, field))
Expand Down
Loading