Skip to content

Commit

Permalink
Python: rebuilt xml creation and parsing (#5550)
Browse files Browse the repository at this point in the history
### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->
We got a report stating that a empty chat_history between two other
params, didn't work, so had to change some things in the way
chat_history xml is done, and subsequent improved the parsing, now using
xml fully.

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->
- Added to_element methods to CMC classes, to_prompt uses that.
- Redid the chat_history from_prompt method to fully utilize xml
parsing.
- Added chat_history tags to detect an empty chatHistory in a prompt.
- Made sure when only 1 message is put into the ChatHistory after
rendering that that is a USER message, instead of SYSTEM, if multiple
the first one is still SYSTEM.

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄
  • Loading branch information
eavanvalkenburg committed Mar 19, 2024
1 parent 1cf984c commit ccbeb7b
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import List, Optional
from xml.etree.ElementTree import Element

from defusedxml import ElementTree
from openai.types.chat import ChatCompletion

from semantic_kernel.connectors.ai.open_ai.contents.function_call import FunctionCall
Expand Down Expand Up @@ -40,7 +39,7 @@ def ToolIdProperty():
# Directly using the class name and the attribute name as strings
return f"{ToolCall.__name__}.{ToolCall.id.__name__}"

def to_prompt(self, root_key: str) -> str:
def to_element(self, root_key: str) -> Element:
"""Convert the OpenAIChatMessageContent to a prompt.
Returns:
Expand All @@ -56,7 +55,7 @@ def to_prompt(self, root_key: str) -> str:
if self.tool_call_id:
root.set("tool_call_id", self.tool_call_id)
root.text = self.content or ""
return ElementTree.tostring(root, encoding=self.encoding or "unicode", short_empty_elements=False)
return root

@classmethod
def from_element(cls, element: Element) -> "ChatMessageContent":
Expand Down
77 changes: 33 additions & 44 deletions python/semantic_kernel/contents/chat_history.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright (c) Microsoft. All rights reserved.

import logging
from typing import Any, Dict, Final, Iterator, List, Optional, Tuple, Type, Union
from typing import Any, Dict, Final, Iterator, List, Optional, Type, Union
from xml.etree import ElementTree
from xml.etree.ElementTree import Element

import defusedxml.ElementTree as ET

Expand All @@ -14,9 +16,7 @@
logger = logging.getLogger(__name__)

ROOT_KEY_MESSAGE: Final[str] = "message"
START_TAG: Final[str] = f"<{ROOT_KEY_MESSAGE}"
END_TAG: Final[str] = f"</{ROOT_KEY_MESSAGE}>"
LEN_END_TAG: Final[int] = len(END_TAG)
ROOT_KEY_HISTORY: Final[str] = "chat_history"


class ChatHistory(KernelBaseModel):
Expand Down Expand Up @@ -159,15 +159,16 @@ def __contains__(self, item: ChatMessageContent) -> bool:

def __str__(self) -> str:
"""Return a string representation of the history."""
if not self.messages:
return ""
return "\n".join([msg.to_prompt(root_key=ROOT_KEY_MESSAGE) for msg in self.messages])
chat_history_xml = Element(ROOT_KEY_HISTORY)
for message in self.messages:
chat_history_xml.append(message.to_element(root_key=ROOT_KEY_MESSAGE))
return ElementTree.tostring(chat_history_xml, encoding="unicode", short_empty_elements=True)

def __iter__(self) -> Iterator[ChatMessageContent]:
"""Return an iterator over the messages in the history."""
return iter(self.messages)

def __eq__(self, other: "ChatHistory") -> bool:
def __eq__(self, other: Any) -> bool:
"""Check if two ChatHistory instances are equal."""
if not isinstance(other, ChatHistory):
return False
Expand All @@ -188,38 +189,25 @@ def from_rendered_prompt(
ChatHistory: The ChatHistory instance created from the rendered prompt.
"""
messages: List[chat_message_content_type] = []
result, remainder = cls._render_remaining(rendered_prompt, chat_message_content_type, True)
if result:
messages.append(result)
while remainder:
result, remainder = cls._render_remaining(remainder, chat_message_content_type)
if result:
messages.append(result)
return cls(messages=messages)

@staticmethod
def _render_remaining(
prompt: Optional[str],
chat_message_content_type: Type[ChatMessageContent] = ChatMessageContent,
first: bool = False,
) -> Tuple[Optional[ChatMessageContent], Optional[str]]:
"""Render the remaining messages in the history."""
if not prompt:
return None, None
prompt = prompt.strip()
start = prompt.find(START_TAG)
end = prompt.find(END_TAG)
role = ChatRole.SYSTEM if first else ChatRole.USER
if start == -1 or end == -1:
return chat_message_content_type(role=role, content=prompt), None
if start > 0 and end > 0:
return chat_message_content_type(role=role, content=prompt[:start]), prompt[start:]
end_of_tag = end + LEN_END_TAG
prompt = rendered_prompt.strip()
try:
return chat_message_content_type.from_element(ET.fromstring(prompt[start:end_of_tag])), prompt[end_of_tag:]
except Exception as exc:
logger.warning(f"Unable to parse prompt: {prompt[start:end_of_tag]}, returning as content", exc_info=exc)
return chat_message_content_type(role=role, content=prompt[start:end_of_tag]), prompt[end_of_tag:]
xml_prompt = ET.fromstring(f"<prompt>{prompt}</prompt>")
except ET.ParseError as e:
logger.error(f"Error parsing XML of prompt: {e}")
return cls(messages=[chat_message_content_type(role=ChatRole.USER, content=prompt)])
if xml_prompt.text and xml_prompt.text.strip():
messages.append(chat_message_content_type(role=ChatRole.SYSTEM, content=xml_prompt.text.strip()))
for item in xml_prompt:
if item.tag == ROOT_KEY_MESSAGE:
messages.append(chat_message_content_type.from_element(item))
elif item.tag == ROOT_KEY_HISTORY:
for message in item:
messages.append(chat_message_content_type.from_element(message))
if item.tail and item.tail.strip():
messages.append(chat_message_content_type(role=ChatRole.USER, content=item.tail.strip()))
if len(messages) == 1 and messages[0].role == ChatRole.SYSTEM:
messages[0].role = ChatRole.USER
return cls(messages=messages)

def serialize(self) -> str:
"""
Expand All @@ -234,7 +222,7 @@ def serialize(self) -> str:
try:
return self.model_dump_json(indent=4)
except Exception as e:
raise ContentSerializationError(f"Unable to serialize ChatHistory to JSON: {e}")
raise ContentSerializationError(f"Unable to serialize ChatHistory to JSON: {e}") from e

@classmethod
def restore_chat_history(cls, chat_history_json: str) -> "ChatHistory":
Expand All @@ -257,19 +245,20 @@ def restore_chat_history(cls, chat_history_json: str) -> "ChatHistory":
except Exception as e:
raise ContentInitializationError(f"Invalid JSON format: {e}")

def store_chat_history_to_file(chat_history: "ChatHistory", file_path: str) -> None:
def store_chat_history_to_file(self, file_path: str) -> None:
"""
Stores the serialized ChatHistory to a file.
Args:
chat_history (ChatHistory): The ChatHistory instance to serialize and store.
file_path (str): The path to the file where the serialized data will be stored.
"""
json_str = chat_history.serialize()
json_str = self.serialize()
with open(file_path, "w") as file:
file.write(json_str)

def load_chat_history_from_file(file_path: str) -> "ChatHistory":
@classmethod
def load_chat_history_from_file(cls, file_path: str) -> "ChatHistory":
"""
Loads the ChatHistory from a file.
Expand All @@ -281,4 +270,4 @@ def load_chat_history_from_file(file_path: str) -> "ChatHistory":
"""
with open(file_path, "r") as file:
json_str = file.read()
return ChatHistory.restore_chat_history(json_str)
return cls.restore_chat_history(json_str)
19 changes: 15 additions & 4 deletions python/semantic_kernel/contents/chat_message_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,28 @@ class ChatMessageContent(KernelContent):
def __str__(self) -> str:
return self.content or ""

def to_element(self, root_key: str) -> Element:
"""Convert the ChatMessageContent to an XML Element.
Args:
root_key: str - The key to use for the root of the XML Element.
Returns:
Element - The XML Element representing the ChatMessageContent.
"""
root = Element(root_key)
root.set("role", self.role.value)
root.text = self.content or ""
return root

def to_prompt(self, root_key: str) -> str:
"""Convert the ChatMessageContent to a prompt.
Returns:
str - The prompt from the ChatMessageContent.
"""

root = Element(root_key)
root.set("role", self.role.value)
root.set("metadata", json.dumps(self.metadata))
root.text = self.content or ""
root = self.to_element(root_key)
return ElementTree.tostring(root, encoding=self.encoding or "unicode", short_empty_elements=False)

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings
from semantic_kernel.core_plugins.math_plugin import MathPlugin
from semantic_kernel.functions.kernel_arguments import KernelArguments
from semantic_kernel.kernel import Kernel
from semantic_kernel.prompt_template.prompt_template_config import PromptTemplateConfig


Expand Down Expand Up @@ -176,9 +177,7 @@ async def test_azure_oai_chat_service_with_tool_call(setup_tldr_function_for_oai


@pytest.mark.asyncio
async def test_azure_oai_chat_service_with_tool_call_streaming(setup_tldr_function_for_oai_models, get_aoai_config):
kernel, _, _ = setup_tldr_function_for_oai_models

async def test_azure_oai_chat_service_with_tool_call_streaming(kernel: Kernel, get_aoai_config):
_, api_key, endpoint = get_aoai_config

if "Python_Integration_Tests" in os.environ:
Expand Down Expand Up @@ -207,7 +206,7 @@ async def test_azure_oai_chat_service_with_tool_call_streaming(setup_tldr_functi
),
)

kernel.import_plugin_from_object(MathPlugin(), plugin_name="math")
kernel.import_plugin_from_object(MathPlugin(), plugin_name="Math")

# Create the prompt function
chat_func = kernel.create_function_from_prompt(prompt="{{$input}}", function_name="chat", plugin_name="chat")
Expand All @@ -221,7 +220,7 @@ async def test_azure_oai_chat_service_with_tool_call_streaming(setup_tldr_functi
auto_invoke_kernel_functions=True,
max_auto_invoke_attempts=3,
)
arguments = KernelArguments(input="what is 1+1?", settings=execution_settings)
arguments = KernelArguments(input="what is 101+102?", settings=execution_settings)

result = None
async for message in kernel.invoke_stream(chat_func, arguments=arguments):
Expand Down
Loading

0 comments on commit ccbeb7b

Please sign in to comment.