From 62db6f0daa434661f4b7362a78d5a3db4efeccc1 Mon Sep 17 00:00:00 2001 From: Patrick Ogenstad Date: Thu, 3 Apr 2025 10:07:12 +0200 Subject: [PATCH] Add sync option for jinja2 templates and set it as default --- infrahub_sdk/ctl/cli_commands.py | 6 +- .../pytest_plugin/items/jinja2_transform.py | 9 +- infrahub_sdk/template/__init__.py | 79 ++++++++++----- tests/unit/sdk/test_template.py | 97 ++++++++++++++----- 4 files changed, 132 insertions(+), 59 deletions(-) diff --git a/infrahub_sdk/ctl/cli_commands.py b/infrahub_sdk/ctl/cli_commands.py index 0d9a850f..cf685c27 100644 --- a/infrahub_sdk/ctl/cli_commands.py +++ b/infrahub_sdk/ctl/cli_commands.py @@ -43,7 +43,7 @@ from ..ctl.validate import app as validate_app from ..exceptions import GraphQLError, ModuleImportError from ..schema import MainSchemaTypesAll, SchemaRoot -from ..template import Jinja2Template +from ..template import Jinja2TemplateSync from ..template.exceptions import JinjaTemplateError from ..utils import get_branch, write_to_file from ..yaml import SchemaFile @@ -178,9 +178,9 @@ async def run( async def render_jinja2_template(template_path: Path, variables: dict[str, Any], data: dict[str, Any]) -> str: variables["data"] = data - jinja_template = Jinja2Template(template=Path(template_path), template_directory=Path()) + jinja_template = Jinja2TemplateSync(template=Path(template_path), template_directory=Path()) try: - rendered_tpl = await jinja_template.render(variables=variables) + rendered_tpl = jinja_template.render(variables=variables) except JinjaTemplateError as exc: print_template_errors(error=exc, console=console) raise typer.Exit(1) from exc diff --git a/infrahub_sdk/pytest_plugin/items/jinja2_transform.py b/infrahub_sdk/pytest_plugin/items/jinja2_transform.py index 4ed2e2c5..06e39157 100644 --- a/infrahub_sdk/pytest_plugin/items/jinja2_transform.py +++ b/infrahub_sdk/pytest_plugin/items/jinja2_transform.py @@ -1,6 +1,5 @@ from __future__ import annotations -import asyncio import difflib from pathlib import Path from typing import TYPE_CHECKING, Any @@ -9,7 +8,7 @@ import ujson from httpx import HTTPStatusError -from ...template import Jinja2Template +from ...template import Jinja2TemplateSync from ...template.exceptions import JinjaTemplateError from ..exceptions import OutputMatchError from ..models import InfrahubInputOutputTest, InfrahubTestExpectedResult @@ -20,8 +19,8 @@ class InfrahubJinja2Item(InfrahubItem): - def _get_jinja2(self) -> Jinja2Template: - return Jinja2Template( + def _get_jinja2(self) -> Jinja2TemplateSync: + return Jinja2TemplateSync( template=Path(self.resource_config.template_path), # type: ignore[attr-defined] template_directory=Path(self.session.infrahub_config_path.parent), # type: ignore[attr-defined] ) @@ -38,7 +37,7 @@ def render_jinja2_template(self, variables: dict[str, Any]) -> str | None: jinja2_template = self._get_jinja2() try: - return asyncio.run(jinja2_template.render(variables=variables)) + return jinja2_template.render(variables=variables) except JinjaTemplateError as exc: if self.test.expect == InfrahubTestExpectedResult.PASS: raise exc diff --git a/infrahub_sdk/template/__init__.py b/infrahub_sdk/template/__init__.py index c43f7ad9..90abc440 100644 --- a/infrahub_sdk/template/__init__.py +++ b/infrahub_sdk/template/__init__.py @@ -2,7 +2,7 @@ import linecache from pathlib import Path -from typing import Any, Callable, NoReturn +from typing import Any, Callable, ClassVar, NoReturn import jinja2 from jinja2 import meta, nodes @@ -24,7 +24,9 @@ netutils_filters = jinja2_convenience_function() -class Jinja2Template: +class Jinja2TemplateBase: + _is_async: ClassVar[bool] = True + def __init__( self, template: str | Path, @@ -106,29 +108,8 @@ def validate(self, restricted: bool = True) -> None: f"These operations are forbidden for string based templates: {forbidden_operations}" ) - async def render(self, variables: dict[str, Any]) -> str: - template = self.get_template() - try: - output = await template.render_async(variables) - except jinja2.exceptions.TemplateNotFound as exc: - raise JinjaTemplateNotFoundError(message=exc.message, filename=str(exc.name), base_template=template.name) - except jinja2.TemplateSyntaxError as exc: - self._raise_template_syntax_error(error=exc) - except jinja2.UndefinedError as exc: - traceback = Traceback(show_locals=False) - errors = _identify_faulty_jinja_code(traceback=traceback) - raise JinjaTemplateUndefinedError(message=exc.message, errors=errors) - except Exception as exc: - if error_message := getattr(exc, "message", None): - message = error_message - else: - message = str(exc) - raise JinjaTemplateError(message=message or "Unknown template error") - - return output - def _get_string_based_environment(self) -> jinja2.Environment: - env = SandboxedEnvironment(enable_async=True, undefined=jinja2.StrictUndefined) + env = SandboxedEnvironment(enable_async=self._is_async, undefined=jinja2.StrictUndefined) self._set_filters(env=env) self._environment = env return self._environment @@ -139,7 +120,7 @@ def _get_file_based_environment(self) -> jinja2.Environment: loader=template_loader, trim_blocks=True, lstrip_blocks=True, - enable_async=True, + enable_async=self._is_async, ) self._set_filters(env=env) self._environment = env @@ -177,6 +158,54 @@ def _raise_template_syntax_error(self, error: jinja2.TemplateSyntaxError) -> NoR raise JinjaTemplateSyntaxError(message=error.message, filename=filename, lineno=error.lineno) +class Jinja2Template(Jinja2TemplateBase): + async def render(self, variables: dict[str, Any]) -> str: + template = self.get_template() + try: + output = await template.render_async(variables) + except jinja2.exceptions.TemplateNotFound as exc: + raise JinjaTemplateNotFoundError(message=exc.message, filename=str(exc.name), base_template=template.name) + except jinja2.TemplateSyntaxError as exc: + self._raise_template_syntax_error(error=exc) + except jinja2.UndefinedError as exc: + traceback = Traceback(show_locals=False) + errors = _identify_faulty_jinja_code(traceback=traceback) + raise JinjaTemplateUndefinedError(message=exc.message, errors=errors) + except Exception as exc: + if error_message := getattr(exc, "message", None): + message = error_message + else: + message = str(exc) + raise JinjaTemplateError(message=message or "Unknown template error") + + return output + + +class Jinja2TemplateSync(Jinja2TemplateBase): + _is_async: ClassVar[bool] = False + + def render(self, variables: dict[str, Any]) -> str: + template = self.get_template() + try: + output = template.render(variables) + except jinja2.exceptions.TemplateNotFound as exc: + raise JinjaTemplateNotFoundError(message=exc.message, filename=str(exc.name), base_template=template.name) + except jinja2.TemplateSyntaxError as exc: + self._raise_template_syntax_error(error=exc) + except jinja2.UndefinedError as exc: + traceback = Traceback(show_locals=False) + errors = _identify_faulty_jinja_code(traceback=traceback) + raise JinjaTemplateUndefinedError(message=exc.message, errors=errors) + except Exception as exc: + if error_message := getattr(exc, "message", None): + message = error_message + else: + message = str(exc) + raise JinjaTemplateError(message=message or "Unknown template error") + + return output + + def _identify_faulty_jinja_code(traceback: Traceback, nbr_context_lines: int = 3) -> list[UndefinedJinja2Error]: """This function identifies the faulty Jinja2 code and beautify it to provide meaningful information to the user. diff --git a/tests/unit/sdk/test_template.py b/tests/unit/sdk/test_template.py index b8854e54..4cd962c0 100644 --- a/tests/unit/sdk/test_template.py +++ b/tests/unit/sdk/test_template.py @@ -6,7 +6,7 @@ from rich.syntax import Syntax from rich.traceback import Frame -from infrahub_sdk.template import Jinja2Template +from infrahub_sdk.template import Jinja2Template, Jinja2TemplateSync from infrahub_sdk.template.exceptions import ( JinjaTemplateError, JinjaTemplateNotFoundError, @@ -78,9 +78,15 @@ class JinjaTestCaseFailing: "test_case", [pytest.param(tc, id=tc.name) for tc in SUCCESSFUL_STRING_TEST_CASES], ) -async def test_render_string(test_case: JinjaTestCase) -> None: - jinja = Jinja2Template(template=test_case.template) - assert test_case.expected == await jinja.render(variables=test_case.variables) +@pytest.mark.parametrize("is_async", [True, False]) +async def test_render_string(test_case: JinjaTestCase, is_async: bool) -> None: + if is_async: + jinja = Jinja2Template(template=test_case.template) + assert test_case.expected == await jinja.render(variables=test_case.variables) + else: + jinja = Jinja2TemplateSync(template=test_case.template) + assert test_case.expected == jinja.render(variables=test_case.variables) + assert test_case.expected_variables == jinja.get_variables() @@ -106,9 +112,14 @@ async def test_render_string(test_case: JinjaTestCase) -> None: "test_case", [pytest.param(tc, id=tc.name) for tc in SUCCESSFUL_FILE_TEST_CASES], ) -async def test_render_template_from_file(test_case: JinjaTestCase) -> None: - jinja = Jinja2Template(template=Path(test_case.template), template_directory=TEMPLATE_DIRECTORY) - assert test_case.expected == await jinja.render(variables=test_case.variables) +@pytest.mark.parametrize("is_async", [True, False]) +async def test_render_template_from_file(test_case: JinjaTestCase, is_async: bool) -> None: + if is_async: + jinja = Jinja2Template(template=Path(test_case.template), template_directory=TEMPLATE_DIRECTORY) + assert test_case.expected == await jinja.render(variables=test_case.variables) + else: + jinja = Jinja2TemplateSync(template=Path(test_case.template), template_directory=TEMPLATE_DIRECTORY) + assert test_case.expected == jinja.render(variables=test_case.variables) assert test_case.expected_variables == jinja.get_variables() assert jinja.get_template() @@ -153,10 +164,16 @@ async def test_render_template_from_file(test_case: JinjaTestCase) -> None: "test_case", [pytest.param(tc, id=tc.name) for tc in FAILING_STRING_TEST_CASES], ) -async def test_render_string_errors(test_case: JinjaTestCaseFailing) -> None: - jinja = Jinja2Template(template=test_case.template, template_directory=TEMPLATE_DIRECTORY) - with pytest.raises(test_case.error.__class__) as exc: - await jinja.render(variables=test_case.variables) +@pytest.mark.parametrize("is_async", [True, False]) +async def test_render_string_errors(test_case: JinjaTestCaseFailing, is_async: bool) -> None: + if is_async: + jinja = Jinja2Template(template=test_case.template, template_directory=TEMPLATE_DIRECTORY) + with pytest.raises(test_case.error.__class__) as exc: + await jinja.render(variables=test_case.variables) + else: + jinja = Jinja2TemplateSync(template=test_case.template, template_directory=TEMPLATE_DIRECTORY) + with pytest.raises(test_case.error.__class__) as exc: + jinja.render(variables=test_case.variables) _compare_errors(expected=test_case.error, received=exc.value) @@ -234,27 +251,50 @@ async def test_render_string_errors(test_case: JinjaTestCaseFailing) -> None: "test_case", [pytest.param(tc, id=tc.name) for tc in FAILING_FILE_TEST_CASES], ) -async def test_manage_file_based_errors(test_case: JinjaTestCaseFailing) -> None: - jinja = Jinja2Template(template=Path(test_case.template), template_directory=TEMPLATE_DIRECTORY) - with pytest.raises(test_case.error.__class__) as exc: - await jinja.render(variables=test_case.variables) +@pytest.mark.parametrize("is_async", [True, False]) +async def test_manage_file_based_errors(test_case: JinjaTestCaseFailing, is_async: bool) -> None: + if is_async: + jinja = Jinja2Template(template=Path(test_case.template), template_directory=TEMPLATE_DIRECTORY) + with pytest.raises(test_case.error.__class__) as exc: + await jinja.render(variables=test_case.variables) + else: + jinja = Jinja2TemplateSync(template=Path(test_case.template), template_directory=TEMPLATE_DIRECTORY) + with pytest.raises(test_case.error.__class__) as exc: + jinja.render(variables=test_case.variables) _compare_errors(expected=test_case.error, received=exc.value) -async def test_manage_unhandled_error() -> None: - jinja = Jinja2Template( - template="Hello {{ number | divide_by_zero }}", - filters={"divide_by_zero": _divide_by_zero}, - ) - with pytest.raises(JinjaTemplateError) as exc: - await jinja.render(variables={"number": 1}) +@pytest.mark.parametrize("is_async", [True, False]) +async def test_manage_unhandled_error(is_async: bool) -> None: + template = "Hello {{ number | divide_by_zero }}" + filters = {"divide_by_zero": _divide_by_zero} + if is_async: + jinja = Jinja2Template( + template=template, + filters=filters, + ) + with pytest.raises(JinjaTemplateError) as exc: + await jinja.render(variables={"number": 1}) + else: + jinja = Jinja2TemplateSync( + template=template, + filters=filters, + ) + with pytest.raises(JinjaTemplateError) as exc: + jinja.render(variables={"number": 1}) assert exc.value.message == "division by zero" -async def test_validate_filter() -> None: - jinja = Jinja2Template(template="{{ network | get_all_host }}") +@pytest.mark.parametrize("is_async", [True, False]) +async def test_validate_filter(is_async: bool) -> None: + template = "{{ network | get_all_host }}" + if is_async: + jinja = Jinja2Template(template=template) + else: + jinja = Jinja2TemplateSync(template=template) + jinja.validate(restricted=False) with pytest.raises(JinjaTemplateOperationViolationError) as exc: jinja.validate(restricted=True) @@ -262,8 +302,13 @@ async def test_validate_filter() -> None: assert exc.value.message == "The 'get_all_host' filter isn't allowed to be used" -async def test_validate_operation() -> None: - jinja = Jinja2Template(template="Hello {% include 'very-forbidden.j2' %}") +@pytest.mark.parametrize("is_async", [True, False]) +async def test_validate_operation(is_async: bool) -> None: + if is_async: + jinja = Jinja2Template(template="Hello {% include 'very-forbidden.j2' %}") + else: + jinja = Jinja2TemplateSync(template="Hello {% include 'very-forbidden.j2' %}") + with pytest.raises(JinjaTemplateOperationViolationError) as exc: jinja.validate(restricted=True)