From 453f818a330beea37879cdc3e763af9fb6a150f4 Mon Sep 17 00:00:00 2001 From: puyj Date: Tue, 21 Jun 2022 16:03:10 +0800 Subject: [PATCH] [Python] Implement validate_all function Signed-off-by: puyj --- python/README.md | 4 +- python/protoc_gen_validate/validator.py | 57 +++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/python/README.md b/python/README.md index 017e68eca..e90c1d79c 100644 --- a/python/README.md +++ b/python/README.md @@ -8,11 +8,13 @@ in `validate.proto`. Implemented Python annotations are listed in the [rules com ### Example ```python3 from entities_pb2 import Person -from protoc_gen_validate.validator import validate, ValidationFailed +from protoc_gen_validate.validator import validate, ValidationFailed, validate_all p = Person(first_name="Foo", last_name="Bar", age=42) try: validate(p) + # you can also validate all rules + # validate_all(p) except ValidationFailed as err: print(err) ``` diff --git a/python/protoc_gen_validate/validator.py b/python/protoc_gen_validate/validator.py index bdca00b0a..e7cabcbbd 100644 --- a/python/protoc_gen_validate/validator.py +++ b/python/protoc_gen_validate/validator.py @@ -1,3 +1,4 @@ +import ast import re import struct import sys @@ -63,6 +64,62 @@ def _validate_inner(proto_message: Message): return locals()['generate_validate'] +class _Transformer(ast.NodeTransformer): + """ + Consider generated functions has the following structure: + + ``` + def generate_validate(p): + ... + if rules_stmt: + raise ValidationFailed(msg) + ... + return None + ``` + + Transformer made the following three changes: + + 1. Define a variable `err` that records all ValidationFailed error messages. + 2. Convert all `raise ValidationFailed(error_message)` to `err += error_message`. + 3. When `err` is not an empty string, `raise ValidationFailed(err)`. + """ + + def visit_FunctionDef(self, node: ast.FunctionDef): + self.generic_visit(node) + # add a suffix to the function name + node.name = node.name + "_all" + node.body.insert(0, ast.parse("err = ''").body[0]) + return node + + def visit_Raise(self, node: ast.Raise): + exc_str = " ".join(str(_.value) for _ in node.exc.args) + return ast.parse(rf'err += "\n{exc_str}"').body[0] + + def visit_Return(self, node: ast.Return): + return ast.parse("if err:\n raise ValidationFailed(err)").body[0] + + +# Cache generated functions with the message descriptor's full_name as the cache key +@lru_cache() +def _validate_all_inner(proto_message: Message): + func = file_template(ValidatingMessage(proto_message)) + func_ast = ast.parse(func) + func_ast = _Transformer().visit(func_ast) + func_ast = ast.fix_missing_locations(func_ast) + func = ast.unparse(func_ast) + global printer + printer += func + "\n" + exec(func) + try: + return generate_validate_all + except NameError: + return locals()['generate_validate_all'] + + +def validate_all(proto_message: Message): + return _validate_all_inner(ValidatingMessage(proto_message))(proto_message) + + def print_validate(): return "".join([s for s in printer.splitlines(True) if s.strip()])