Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 50 additions & 62 deletions protovalidate/internal/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
import dataclasses
import datetime
import typing
from collections.abc import Callable
from collections.abc import Callable, Container, Iterable, Mapping

import celpy
from celpy import celtypes
from google.protobuf import any_pb2, descriptor, message, message_factory
from google.protobuf import any_pb2, descriptor, duration_pb2, message, message_factory

from buf.validate import validate_pb2 # type: ignore
from buf.validate import validate_pb2
from protovalidate.config import Config
from protovalidate.internal.cel_field_presence import InterpretedRunner, in_has

Expand All @@ -30,14 +30,14 @@ class CompilationError(Exception):
pass


def make_duration(msg: message.Message) -> celtypes.DurationType:
def make_duration(msg: duration_pb2.Duration) -> celtypes.DurationType:
return celtypes.DurationType(
seconds=msg.seconds, # type: ignore
nanos=msg.nanos, # type: ignore
seconds=msg.seconds,
nanos=msg.nanos,
)


def make_timestamp(msg: message.Message) -> celtypes.TimestampType:
def make_timestamp(msg: duration_pb2.Duration) -> celtypes.TimestampType:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a timestamp_pb2.Timestamp?

return celtypes.TimestampType(1970, 1, 1) + make_duration(msg)


Expand All @@ -62,7 +62,6 @@ def unwrap(msg: message.Message) -> celtypes.Value:

class MessageType(celtypes.MapType):
msg: message.Message
desc: descriptor.Descriptor

def __init__(self, msg: message.Message):
super().__init__()
Expand Down Expand Up @@ -163,7 +162,7 @@ def _field_value_to_cel(val: typing.Any, field: descriptor.FieldDescriptor) -> c


def _is_empty_field(msg: message.Message, field: descriptor.FieldDescriptor) -> bool:
if field.has_presence: # type: ignore[attr-defined]
if field.has_presence:
return not _proto_message_has_field(msg, field)
if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
return len(_proto_message_get_field(msg, field)) == 0
Expand All @@ -176,14 +175,11 @@ def _repeated_field_to_cel(msg: message.Message, field: descriptor.FieldDescript
return _repeated_field_value_to_cel(_proto_message_get_field(msg, field), field)


def _repeated_field_value_to_cel(val: typing.Any, field: descriptor.FieldDescriptor) -> celtypes.Value:
result = celtypes.ListType()
for item in val:
result.append(_scalar_field_value_to_cel(item, field))
return result
def _repeated_field_value_to_cel(val: Iterable, field: descriptor.FieldDescriptor) -> celtypes.Value:
return celtypes.ListType(_scalar_field_value_to_cel(item, field) for item in val)


def _map_field_value_to_cel(mapping: typing.Any, field: descriptor.FieldDescriptor) -> celtypes.Value:
def _map_field_value_to_cel(mapping: Mapping, field: descriptor.FieldDescriptor) -> celtypes.Value:
result = celtypes.MapType()
key_field = field.message_type.fields[0]
val_field = field.message_type.fields[1]
Expand Down Expand Up @@ -269,6 +265,7 @@ class RuleContext:
"""The state associated with a single rule evaluation."""

_cfg: Config
_violations: list[Violation]

def __init__(self, *, config: Config, violations: typing.Optional[list[Violation]] = None):
self._cfg = config
Expand Down Expand Up @@ -305,7 +302,7 @@ def done(self) -> bool:
def has_errors(self) -> bool:
return len(self._violations) > 0

def sub_context(self):
def sub_context(self) -> "RuleContext":
return RuleContext(config=self._cfg)


Expand Down Expand Up @@ -545,19 +542,17 @@ def __init__(
type_case = field_level.WhichOneof("type")
super().__init__(None if type_case is None else getattr(field_level, type_case))
self._field = field
self._ignore_empty = field_level.ignore in (validate_pb2.IGNORE_IF_ZERO_VALUE,) or (
field.has_presence # type: ignore[attr-defined]
and not for_items
self._ignore_empty = field_level.ignore == validate_pb2.IGNORE_IF_ZERO_VALUE or (
field.has_presence and not for_items
)
self._required = field_level.required
type_case = field_level.WhichOneof("type")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

already set above

if type_case is not None:
rules: message.Message = getattr(field_level, type_case)
# For each set field in the message, look for the private rule
# extension.
for list_field, _ in rules.ListFields():
if validate_pb2.predefined in list_field.GetOptions().Extensions:
for cel in list_field.GetOptions().Extensions[validate_pb2.predefined].cel:
if validate_pb2.predefined in list_field.GetOptions().Extensions: # type: ignore
for cel in list_field.GetOptions().Extensions[validate_pb2.predefined].cel: # type: ignore
self.add_rule(
env,
funcs,
Expand Down Expand Up @@ -646,25 +641,20 @@ def __init__(
field_level: validate_pb2.FieldRules,
):
super().__init__(env, funcs, field, field_level)
self._in = []
if getattr(field_level.any, "in"):
self._in = getattr(field_level.any, "in")
self._not_in = []
if field_level.any.not_in:
self._not_in = field_level.any.not_in
self._in = getattr(field_level.any, "in") or []
self._not_in: Container[str] = field_level.any.not_in or []

def _validate_value(self, ctx: RuleContext, value: any_pb2.Any, *, for_key: bool = False):
if len(self._in) > 0:
if value.type_url not in self._in:
ctx.add(
Violation(
rule=AnyRules._in_rule_path,
rule_value=self._in,
rule_id="any.in",
message="type URL must be in the allow list",
for_key=for_key,
)
if len(self._in) > 0 and value.type_url not in self._in:
ctx.add(
Violation(
rule=AnyRules._in_rule_path,
rule_value=self._in,
rule_id="any.in",
message="type URL must be in the allow list",
for_key=for_key,
)
)
if value.type_url in self._not_in:
ctx.add(
Violation(
Expand Down Expand Up @@ -710,22 +700,20 @@ def validate(self, ctx: RuleContext, message: message.Message):
super().validate(ctx, message)
if ctx.done:
return
if self._defined_only:
value = getattr(message, self._field.name)
if value not in self._field.enum_type.values_by_number:
ctx.add(
Violation(
field=validate_pb2.FieldPath(
elements=[
_field_to_element(self._field),
],
),
rule=EnumRules._defined_only_rule_path,
rule_value=self._defined_only,
rule_id="enum.defined_only",
message="value must be one of the defined enum values",
if self._defined_only and getattr(message, self._field.name) not in self._field.enum_type.values_by_number:
ctx.add(
Violation(
field=validate_pb2.FieldPath(
elements=[
_field_to_element(self._field),
],
),
)
rule=EnumRules._defined_only_rule_path,
rule_value=self._defined_only,
rule_id="enum.defined_only",
message="value must be one of the defined enum values",
),
)


class RepeatedRules(FieldRules):
Expand Down Expand Up @@ -875,7 +863,7 @@ def __init__(self, funcs: dict[str, celpy.CELFunction]):
self._funcs = funcs
self._cache = {}

def get(self, descriptor: descriptor.Descriptor) -> list[Rules]:
def get(self, descriptor) -> list[Rules]:
if descriptor not in self._cache:
try:
self._cache[descriptor] = self._new_rules(descriptor)
Expand Down Expand Up @@ -1042,8 +1030,8 @@ def _new_rules(self, desc: descriptor.Descriptor) -> list[Rules]:
result: list[Rules] = []
rule: typing.Optional[Rules] = None
all_msg_oneof_fields = set()
if validate_pb2.message in desc.GetOptions().Extensions:
message_level = desc.GetOptions().Extensions[validate_pb2.message]
if desc.GetOptions().HasExtension(validate_pb2.message): # type: ignore
message_level = desc.GetOptions().Extensions[validate_pb2.message] # type: ignore
for oneof in message_level.oneof:
all_msg_oneof_fields.update(oneof.fields)
if rule := self._new_message_rule(message_level, desc):
Expand Down Expand Up @@ -1094,8 +1082,8 @@ def __init__(
def validate(self, ctx: RuleContext, message: message.Message):
if not message.HasField(self._field.name):
return
rules = self._factory.get(self._field.message_type)
if rules is None:
rules: list[Rules] = self._factory.get(self._field.message_type)
if not rules:
Comment on lines +1085 to +1086
Copy link
Member Author

@stefanvanburen stefanvanburen Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noticed we weren't exiting early here, since self._factory.get is never returning None. Switched to check if we have any rules in the list. Should be slightly faster.

return
val = getattr(message, self._field.name)
sub_ctx = ctx.sub_context()
Expand Down Expand Up @@ -1124,8 +1112,8 @@ def validate(self, ctx: RuleContext, message: message.Message):
val = getattr(message, self._field.name)
if not val:
return
rules = self._factory.get(self._value_field.message_type)
if rules is None:
rules: list[Rules] = self._factory.get(self._value_field.message_type)
if not rules:
return
for k, v in val.items():
sub_ctx = ctx.sub_context()
Expand All @@ -1151,8 +1139,8 @@ def validate(self, ctx: RuleContext, message: message.Message):
val = getattr(message, self._field.name)
if not val:
return
rules = self._factory.get(self._field.message_type)
if rules is None:
rules: list[Rules] = self._factory.get(self._field.message_type)
if not rules:
return
for idx, item in enumerate(val):
sub_ctx = ctx.sub_context()
Expand Down
8 changes: 4 additions & 4 deletions protovalidate/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from google.protobuf import message

from buf.validate import validate_pb2 # type: ignore
from buf.validate import validate_pb2
from protovalidate.config import Config
from protovalidate.internal import extra_func
from protovalidate.internal import rules as _rules
Expand All @@ -38,7 +38,7 @@ class Validator:
_factory: _rules.RuleFactory
_cfg: Config

def __init__(self, config=None):
def __init__(self, config: typing.Optional[Config] = None):
self._cfg = config if config is not None else Config()
funcs = extra_func.make_extra_funcs()
self._factory = _rules.RuleFactory(funcs)
Expand Down Expand Up @@ -92,9 +92,9 @@ def collect_violations(
break
for violation in ctx.violations:
if violation.proto.HasField("field"):
violation.proto.field.elements.reverse()
violation.proto.field.elements.reverse() # type: ignore
if violation.proto.HasField("rule"):
violation.proto.rule.elements.reverse()
violation.proto.rule.elements.reverse() # type: ignore
return ctx.violations


Expand Down
8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ dev = [
"google-re2-stubs>=0.1.1",
"mypy>=1.17.1",
"ruff>=0.12.0",
"types-protobuf>=5",
"types-protobuf>=5.29.1.20250315",
]

[tool.hatch.version]
Expand Down Expand Up @@ -106,3 +106,9 @@ ban-relative-imports = "all"
[tool.ruff.lint.per-file-ignores]
# Tests can use magic values, assertions, and relative imports.
"tests/**/*" = ["PLR2004", "S101", "TID252"]

[tool.mypy]
mypy_path = "gen"

[tool.ty.environment]
extra-paths = ["gen"]
Loading