Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions infrahub_sdk/ctl/cli_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from ..ctl.validate import app as validate_app
from ..exceptions import GraphQLError, ModuleImportError
from ..schema import MainSchemaTypesAll, SchemaRoot
from ..template import Jinja2Template
from ..template import Jinja2TemplateSync
from ..template.exceptions import JinjaTemplateError
from ..utils import get_branch, write_to_file
from ..yaml import SchemaFile
Expand Down Expand Up @@ -178,9 +178,9 @@ async def run(

async def render_jinja2_template(template_path: Path, variables: dict[str, Any], data: dict[str, Any]) -> str:
variables["data"] = data
jinja_template = Jinja2Template(template=Path(template_path), template_directory=Path())
jinja_template = Jinja2TemplateSync(template=Path(template_path), template_directory=Path())
try:
rendered_tpl = await jinja_template.render(variables=variables)
rendered_tpl = jinja_template.render(variables=variables)
except JinjaTemplateError as exc:
print_template_errors(error=exc, console=console)
raise typer.Exit(1) from exc
Expand Down
9 changes: 4 additions & 5 deletions infrahub_sdk/pytest_plugin/items/jinja2_transform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import asyncio
import difflib
from pathlib import Path
from typing import TYPE_CHECKING, Any
Expand All @@ -9,7 +8,7 @@
import ujson
from httpx import HTTPStatusError

from ...template import Jinja2Template
from ...template import Jinja2TemplateSync
from ...template.exceptions import JinjaTemplateError
from ..exceptions import OutputMatchError
from ..models import InfrahubInputOutputTest, InfrahubTestExpectedResult
Expand All @@ -20,8 +19,8 @@


class InfrahubJinja2Item(InfrahubItem):
def _get_jinja2(self) -> Jinja2Template:
return Jinja2Template(
def _get_jinja2(self) -> Jinja2TemplateSync:
return Jinja2TemplateSync(
template=Path(self.resource_config.template_path), # type: ignore[attr-defined]
template_directory=Path(self.session.infrahub_config_path.parent), # type: ignore[attr-defined]
)
Expand All @@ -38,7 +37,7 @@ def render_jinja2_template(self, variables: dict[str, Any]) -> str | None:
jinja2_template = self._get_jinja2()

try:
return asyncio.run(jinja2_template.render(variables=variables))
return jinja2_template.render(variables=variables)
except JinjaTemplateError as exc:
if self.test.expect == InfrahubTestExpectedResult.PASS:
raise exc
Expand Down
79 changes: 54 additions & 25 deletions infrahub_sdk/template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import linecache
from pathlib import Path
from typing import Any, Callable, NoReturn
from typing import Any, Callable, ClassVar, NoReturn

import jinja2
from jinja2 import meta, nodes
Expand All @@ -24,7 +24,9 @@
netutils_filters = jinja2_convenience_function()


class Jinja2Template:
class Jinja2TemplateBase:
_is_async: ClassVar[bool] = True

def __init__(
self,
template: str | Path,
Expand Down Expand Up @@ -106,29 +108,8 @@
f"These operations are forbidden for string based templates: {forbidden_operations}"
)

async def render(self, variables: dict[str, Any]) -> str:
template = self.get_template()
try:
output = await template.render_async(variables)
except jinja2.exceptions.TemplateNotFound as exc:
raise JinjaTemplateNotFoundError(message=exc.message, filename=str(exc.name), base_template=template.name)
except jinja2.TemplateSyntaxError as exc:
self._raise_template_syntax_error(error=exc)
except jinja2.UndefinedError as exc:
traceback = Traceback(show_locals=False)
errors = _identify_faulty_jinja_code(traceback=traceback)
raise JinjaTemplateUndefinedError(message=exc.message, errors=errors)
except Exception as exc:
if error_message := getattr(exc, "message", None):
message = error_message
else:
message = str(exc)
raise JinjaTemplateError(message=message or "Unknown template error")

return output

def _get_string_based_environment(self) -> jinja2.Environment:
env = SandboxedEnvironment(enable_async=True, undefined=jinja2.StrictUndefined)
env = SandboxedEnvironment(enable_async=self._is_async, undefined=jinja2.StrictUndefined)
self._set_filters(env=env)
self._environment = env
return self._environment
Expand All @@ -139,7 +120,7 @@
loader=template_loader,
trim_blocks=True,
lstrip_blocks=True,
enable_async=True,
enable_async=self._is_async,
)
self._set_filters(env=env)
self._environment = env
Expand Down Expand Up @@ -177,6 +158,54 @@
raise JinjaTemplateSyntaxError(message=error.message, filename=filename, lineno=error.lineno)


class Jinja2Template(Jinja2TemplateBase):
async def render(self, variables: dict[str, Any]) -> str:
template = self.get_template()
try:
output = await template.render_async(variables)
except jinja2.exceptions.TemplateNotFound as exc:
raise JinjaTemplateNotFoundError(message=exc.message, filename=str(exc.name), base_template=template.name)
except jinja2.TemplateSyntaxError as exc:
self._raise_template_syntax_error(error=exc)
except jinja2.UndefinedError as exc:
traceback = Traceback(show_locals=False)
errors = _identify_faulty_jinja_code(traceback=traceback)
raise JinjaTemplateUndefinedError(message=exc.message, errors=errors)
except Exception as exc:
if error_message := getattr(exc, "message", None):
message = error_message

Check warning on line 176 in infrahub_sdk/template/__init__.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/template/__init__.py#L176

Added line #L176 was not covered by tests
else:
message = str(exc)
raise JinjaTemplateError(message=message or "Unknown template error")

return output


class Jinja2TemplateSync(Jinja2TemplateBase):
_is_async: ClassVar[bool] = False

def render(self, variables: dict[str, Any]) -> str:
template = self.get_template()
try:
output = template.render(variables)
except jinja2.exceptions.TemplateNotFound as exc:
raise JinjaTemplateNotFoundError(message=exc.message, filename=str(exc.name), base_template=template.name)
except jinja2.TemplateSyntaxError as exc:
self._raise_template_syntax_error(error=exc)
except jinja2.UndefinedError as exc:
traceback = Traceback(show_locals=False)
errors = _identify_faulty_jinja_code(traceback=traceback)
raise JinjaTemplateUndefinedError(message=exc.message, errors=errors)
except Exception as exc:
if error_message := getattr(exc, "message", None):
message = error_message

Check warning on line 201 in infrahub_sdk/template/__init__.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/template/__init__.py#L201

Added line #L201 was not covered by tests
else:
message = str(exc)
raise JinjaTemplateError(message=message or "Unknown template error")

return output


def _identify_faulty_jinja_code(traceback: Traceback, nbr_context_lines: int = 3) -> list[UndefinedJinja2Error]:
"""This function identifies the faulty Jinja2 code and beautify it to provide meaningful information to the user.

Expand Down
97 changes: 71 additions & 26 deletions tests/unit/sdk/test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from rich.syntax import Syntax
from rich.traceback import Frame

from infrahub_sdk.template import Jinja2Template
from infrahub_sdk.template import Jinja2Template, Jinja2TemplateSync
from infrahub_sdk.template.exceptions import (
JinjaTemplateError,
JinjaTemplateNotFoundError,
Expand Down Expand Up @@ -78,9 +78,15 @@ class JinjaTestCaseFailing:
"test_case",
[pytest.param(tc, id=tc.name) for tc in SUCCESSFUL_STRING_TEST_CASES],
)
async def test_render_string(test_case: JinjaTestCase) -> None:
jinja = Jinja2Template(template=test_case.template)
assert test_case.expected == await jinja.render(variables=test_case.variables)
@pytest.mark.parametrize("is_async", [True, False])
async def test_render_string(test_case: JinjaTestCase, is_async: bool) -> None:
if is_async:
jinja = Jinja2Template(template=test_case.template)
assert test_case.expected == await jinja.render(variables=test_case.variables)
else:
jinja = Jinja2TemplateSync(template=test_case.template)
assert test_case.expected == jinja.render(variables=test_case.variables)

assert test_case.expected_variables == jinja.get_variables()


Expand All @@ -106,9 +112,14 @@ async def test_render_string(test_case: JinjaTestCase) -> None:
"test_case",
[pytest.param(tc, id=tc.name) for tc in SUCCESSFUL_FILE_TEST_CASES],
)
async def test_render_template_from_file(test_case: JinjaTestCase) -> None:
jinja = Jinja2Template(template=Path(test_case.template), template_directory=TEMPLATE_DIRECTORY)
assert test_case.expected == await jinja.render(variables=test_case.variables)
@pytest.mark.parametrize("is_async", [True, False])
async def test_render_template_from_file(test_case: JinjaTestCase, is_async: bool) -> None:
if is_async:
jinja = Jinja2Template(template=Path(test_case.template), template_directory=TEMPLATE_DIRECTORY)
assert test_case.expected == await jinja.render(variables=test_case.variables)
else:
jinja = Jinja2TemplateSync(template=Path(test_case.template), template_directory=TEMPLATE_DIRECTORY)
assert test_case.expected == jinja.render(variables=test_case.variables)
assert test_case.expected_variables == jinja.get_variables()
assert jinja.get_template()

Expand Down Expand Up @@ -153,10 +164,16 @@ async def test_render_template_from_file(test_case: JinjaTestCase) -> None:
"test_case",
[pytest.param(tc, id=tc.name) for tc in FAILING_STRING_TEST_CASES],
)
async def test_render_string_errors(test_case: JinjaTestCaseFailing) -> None:
jinja = Jinja2Template(template=test_case.template, template_directory=TEMPLATE_DIRECTORY)
with pytest.raises(test_case.error.__class__) as exc:
await jinja.render(variables=test_case.variables)
@pytest.mark.parametrize("is_async", [True, False])
async def test_render_string_errors(test_case: JinjaTestCaseFailing, is_async: bool) -> None:
if is_async:
jinja = Jinja2Template(template=test_case.template, template_directory=TEMPLATE_DIRECTORY)
with pytest.raises(test_case.error.__class__) as exc:
await jinja.render(variables=test_case.variables)
else:
jinja = Jinja2TemplateSync(template=test_case.template, template_directory=TEMPLATE_DIRECTORY)
with pytest.raises(test_case.error.__class__) as exc:
jinja.render(variables=test_case.variables)

_compare_errors(expected=test_case.error, received=exc.value)

Expand Down Expand Up @@ -234,36 +251,64 @@ async def test_render_string_errors(test_case: JinjaTestCaseFailing) -> None:
"test_case",
[pytest.param(tc, id=tc.name) for tc in FAILING_FILE_TEST_CASES],
)
async def test_manage_file_based_errors(test_case: JinjaTestCaseFailing) -> None:
jinja = Jinja2Template(template=Path(test_case.template), template_directory=TEMPLATE_DIRECTORY)
with pytest.raises(test_case.error.__class__) as exc:
await jinja.render(variables=test_case.variables)
@pytest.mark.parametrize("is_async", [True, False])
async def test_manage_file_based_errors(test_case: JinjaTestCaseFailing, is_async: bool) -> None:
if is_async:
jinja = Jinja2Template(template=Path(test_case.template), template_directory=TEMPLATE_DIRECTORY)
with pytest.raises(test_case.error.__class__) as exc:
await jinja.render(variables=test_case.variables)
else:
jinja = Jinja2TemplateSync(template=Path(test_case.template), template_directory=TEMPLATE_DIRECTORY)
with pytest.raises(test_case.error.__class__) as exc:
jinja.render(variables=test_case.variables)

_compare_errors(expected=test_case.error, received=exc.value)


async def test_manage_unhandled_error() -> None:
jinja = Jinja2Template(
template="Hello {{ number | divide_by_zero }}",
filters={"divide_by_zero": _divide_by_zero},
)
with pytest.raises(JinjaTemplateError) as exc:
await jinja.render(variables={"number": 1})
@pytest.mark.parametrize("is_async", [True, False])
async def test_manage_unhandled_error(is_async: bool) -> None:
template = "Hello {{ number | divide_by_zero }}"
filters = {"divide_by_zero": _divide_by_zero}
if is_async:
jinja = Jinja2Template(
template=template,
filters=filters,
)
with pytest.raises(JinjaTemplateError) as exc:
await jinja.render(variables={"number": 1})
else:
jinja = Jinja2TemplateSync(
template=template,
filters=filters,
)
with pytest.raises(JinjaTemplateError) as exc:
jinja.render(variables={"number": 1})

assert exc.value.message == "division by zero"


async def test_validate_filter() -> None:
jinja = Jinja2Template(template="{{ network | get_all_host }}")
@pytest.mark.parametrize("is_async", [True, False])
async def test_validate_filter(is_async: bool) -> None:
template = "{{ network | get_all_host }}"
if is_async:
jinja = Jinja2Template(template=template)
else:
jinja = Jinja2TemplateSync(template=template)

jinja.validate(restricted=False)
with pytest.raises(JinjaTemplateOperationViolationError) as exc:
jinja.validate(restricted=True)

assert exc.value.message == "The 'get_all_host' filter isn't allowed to be used"


async def test_validate_operation() -> None:
jinja = Jinja2Template(template="Hello {% include 'very-forbidden.j2' %}")
@pytest.mark.parametrize("is_async", [True, False])
async def test_validate_operation(is_async: bool) -> None:
if is_async:
jinja = Jinja2Template(template="Hello {% include 'very-forbidden.j2' %}")
else:
jinja = Jinja2TemplateSync(template="Hello {% include 'very-forbidden.j2' %}")

with pytest.raises(JinjaTemplateOperationViolationError) as exc:
jinja.validate(restricted=True)

Expand Down