Skip to content

Commit 780aa63

Browse files
Python: implement async support for templates (#6486)
### Motivation and Context <!-- Thank you for your contribution to the semantic-kernel repo! Please help reviewers and future users, providing the following information: 1. Why is this change required? 2. What problem does it solve? 3. What scenario does it contribute to? 4. If it fixes an open issue, please link to the issue here. --> Jinja2 templates now use async functions instead of nested ones! With full testing ### Description <!-- Describe your changes, the overall approach, the underlying design. These notes will help understanding how your code works. Thanks! --> ### Contribution Checklist <!-- Before submitting this PR, please make sure: --> - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone 😄
1 parent d56ca29 commit 780aa63

File tree

4 files changed

+143
-7
lines changed

4 files changed

+143
-7
lines changed

python/semantic_kernel/prompt_template/jinja2_prompt_template.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
from pydantic import PrivateAttr, field_validator
1010

1111
from semantic_kernel.exceptions import Jinja2TemplateRenderException
12-
from semantic_kernel.exceptions.template_engine_exceptions import TemplateRenderException
1312
from semantic_kernel.functions.kernel_arguments import KernelArguments
1413
from semantic_kernel.prompt_template.const import JINJA2_TEMPLATE_FORMAT_NAME
1514
from semantic_kernel.prompt_template.prompt_template_base import PromptTemplateBase
1615
from semantic_kernel.prompt_template.prompt_template_config import PromptTemplateConfig
17-
from semantic_kernel.prompt_template.utils import JINJA2_SYSTEM_HELPERS, create_template_helper_from_function
16+
from semantic_kernel.prompt_template.utils import JINJA2_SYSTEM_HELPERS
17+
from semantic_kernel.prompt_template.utils.template_function_helpers import create_template_helper_from_function
1818

1919
if TYPE_CHECKING:
2020
from semantic_kernel.kernel import Kernel
@@ -63,7 +63,7 @@ def model_post_init(self, _: Any) -> None:
6363
if not self.prompt_template_config.template:
6464
self._env = None
6565
return
66-
self._env = ImmutableSandboxedEnvironment(loader=BaseLoader())
66+
self._env = ImmutableSandboxedEnvironment(loader=BaseLoader(), enable_async=True)
6767

6868
async def render(self, kernel: "Kernel", arguments: Optional["KernelArguments"] = None) -> str:
6969
"""Render the prompt template.
@@ -97,16 +97,16 @@ async def render(self, kernel: "Kernel", arguments: Optional["KernelArguments"]
9797
arguments,
9898
self.prompt_template_config.template_format,
9999
allow_unsafe_function_output,
100+
enable_async=True,
100101
)
101102
for function in plugin
102103
}
103104
)
105+
if self.prompt_template_config.template is None:
106+
raise Jinja2TemplateRenderException("Error rendering template, template is None")
104107
try:
105-
if self.prompt_template_config.template is None:
106-
raise TemplateRenderException("Template is None")
107108
template = self._env.from_string(self.prompt_template_config.template, globals=helpers)
108-
return template.render(**arguments)
109-
109+
return await template.render_async(**arguments)
110110
except TemplateError as exc:
111111
logger.error(
112112
f"Error rendering prompt template: {self.prompt_template_config.template} with arguments: {arguments}"

python/semantic_kernel/prompt_template/utils/template_function_helpers.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,50 @@ def create_template_helper_from_function(
2929
base_arguments: "KernelArguments",
3030
template_format: TEMPLATE_FORMAT_TYPES,
3131
allow_dangerously_set_content: bool = False,
32+
enable_async: bool = False,
33+
) -> Callable[..., Any]:
34+
"""Create a helper function for both the Handlebars and Jinja2 templating engines from a kernel function.
35+
36+
Args:
37+
function (KernelFunction): The kernel function to create a helper for.
38+
kernel (Kernel): The kernel to use for invoking the function.
39+
base_arguments (KernelArguments): The base arguments to use when invoking the function.
40+
template_format (TEMPLATE_FORMAT_TYPES): The template format to create the helper for.
41+
allow_dangerously_set_content (bool, optional): Return the content of the function result
42+
without encoding it or not.
43+
enable_async (bool, optional): Enable async helper function. Defaults to False.
44+
Currently only works for Jinja2 templates.
45+
46+
Returns:
47+
The function with args that are callable by the different templates.
48+
49+
Raises:
50+
ValueError: If the template format is not supported.
51+
52+
"""
53+
if enable_async:
54+
return _create_async_template_helper_from_function(
55+
function=function,
56+
kernel=kernel,
57+
base_arguments=base_arguments,
58+
template_format=template_format,
59+
allow_dangerously_set_content=allow_dangerously_set_content,
60+
)
61+
return _create_sync_template_helper_from_function(
62+
function=function,
63+
kernel=kernel,
64+
base_arguments=base_arguments,
65+
template_format=template_format,
66+
allow_dangerously_set_content=allow_dangerously_set_content,
67+
)
68+
69+
70+
def _create_sync_template_helper_from_function(
71+
function: "KernelFunction",
72+
kernel: "Kernel",
73+
base_arguments: "KernelArguments",
74+
template_format: TEMPLATE_FORMAT_TYPES,
75+
allow_dangerously_set_content: bool = False,
3276
) -> Callable[..., Any]:
3377
"""Create a helper function for both the Handlebars and Jinja2 templating engines from a kernel function."""
3478
if template_format not in [JINJA2_TEMPLATE_FORMAT_NAME, HANDLEBARS_TEMPLATE_FORMAT_NAME]:
@@ -67,3 +111,32 @@ def func(*args, **kwargs):
67111
return escape(str(result))
68112

69113
return func
114+
115+
116+
def _create_async_template_helper_from_function(
117+
function: "KernelFunction",
118+
kernel: "Kernel",
119+
base_arguments: "KernelArguments",
120+
template_format: TEMPLATE_FORMAT_TYPES,
121+
allow_dangerously_set_content: bool = False,
122+
) -> Callable[..., Any]:
123+
"""Create a async helper function for Jinja2 templating engines from a kernel function."""
124+
if template_format not in [JINJA2_TEMPLATE_FORMAT_NAME]:
125+
raise ValueError(f"Invalid template format: {template_format}")
126+
127+
async def func(*args, **kwargs):
128+
arguments = KernelArguments()
129+
if base_arguments and base_arguments.execution_settings:
130+
arguments.execution_settings = base_arguments.execution_settings # pragma: no cover
131+
arguments.update(base_arguments)
132+
arguments.update(kwargs)
133+
logger.debug(
134+
f"Invoking function {function.metadata.fully_qualified_name} "
135+
f"with args: {arguments} and kwargs: {kwargs}."
136+
)
137+
result = await function.invoke(kernel=kernel, arguments=arguments)
138+
if allow_dangerously_set_content:
139+
return result
140+
return escape(str(result))
141+
142+
return func

python/tests/unit/prompt_template/test_jinja2_prompt_template.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,15 @@ async def test_it_renders_fail(kernel: Kernel):
9292
await target.render(kernel, KernelArguments())
9393

9494

95+
@pytest.mark.asyncio
96+
async def test_it_renders_fail_empty_template(kernel: Kernel):
97+
template = "{{ plug-func 'test1'}}"
98+
target = create_jinja2_prompt_template(template)
99+
target.prompt_template_config.template = None
100+
with pytest.raises(Jinja2TemplateRenderException):
101+
await target.render(kernel, KernelArguments())
102+
103+
95104
@pytest.mark.asyncio
96105
async def test_it_renders_list(kernel: Kernel):
97106
template = "List: {% for item in items %}{{ item }}{% endfor %}"
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
3+
import pytest
4+
5+
from semantic_kernel.functions import kernel_function
6+
from semantic_kernel.functions.kernel_function_from_method import KernelFunctionFromMethod
7+
from semantic_kernel.kernel import Kernel
8+
from semantic_kernel.prompt_template.const import JINJA2_TEMPLATE_FORMAT_NAME
9+
from semantic_kernel.prompt_template.utils.template_function_helpers import create_template_helper_from_function
10+
11+
12+
def test_create_helpers(kernel: Kernel):
13+
# Arrange
14+
function = KernelFunctionFromMethod(kernel_function(lambda x: x + 1, name="test"), plugin_name="test")
15+
base_arguments = {}
16+
template_format = JINJA2_TEMPLATE_FORMAT_NAME
17+
allow_dangerously_set_content = False
18+
enable_async = False
19+
20+
# Act
21+
result = create_template_helper_from_function(
22+
function, kernel, base_arguments, template_format, allow_dangerously_set_content, enable_async
23+
)
24+
25+
# Assert
26+
assert int(str(result(x=1))) == 2
27+
28+
29+
@pytest.mark.parametrize(
30+
"template_format, enable_async, exception",
31+
[
32+
("jinja2", True, False),
33+
("jinja2", False, False),
34+
("handlebars", True, True),
35+
("handlebars", False, False),
36+
("semantic-kernel", False, True),
37+
("semantic-kernel", True, True),
38+
],
39+
)
40+
@pytest.mark.asyncio
41+
async def test_create_helpers_fail(kernel: Kernel, template_format: str, enable_async: bool, exception: bool):
42+
# Arrange
43+
function = KernelFunctionFromMethod(kernel_function(lambda x: x + 1, name="test"), plugin_name="test")
44+
45+
if exception:
46+
with pytest.raises(ValueError):
47+
create_template_helper_from_function(function, kernel, {}, template_format, False, enable_async)
48+
return
49+
result = create_template_helper_from_function(function, kernel, {}, template_format, False, enable_async)
50+
if enable_async:
51+
res = await result(x=1)
52+
assert int(str(res)) == 2
53+
else:
54+
assert int(str(result(x=1))) == 2

0 commit comments

Comments
 (0)