Skip to content

Commit

Permalink
fixed tests and some samples
Browse files Browse the repository at this point in the history
  • Loading branch information
eavanvalkenburg committed Apr 23, 2024
1 parent 6e82c66 commit 189856d
Show file tree
Hide file tree
Showing 14 changed files with 610 additions and 220 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
AzureChatCompletion,
AzureChatPromptExecutionSettings,
ExtraBody,
FunctionCallContent,
ToolCall,
)
from semantic_kernel.contents import AuthorRole, ChatHistory
from semantic_kernel.contents import ChatHistory
from semantic_kernel.functions import KernelArguments
from semantic_kernel.prompt_template import InputVariable, PromptTemplateConfig
from semantic_kernel.utils.settings import (
Expand All @@ -21,7 +19,7 @@
)

kernel = Kernel()
logging.basicConfig(level=logging.DEBUG)
logging.basicConfig(level=logging.INFO)

# Load Azure OpenAI Settings
aoai_settings = azure_openai_settings_from_dot_env_as_dict(include_api_version=True)
Expand All @@ -36,12 +34,12 @@

# Our example index has fields "source_title", "source_text", "source_url", and "source_file".
# Add fields mapping to the settings to indicate which fields to use for the title, content, URL, and file path.
azure_ai_search_settings["fieldsMapping"] = {
"titleField": "source_title",
"urlField": "source_url",
"contentFields": ["source_text"],
"filepathField": "source_file",
}
# azure_ai_search_settings["fieldsMapping"] = {
# "titleField": "source_title",
# "urlField": "source_url",
# "contentFields": ["source_text"],
# "filepathField": "source_file",
# }

# Create the data source settings

Expand Down Expand Up @@ -87,12 +85,18 @@ async def chat() -> bool:
if user_input == "exit":
print("\n\nExiting chat...")
return False

# Non streaming
# answer = await kernel.run(chat_function, input_vars=context_vars)
# print(f"Assistant:> {answer}")
arguments = KernelArguments(chat_history=chat_history, user_input=user_input, execution_settings=req_settings)

stream = False
if not stream:
# Non streaming
answer = await kernel.invoke(chat_function, arguments=arguments)
print(f"Assistant:> {answer}")
chat_history.add_user_message(user_input)
for message in AzureChatCompletion.split_message(answer.value[0]):
chat_history.add_message(message)
return True

full_message = None
print("Assistant:> ", end="")
async for message in kernel.invoke_stream(chat_function, arguments=arguments):
Expand All @@ -101,24 +105,9 @@ async def chat() -> bool:
print("\n")

# The tool message containing cited sources is available in the context
if full_message:
chat_history.add_user_message(user_input)
if hasattr(full_message, "tool_message"):
chat_history.add_message(
ChatMessageContent(
role="assistant",
tool_calls=[
ToolCall(
id="chat_with_your_data",
function=FunctionCallContent(name="chat_with_your_data", arguments=""),
)
],
)
)
chat_history.add_tool_message(full_message.tool_message, {"tool_call_id": "chat_with_your_data"})
if full_message.role is None:
full_message.role = AuthorRole.ASSISTANT
chat_history.add_assistant_message(full_message.content)
chat_history.add_user_message(user_input)
for message in AzureChatCompletion.split_message(full_message):
chat_history.add_message(message)
return True


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@
AzureChatCompletion,
AzureChatPromptExecutionSettings,
ExtraBody,
FunctionCallContent,
ToolCall,
)
from semantic_kernel.contents import AuthorRole, ChatHistory
from semantic_kernel.contents import ChatHistory
from semantic_kernel.functions import KernelArguments
from semantic_kernel.kernel import Kernel
from semantic_kernel.prompt_template import InputVariable, PromptTemplateConfig
Expand Down Expand Up @@ -113,22 +111,8 @@ async def chat() -> bool:
# The tool message containing cited sources is available in the context
if full_message:
chat_history.add_user_message(user_input)
if hasattr(full_message, "tool_message"):
chat_history.add_message(
ChatMessageContent(
role="assistant",
tool_calls=[
ToolCall(
id="chat_with_your_data",
function=FunctionCallContent(name="chat_with_your_data", arguments=""),
)
],
)
)
chat_history.add_tool_message(full_message.tool_message, {"tool_call_id": "chat_with_your_data"})
if full_message.role is None:
full_message.role = AuthorRole.ASSISTANT
chat_history.add_assistant_message(full_message.content)
for message in AzureChatCompletion.split_message(full_message):
chat_history.add_message(message)
return True


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from semantic_kernel import Kernel
from semantic_kernel.connectors.ai.open_ai import AzureTextCompletion, AzureTextEmbedding
from semantic_kernel.connectors.memory import AzureCognitiveSearchMemoryStore
from semantic_kernel.connectors.memory.azure_cognitive_search import AzureCognitiveSearchMemoryStore
from semantic_kernel.core_plugins import TextMemoryPlugin
from semantic_kernel.memory import SemanticTextMemory

Expand Down Expand Up @@ -62,7 +62,7 @@ async def main() -> None:
api_key=AZURE_OPENAI_API_KEY,
),
)
embedding_service_id = ("ada",)
embedding_service_id = "ada"
embedding_gen = AzureTextEmbedding(
service_id=embedding_service_id,
deployment_name="text-embedding-ada-002",
Expand All @@ -81,10 +81,10 @@ async def main() -> None:
kernel.add_plugin(TextMemoryPlugin(memory), "TextMemoryPlugin")

print("Populating memory...")
await populate_memory(kernel)
await populate_memory(memory)

print("Asking questions... (manually)")
await search_acs_memory_questions(kernel)
await search_acs_memory_questions(memory)

await acs_connector.close()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,7 @@
from typing import TYPE_CHECKING, List

from semantic_kernel import Kernel
from semantic_kernel.connectors.ai.open_ai import ( # OpenAIChatMessageContent,
OpenAIChatCompletion,
OpenAIChatPromptExecutionSettings,
)

# OpenAIStreamingChatMessageContent,
from semantic_kernel.connectors.ai.open_ai import OpenAIChatCompletion, OpenAIChatPromptExecutionSettings
from semantic_kernel.connectors.ai.open_ai.utils import get_tool_call_object
from semantic_kernel.contents import ChatHistory
from semantic_kernel.contents.chat_message_content import ChatMessageContent
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.
import json
import logging
from copy import deepcopy
from typing import Any, Dict, Mapping, Optional, Union, overload
from uuid import uuid4

Expand All @@ -20,9 +21,11 @@
from semantic_kernel.connectors.ai.open_ai.services.open_ai_text_completion_base import OpenAITextCompletionBase
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.finish_reason import FinishReason
from semantic_kernel.contents.function_call_content import FunctionCallContent
from semantic_kernel.contents.function_result_content import FunctionResultContent
from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent
from semantic_kernel.contents.text_content import TextContent
from semantic_kernel.kernel_pydantic import HttpsUrl

logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -265,9 +268,19 @@ def _add_tool_message_to_chat_message_content(
self, content: ChatMessageContent | StreamingChatMessageContent, choice: Choice
) -> "ChatMessageContent | StreamingChatMessageContent":
if tool_message := self._get_tool_message_from_chat_choice(choice=choice):
function_call = FunctionCallContent(id=uuid4(), name="Azure-OnYourData")
try:
tool_message_dict = json.loads(tool_message)
except json.JSONDecodeError:
logger.error("Failed to parse tool message JSON: %s", tool_message)
tool_message_dict = {"citations": tool_message}

function_call = FunctionCallContent(
id=str(uuid4()),
name="Azure-OnYourData",
arguments=json.dumps({"query": tool_message_dict.get("intent", [])}),
)
result = FunctionResultContent.from_function_call_content_and_result(
result=tool_message, function_call_content=function_call
result=tool_message_dict["citations"], function_call_content=function_call
)
content.items.insert(0, function_call)
content.items.insert(1, result)
Expand All @@ -283,3 +296,26 @@ def _get_tool_message_from_chat_choice(self, choice: Union[Choice, ChunkChoice])
return json.dumps(content.model_extra["context"])

return None

@staticmethod
def split_message(message: "ChatMessageContent") -> list["ChatMessageContent"]:
"""Split a Azure On Your Data response into separate ChatMessageContents.
If the message does not have three contents, and those three are one each of:
FunctionCallContent, FunctionResultContent, and TextContent,
it will not return three messages, potentially only one or two.
The order of the returned messages is as expected by OpenAI.
"""
if len(message.items) != 3:
return [message]
messages = {"tool_call": deepcopy(message), "tool_result": deepcopy(message), "assistant": deepcopy(message)}
for key, msg in messages.items():
if key == "tool_call":
msg.items = [item for item in msg.items if isinstance(item, FunctionCallContent)]
msg.finish_reason = FinishReason.FUNCTION_CALL
if key == "tool_result":
msg.items = [item for item in msg.items if isinstance(item, FunctionResultContent)]
if key == "assistant":
msg.items = [item for item in msg.items if isinstance(item, TextContent)]
return [messages["tool_call"], messages["tool_result"], messages["assistant"]]
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from semantic_kernel.contents.function_call_content import FunctionCallContent
from semantic_kernel.contents.function_result_content import FunctionResultContent
from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent
from semantic_kernel.contents.streaming_text_content import StreamingTextContent
from semantic_kernel.contents.text_content import TextContent
from semantic_kernel.exceptions import (
FunctionCallInvalidArgumentsException,
Expand Down Expand Up @@ -251,7 +252,7 @@ def _create_streaming_chat_message_content(
items: list[Any] = self._get_tool_calls_from_chat_choice(choice)
items.extend(self._get_function_call_from_chat_choice(choice))
if choice.delta.content is not None:
items.append(TextContent(text=choice.delta.content))
items.append(StreamingTextContent(choice_index=choice.index, text=choice.delta.content))
return StreamingChatMessageContent(
choice_index=choice.index,
inner_content=chunk,
Expand Down
48 changes: 27 additions & 21 deletions python/semantic_kernel/contents/chat_message_content.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.
from __future__ import annotations

import logging
from enum import Enum
from typing import Any, Union, overload
from xml.etree.ElementTree import Element
Expand All @@ -19,16 +20,18 @@
from semantic_kernel.contents.function_call_content import FunctionCallContent
from semantic_kernel.contents.function_result_content import FunctionResultContent
from semantic_kernel.contents.kernel_content import KernelContent
from semantic_kernel.contents.streaming_text_content import StreamingTextContent
from semantic_kernel.contents.text_content import TextContent
from semantic_kernel.kernel_pydantic import KernelBaseModel

TAG_CONTENT_MAP = {
TEXT_CONTENT_TAG: TextContent,
FUNCTION_CALL_CONTENT_TAG: FunctionCallContent,
FUNCTION_RESULT_CONTENT_TAG: FunctionResultContent,
}

ITEM_TYPES = Union[TextContent, FunctionCallContent, FunctionResultContent]
ITEM_TYPES = Union[TextContent, StreamingTextContent, FunctionCallContent, FunctionResultContent]

logger = logging.getLogger(__name__)


class ChatMessageContent(KernelContent):
Expand Down Expand Up @@ -145,8 +148,6 @@ def __init__( # type: ignore
items = [item]
if items:
kwargs["items"] = items
if not items and "finish_reason" not in kwargs:
raise ValueError("ChatMessageContent must have either items or content.")
if inner_content:
kwargs["inner_content"] = inner_content
if metadata:
Expand All @@ -159,7 +160,7 @@ def __init__( # type: ignore

@property
def content(self) -> str:
"""Get the content of the response."""
"""Get the content of the response, will find the first TextContent's text."""
for item in self.items:
if isinstance(item, TextContent):
return item.text
Expand All @@ -169,6 +170,10 @@ def content(self) -> str:
def content(self, value: str):
"""Set the content of the response."""
if not value:
logger.warning(
"Setting empty content on ChatMessageContent does not work, "
"you can do this through the underlying items if needed, ignoring."
)
return
for item in self.items:
if isinstance(item, TextContent):
Expand Down Expand Up @@ -199,20 +204,11 @@ def to_element(self) -> "Element":
"""
root = Element(CHAT_MESSAGE_CONTENT_TAG)
for field in self.model_fields_set:
if field in ["items", "metadata", "inner_content"]:
if field not in ["role", "name", "encoding", "finish_reason", "ai_model_id"]:
continue
value = getattr(self, field)
if value is None:
continue
if isinstance(value, Enum):
value = value.value
if isinstance(value, KernelBaseModel):
value = value.model_dump_json(exclude_none=True)
if isinstance(value, list):
if isinstance(value[0], KernelBaseModel):
value = "|".join([val.model_dump_json(exclude_none=True) for val in value])
else:
value = "|".join(value)
root.set(field, value)
for index, item in enumerate(self.items):
root.insert(index, item.to_element())
Expand All @@ -228,18 +224,26 @@ def from_element(cls, element: Element) -> "ChatMessageContent":
Returns:
ChatMessageContent - The new instance of ChatMessageContent or a subclass.
"""
kwargs: dict[str, Any] = {key: value for key, value in element.items()}
items: list[KernelContent] = []
for child in element:
items.append(TAG_CONTENT_MAP[child.tag].from_element(child)) # type: ignore
kwargs: dict[str, Any] = {}
if child.tag not in TAG_CONTENT_MAP:
logger.warning('Unknown tag "%s" in ChatMessageContent, treating as text', child.tag)
text = ElementTree.tostring(child, encoding="unicode", short_empty_elements=False)
items.append(TextContent(text=text or ""))
else:
items.append(TAG_CONTENT_MAP[child.tag].from_element(child)) # type: ignore
if items:
kwargs["items"] = items
if element.text:
kwargs["content"] = element.text
if not kwargs:
raise ValueError("ChatMessageContent must have either items or content.")
for key, value in element.items():
kwargs[key] = value
if "choice_index" in kwargs and cls is ChatMessageContent:
logger.warning(
"Seems like you are trying to create a StreamingChatMessageContent, "
"use StreamingChatMessageContent.from_element instead, ignoring that field "
" and creating a ChatMessageContent instance."
)
kwargs.pop("choice_index")
return cls(**kwargs)

def to_prompt(self) -> str:
Expand Down Expand Up @@ -280,4 +284,6 @@ def _parse_items(self) -> str | list[dict[str, Any]]:
"""
if len(self.items) == 1 and isinstance(self.items[0], TextContent):
return self.items[0].text
if len(self.items) == 1 and isinstance(self.items[0], FunctionResultContent):
return self.items[0].result
return [item.to_dict() for item in self.items]

0 comments on commit 189856d

Please sign in to comment.