-
Notifications
You must be signed in to change notification settings - Fork 2.8k
/
jinja2_system_helpers.py
95 lines (68 loc) · 2.42 KB
/
jinja2_system_helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# Copyright (c) Microsoft. All rights reserved.
import logging
import re
from enum import Enum
from typing import Callable, Dict
logger: logging.Logger = logging.getLogger(__name__)
def _messages(chat_history):
from semantic_kernel.contents.chat_history import ChatHistory
if not isinstance(chat_history, ChatHistory):
return ""
return str(chat_history)
def _message_to_prompt(context):
from semantic_kernel.contents.chat_message_content import ChatMessageContent
if isinstance(context, ChatMessageContent):
return str(context.to_prompt())
return str(context)
def _message(item):
from semantic_kernel.contents.const import CHAT_MESSAGE_CONTENT_TAG
start = f"<{CHAT_MESSAGE_CONTENT_TAG}"
role = item.role
content = item.content
if isinstance(role, Enum):
role = role.value
start += f' role="{role}"'
start += ">"
end = f"</{CHAT_MESSAGE_CONTENT_TAG}>"
return f"{start}{content}{end}"
# Wrap the _get function to safely handle calls without arguments
def _safe_get_wrapper(context=None, name=None, default=""):
if context is None or name is None:
return default
return _get(context, name, default)
def _get(context, name, default=""):
"""Retrieves a value from the context, with a default if not found."""
return context.get(name, default)
def _double_open():
"""Returns the string representing double open braces."""
return "{{"
def _double_close():
"""Returns the string representing double close braces."""
return "}}"
def _array(*args, **kwargs):
print(f"Received args: {args}")
return list(args)
def _camel_case(*args, **kwargs):
return "".join([word.capitalize() for word in args[0].split("_")])
def _snake_case(*args, **kwargs):
arg = args[0]
arg = re.sub(r"([A-Z]+)([A-Z][a-z])", r"\1_\2", arg)
arg = re.sub(r"([a-z\d])([A-Z])", r"\1_\2", arg)
arg = arg.replace("-", "_")
return arg.lower()
JINJA2_SYSTEM_HELPERS: Dict[str, Callable] = {
"get": _safe_get_wrapper,
"double_open": _double_open,
"doubleOpen": _double_open,
"double_close": _double_close,
"doubleClose": _double_close,
"message": _message,
"message_to_prompt": _message_to_prompt,
"messages": _messages,
"messageToPrompt": _message_to_prompt,
"array": _array,
"camel_case": _camel_case,
"camelCase": _camel_case,
"snake_case": _snake_case,
"snakeCase": _snake_case,
}