-
Notifications
You must be signed in to change notification settings - Fork 9
Improve typing #359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve typing #359
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,13 +15,13 @@ | |
| import dataclasses | ||
| import datetime | ||
| import typing | ||
| from collections.abc import Callable | ||
| from collections.abc import Callable, Container, Iterable, Mapping | ||
|
|
||
| import celpy | ||
| from celpy import celtypes | ||
| from google.protobuf import any_pb2, descriptor, message, message_factory | ||
| from google.protobuf import any_pb2, descriptor, duration_pb2, message, message_factory | ||
|
|
||
| from buf.validate import validate_pb2 # type: ignore | ||
| from buf.validate import validate_pb2 | ||
| from protovalidate.config import Config | ||
| from protovalidate.internal.cel_field_presence import InterpretedRunner, in_has | ||
|
|
||
|
|
@@ -30,14 +30,14 @@ class CompilationError(Exception): | |
| pass | ||
|
|
||
|
|
||
| def make_duration(msg: message.Message) -> celtypes.DurationType: | ||
| def make_duration(msg: duration_pb2.Duration) -> celtypes.DurationType: | ||
| return celtypes.DurationType( | ||
| seconds=msg.seconds, # type: ignore | ||
| nanos=msg.nanos, # type: ignore | ||
| seconds=msg.seconds, | ||
| nanos=msg.nanos, | ||
| ) | ||
|
|
||
|
|
||
| def make_timestamp(msg: message.Message) -> celtypes.TimestampType: | ||
| def make_timestamp(msg: duration_pb2.Duration) -> celtypes.TimestampType: | ||
| return celtypes.TimestampType(1970, 1, 1) + make_duration(msg) | ||
|
|
||
|
|
||
|
|
@@ -62,7 +62,6 @@ def unwrap(msg: message.Message) -> celtypes.Value: | |
|
|
||
| class MessageType(celtypes.MapType): | ||
| msg: message.Message | ||
| desc: descriptor.Descriptor | ||
|
|
||
| def __init__(self, msg: message.Message): | ||
| super().__init__() | ||
|
|
@@ -163,7 +162,7 @@ def _field_value_to_cel(val: typing.Any, field: descriptor.FieldDescriptor) -> c | |
|
|
||
|
|
||
| def _is_empty_field(msg: message.Message, field: descriptor.FieldDescriptor) -> bool: | ||
| if field.has_presence: # type: ignore[attr-defined] | ||
| if field.has_presence: | ||
| return not _proto_message_has_field(msg, field) | ||
| if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: | ||
| return len(_proto_message_get_field(msg, field)) == 0 | ||
|
|
@@ -176,14 +175,11 @@ def _repeated_field_to_cel(msg: message.Message, field: descriptor.FieldDescript | |
| return _repeated_field_value_to_cel(_proto_message_get_field(msg, field), field) | ||
|
|
||
|
|
||
| def _repeated_field_value_to_cel(val: typing.Any, field: descriptor.FieldDescriptor) -> celtypes.Value: | ||
| result = celtypes.ListType() | ||
| for item in val: | ||
| result.append(_scalar_field_value_to_cel(item, field)) | ||
| return result | ||
| def _repeated_field_value_to_cel(val: Iterable, field: descriptor.FieldDescriptor) -> celtypes.Value: | ||
| return celtypes.ListType(_scalar_field_value_to_cel(item, field) for item in val) | ||
|
|
||
|
|
||
| def _map_field_value_to_cel(mapping: typing.Any, field: descriptor.FieldDescriptor) -> celtypes.Value: | ||
| def _map_field_value_to_cel(mapping: Mapping, field: descriptor.FieldDescriptor) -> celtypes.Value: | ||
| result = celtypes.MapType() | ||
| key_field = field.message_type.fields[0] | ||
| val_field = field.message_type.fields[1] | ||
|
|
@@ -269,6 +265,7 @@ class RuleContext: | |
| """The state associated with a single rule evaluation.""" | ||
|
|
||
| _cfg: Config | ||
| _violations: list[Violation] | ||
|
|
||
| def __init__(self, *, config: Config, violations: typing.Optional[list[Violation]] = None): | ||
| self._cfg = config | ||
|
|
@@ -305,7 +302,7 @@ def done(self) -> bool: | |
| def has_errors(self) -> bool: | ||
| return len(self._violations) > 0 | ||
|
|
||
| def sub_context(self): | ||
| def sub_context(self) -> "RuleContext": | ||
| return RuleContext(config=self._cfg) | ||
|
|
||
|
|
||
|
|
@@ -545,19 +542,17 @@ def __init__( | |
| type_case = field_level.WhichOneof("type") | ||
| super().__init__(None if type_case is None else getattr(field_level, type_case)) | ||
| self._field = field | ||
| self._ignore_empty = field_level.ignore in (validate_pb2.IGNORE_IF_ZERO_VALUE,) or ( | ||
| field.has_presence # type: ignore[attr-defined] | ||
| and not for_items | ||
| self._ignore_empty = field_level.ignore == validate_pb2.IGNORE_IF_ZERO_VALUE or ( | ||
| field.has_presence and not for_items | ||
| ) | ||
| self._required = field_level.required | ||
| type_case = field_level.WhichOneof("type") | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. already set above |
||
| if type_case is not None: | ||
| rules: message.Message = getattr(field_level, type_case) | ||
| # For each set field in the message, look for the private rule | ||
| # extension. | ||
| for list_field, _ in rules.ListFields(): | ||
| if validate_pb2.predefined in list_field.GetOptions().Extensions: | ||
| for cel in list_field.GetOptions().Extensions[validate_pb2.predefined].cel: | ||
| if validate_pb2.predefined in list_field.GetOptions().Extensions: # type: ignore | ||
| for cel in list_field.GetOptions().Extensions[validate_pb2.predefined].cel: # type: ignore | ||
| self.add_rule( | ||
| env, | ||
| funcs, | ||
|
|
@@ -646,25 +641,20 @@ def __init__( | |
| field_level: validate_pb2.FieldRules, | ||
| ): | ||
| super().__init__(env, funcs, field, field_level) | ||
| self._in = [] | ||
| if getattr(field_level.any, "in"): | ||
| self._in = getattr(field_level.any, "in") | ||
| self._not_in = [] | ||
| if field_level.any.not_in: | ||
| self._not_in = field_level.any.not_in | ||
| self._in = getattr(field_level.any, "in") or [] | ||
| self._not_in: Container[str] = field_level.any.not_in or [] | ||
|
|
||
| def _validate_value(self, ctx: RuleContext, value: any_pb2.Any, *, for_key: bool = False): | ||
| if len(self._in) > 0: | ||
| if value.type_url not in self._in: | ||
| ctx.add( | ||
| Violation( | ||
| rule=AnyRules._in_rule_path, | ||
| rule_value=self._in, | ||
| rule_id="any.in", | ||
| message="type URL must be in the allow list", | ||
| for_key=for_key, | ||
| ) | ||
| if len(self._in) > 0 and value.type_url not in self._in: | ||
| ctx.add( | ||
| Violation( | ||
| rule=AnyRules._in_rule_path, | ||
| rule_value=self._in, | ||
| rule_id="any.in", | ||
| message="type URL must be in the allow list", | ||
| for_key=for_key, | ||
| ) | ||
| ) | ||
| if value.type_url in self._not_in: | ||
| ctx.add( | ||
| Violation( | ||
|
|
@@ -710,22 +700,20 @@ def validate(self, ctx: RuleContext, message: message.Message): | |
| super().validate(ctx, message) | ||
| if ctx.done: | ||
| return | ||
| if self._defined_only: | ||
| value = getattr(message, self._field.name) | ||
| if value not in self._field.enum_type.values_by_number: | ||
| ctx.add( | ||
| Violation( | ||
| field=validate_pb2.FieldPath( | ||
| elements=[ | ||
| _field_to_element(self._field), | ||
| ], | ||
| ), | ||
| rule=EnumRules._defined_only_rule_path, | ||
| rule_value=self._defined_only, | ||
| rule_id="enum.defined_only", | ||
| message="value must be one of the defined enum values", | ||
| if self._defined_only and getattr(message, self._field.name) not in self._field.enum_type.values_by_number: | ||
| ctx.add( | ||
| Violation( | ||
| field=validate_pb2.FieldPath( | ||
| elements=[ | ||
| _field_to_element(self._field), | ||
| ], | ||
| ), | ||
| ) | ||
| rule=EnumRules._defined_only_rule_path, | ||
| rule_value=self._defined_only, | ||
| rule_id="enum.defined_only", | ||
| message="value must be one of the defined enum values", | ||
| ), | ||
| ) | ||
|
|
||
|
|
||
| class RepeatedRules(FieldRules): | ||
|
|
@@ -875,7 +863,7 @@ def __init__(self, funcs: dict[str, celpy.CELFunction]): | |
| self._funcs = funcs | ||
| self._cache = {} | ||
|
|
||
| def get(self, descriptor: descriptor.Descriptor) -> list[Rules]: | ||
| def get(self, descriptor) -> list[Rules]: | ||
| if descriptor not in self._cache: | ||
| try: | ||
| self._cache[descriptor] = self._new_rules(descriptor) | ||
|
|
@@ -1042,8 +1030,8 @@ def _new_rules(self, desc: descriptor.Descriptor) -> list[Rules]: | |
| result: list[Rules] = [] | ||
| rule: typing.Optional[Rules] = None | ||
| all_msg_oneof_fields = set() | ||
| if validate_pb2.message in desc.GetOptions().Extensions: | ||
| message_level = desc.GetOptions().Extensions[validate_pb2.message] | ||
| if desc.GetOptions().HasExtension(validate_pb2.message): # type: ignore | ||
| message_level = desc.GetOptions().Extensions[validate_pb2.message] # type: ignore | ||
| for oneof in message_level.oneof: | ||
| all_msg_oneof_fields.update(oneof.fields) | ||
| if rule := self._new_message_rule(message_level, desc): | ||
|
|
@@ -1094,8 +1082,8 @@ def __init__( | |
| def validate(self, ctx: RuleContext, message: message.Message): | ||
| if not message.HasField(self._field.name): | ||
| return | ||
| rules = self._factory.get(self._field.message_type) | ||
| if rules is None: | ||
| rules: list[Rules] = self._factory.get(self._field.message_type) | ||
| if not rules: | ||
|
Comment on lines
+1085
to
+1086
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. noticed we weren't exiting early here, since |
||
| return | ||
| val = getattr(message, self._field.name) | ||
| sub_ctx = ctx.sub_context() | ||
|
|
@@ -1124,8 +1112,8 @@ def validate(self, ctx: RuleContext, message: message.Message): | |
| val = getattr(message, self._field.name) | ||
| if not val: | ||
| return | ||
| rules = self._factory.get(self._value_field.message_type) | ||
| if rules is None: | ||
| rules: list[Rules] = self._factory.get(self._value_field.message_type) | ||
| if not rules: | ||
| return | ||
| for k, v in val.items(): | ||
| sub_ctx = ctx.sub_context() | ||
|
|
@@ -1151,8 +1139,8 @@ def validate(self, ctx: RuleContext, message: message.Message): | |
| val = getattr(message, self._field.name) | ||
| if not val: | ||
| return | ||
| rules = self._factory.get(self._field.message_type) | ||
| if rules is None: | ||
| rules: list[Rules] = self._factory.get(self._field.message_type) | ||
| if not rules: | ||
| return | ||
| for idx, item in enumerate(val): | ||
| sub_ctx = ctx.sub_context() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be a
timestamp_pb2.Timestamp?