Skip to content

Commit

Permalink
Python: updated Chat Message Content with Function Call and Result Co…
Browse files Browse the repository at this point in the history
…ntent (#5946)

### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->
Introducing FunctionCallContent and FunctionResultContent
Changed ChatRole to AuthorRole
Adapting ChatMessageContent to have 1 or more other contents within
(currently TextContent or one of the above)
Closes #5890 

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->
Changed OpenAI classes, to remove OpenAIChatMessageContent and
AzureChatMessageContent
Changed OpenAI classes to create and parse FunctionCallContent and other
new things.

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [ ] I didn't break anyone 😄
  • Loading branch information
eavanvalkenburg committed Apr 25, 2024
1 parent e0be616 commit 65e94b3
Show file tree
Hide file tree
Showing 53 changed files with 1,822 additions and 1,357 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ repos:
- id: black
files: \.py$
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.7
rev: v0.4.1
hooks:
- id: ruff
args: [ --fix, --exit-non-zero-on-fix ]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,10 @@
from semantic_kernel.connectors.ai.open_ai import (
AzureAISearchDataSource,
AzureChatCompletion,
AzureChatMessageContent,
AzureChatPromptExecutionSettings,
ExtraBody,
FunctionCall,
ToolCall,
)
from semantic_kernel.contents import ChatHistory, ChatRole
from semantic_kernel.contents import ChatHistory
from semantic_kernel.functions import KernelArguments
from semantic_kernel.prompt_template import InputVariable, PromptTemplateConfig
from semantic_kernel.utils.settings import (
Expand All @@ -22,7 +19,7 @@
)

kernel = Kernel()
logging.basicConfig(level=logging.DEBUG)
logging.basicConfig(level=logging.INFO)

# Load Azure OpenAI Settings
aoai_settings = azure_openai_settings_from_dot_env_as_dict(include_api_version=True)
Expand All @@ -35,14 +32,15 @@

azure_ai_search_settings = azure_aisearch_settings_from_dot_env_as_dict()

# Our example index has fields "source_title", "source_text", "source_url", and "source_file".
# Add fields mapping to the settings to indicate which fields to use for the title, content, URL, and file path.
azure_ai_search_settings["fieldsMapping"] = {
"titleField": "source_title",
"urlField": "source_url",
"contentFields": ["source_text"],
"filepathField": "source_file",
}
# Depending on the index that you use, you might need to enable the below
# and adapt it so that it accurately reflects your index.

# azure_ai_search_settings["fieldsMapping"] = {
# "titleField": "source_title",
# "urlField": "source_url",
# "contentFields": ["source_text"],
# "filepathField": "source_file",
# }

# Create the data source settings

Expand Down Expand Up @@ -88,38 +86,30 @@ async def chat() -> bool:
if user_input == "exit":
print("\n\nExiting chat...")
return False

# Non streaming
# answer = await kernel.run(chat_function, input_vars=context_vars)
# print(f"Assistant:> {answer}")
arguments = KernelArguments(chat_history=chat_history, user_input=user_input, execution_settings=req_settings)

full_message = None
print("Assistant:> ", end="")
async for message in kernel.invoke_stream(chat_function, arguments=arguments):
print(str(message[0]), end="")
full_message = message[0] if not full_message else full_message + message[0]
print("\n")

# The tool message containing cited sources is available in the context
if full_message:
stream = False
if stream:
# streaming
full_message = None
print("Assistant:> ", end="")
async for message in kernel.invoke_stream(chat_function, arguments=arguments):
print(str(message[0]), end="")
full_message = message[0] if not full_message else full_message + message[0]
print("\n")

# The tool message containing cited sources is available in the context
chat_history.add_user_message(user_input)
if hasattr(full_message, "tool_message"):
chat_history.add_message(
AzureChatMessageContent(
role="assistant",
tool_calls=[
ToolCall(
id="chat_with_your_data",
function=FunctionCall(name="chat_with_your_data", arguments=""),
)
],
)
)
chat_history.add_tool_message(full_message.tool_message, {"tool_call_id": "chat_with_your_data"})
if full_message.role is None:
full_message.role = ChatRole.ASSISTANT
chat_history.add_assistant_message(full_message.content)
for message in AzureChatCompletion.split_message(full_message):
chat_history.add_message(message)
return True

# Non streaming
answer = await kernel.invoke(chat_function, arguments=arguments)
print(f"Assistant:> {answer}")
chat_history.add_user_message(user_input)
for message in AzureChatCompletion.split_message(answer.value[0]):
chat_history.add_message(message)
return True


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,10 @@
from semantic_kernel.connectors.ai.open_ai import (
AzureAISearchDataSource,
AzureChatCompletion,
AzureChatMessageContent,
AzureChatPromptExecutionSettings,
ExtraBody,
FunctionCall,
ToolCall,
)
from semantic_kernel.contents import ChatHistory, ChatRole
from semantic_kernel.contents import ChatHistory
from semantic_kernel.functions import KernelArguments
from semantic_kernel.kernel import Kernel
from semantic_kernel.prompt_template import InputVariable, PromptTemplateConfig
Expand Down Expand Up @@ -114,22 +111,8 @@ async def chat() -> bool:
# The tool message containing cited sources is available in the context
if full_message:
chat_history.add_user_message(user_input)
if hasattr(full_message, "tool_message"):
chat_history.add_message(
AzureChatMessageContent(
role="assistant",
tool_calls=[
ToolCall(
id="chat_with_your_data",
function=FunctionCall(name="chat_with_your_data", arguments=""),
)
],
)
)
chat_history.add_tool_message(full_message.tool_message, {"tool_call_id": "chat_with_your_data"})
if full_message.role is None:
full_message.role = ChatRole.ASSISTANT
chat_history.add_assistant_message(full_message.content)
for message in AzureChatCompletion.split_message(full_message):
chat_history.add_message(message)
return True


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from semantic_kernel import Kernel
from semantic_kernel.connectors.ai.open_ai import AzureTextCompletion, AzureTextEmbedding
from semantic_kernel.connectors.memory import AzureCognitiveSearchMemoryStore
from semantic_kernel.connectors.memory.azure_cognitive_search import AzureCognitiveSearchMemoryStore
from semantic_kernel.core_plugins import TextMemoryPlugin
from semantic_kernel.memory import SemanticTextMemory

Expand Down Expand Up @@ -62,7 +62,7 @@ async def main() -> None:
api_key=AZURE_OPENAI_API_KEY,
),
)
embedding_service_id = ("ada",)
embedding_service_id = "ada"
embedding_gen = AzureTextEmbedding(
service_id=embedding_service_id,
deployment_name="text-embedding-ada-002",
Expand All @@ -81,10 +81,10 @@ async def main() -> None:
kernel.add_plugin(TextMemoryPlugin(memory), "TextMemoryPlugin")

print("Populating memory...")
await populate_memory(kernel)
await populate_memory(memory)

print("Asking questions... (manually)")
await search_acs_memory_questions(kernel)
await search_acs_memory_questions(memory)

await acs_connector.close()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,15 @@
import asyncio
import os
from functools import reduce
from typing import TYPE_CHECKING, Any, Dict, List
from typing import TYPE_CHECKING, List

from semantic_kernel import Kernel
from semantic_kernel.connectors.ai.open_ai import (
OpenAIChatCompletion,
OpenAIChatMessageContent,
OpenAIChatPromptExecutionSettings,
OpenAIStreamingChatMessageContent,
)
from semantic_kernel.connectors.ai.open_ai import OpenAIChatCompletion, OpenAIChatPromptExecutionSettings
from semantic_kernel.connectors.ai.open_ai.utils import get_tool_call_object
from semantic_kernel.contents import ChatHistory
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.function_call_content import FunctionCallContent
from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent
from semantic_kernel.core_plugins import MathPlugin, TimePlugin
from semantic_kernel.functions import KernelArguments
from semantic_kernel.utils.settings import openai_settings_from_dot_env
Expand Down Expand Up @@ -88,61 +86,46 @@
arguments = KernelArguments(settings=execution_settings)


def print_tool_calls(message: OpenAIChatMessageContent) -> None:
def print_tool_calls(message: ChatMessageContent) -> None:
# A helper method to pretty print the tool calls from the message.
# This is only triggered if auto invoke tool calls is disabled.
if isinstance(message, OpenAIChatMessageContent):
tool_calls = message.tool_calls
formatted_tool_calls = []
for i, tool_call in enumerate(tool_calls, start=1):
tool_call_id = tool_call.id
function_name = tool_call.function.name
function_arguments = tool_call.function.arguments
items = message.items
formatted_tool_calls = []
for i, item in enumerate(items, start=1):
if isinstance(item, FunctionCallContent):
tool_call_id = item.id
function_name = item.name
function_arguments = item.arguments
formatted_str = (
f"tool_call {i} id: {tool_call_id}\n"
f"tool_call {i} function name: {function_name}\n"
f"tool_call {i} arguments: {function_arguments}"
)
formatted_tool_calls.append(formatted_str)
print("Tool calls:\n" + "\n\n".join(formatted_tool_calls))
print("Tool calls:\n" + "\n\n".join(formatted_tool_calls))


async def handle_streaming(
kernel: Kernel,
chat_function: "KernelFunction",
user_input: str,
history: ChatHistory,
execution_settings: OpenAIChatPromptExecutionSettings,
arguments: KernelArguments,
) -> None:
response = kernel.invoke_stream(
chat_function,
return_function_results=False,
user_input=user_input,
chat_history=history,
arguments=arguments,
)

print("Mosscap:> ", end="")
streamed_chunks: List[OpenAIStreamingChatMessageContent] = []
tool_call_ids_by_index: Dict[str, Any] = {}

streamed_chunks: List[StreamingChatMessageContent] = []
async for message in response:
if not execution_settings.auto_invoke_kernel_functions and isinstance(
message[0], OpenAIStreamingChatMessageContent
):
if not execution_settings.auto_invoke_kernel_functions:
streamed_chunks.append(message[0])
if message[0].tool_calls is not None:
for tc in message[0].tool_calls:
if tc.id not in tool_call_ids_by_index:
tool_call_ids_by_index[tc.id] = tc
else:
for tc in message[0].tool_calls:
tool_call_ids_by_index[tc.id] += tc
else:
print(str(message[0]), end="")

if streamed_chunks:
streaming_chat_message = reduce(lambda first, second: first + second, streamed_chunks)
streaming_chat_message.tool_calls = list(tool_call_ids_by_index.values())
print("Auto tool calls is disabled, printing returned tool calls...")
print_tool_calls(streaming_chat_message)

Expand All @@ -162,19 +145,19 @@ async def chat() -> bool:
if user_input == "exit":
print("\n\nExiting chat...")
return False
arguments["user_input"] = user_input
arguments["chat_history"] = history

stream = True
if stream:
await handle_streaming(kernel, chat_function, user_input, history, execution_settings)
await handle_streaming(kernel, chat_function, arguments=arguments)
else:
result = await kernel.invoke(chat_function, user_input=user_input, chat_history=history)
result = await kernel.invoke(chat_function, arguments=arguments)

# If tools are used, and auto invoke tool calls is False, the response will be of type
# OpenAIChatMessageContent with information about the tool calls, which need to be sent
# ChatMessageContent with information about the tool calls, which need to be sent
# back to the model to get the final response.
if not execution_settings.auto_invoke_kernel_functions and isinstance(
result.value[0], OpenAIChatMessageContent
):
if not execution_settings.auto_invoke_kernel_functions:
print_tool_calls(result.value[0])
return True

Expand Down
27 changes: 10 additions & 17 deletions python/semantic_kernel/connectors/ai/chat_completion_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,13 @@


class ChatCompletionClientBase(AIServiceClientBase, ABC):
def get_chat_message_content_type(self) -> str:
"""Get the chat message content types used by a class, default is 'ChatMessageContent'."""
return "ChatMessageContent"

@abstractmethod
async def complete_chat(
self,
chat_history: ChatHistory,
settings: PromptExecutionSettings,
chat_history: "ChatHistory",
settings: "PromptExecutionSettings",
**kwargs: Any,
) -> list[ChatMessageContent]:
) -> list["ChatMessageContent"]:
"""
This is the method that is called from the kernel to get a response from a chat-optimized LLM.
Expand All @@ -42,10 +38,10 @@ async def complete_chat(
@abstractmethod
def complete_chat_stream(
self,
chat_history: ChatHistory,
settings: PromptExecutionSettings,
chat_history: "ChatHistory",
settings: "PromptExecutionSettings",
**kwargs: Any,
) -> AsyncGenerator[list[StreamingChatMessageContent], Any]:
) -> AsyncGenerator[list["StreamingChatMessageContent"], Any]:
"""
This is the method that is called from the kernel to get a stream response from a chat-optimized LLM.
Expand All @@ -63,7 +59,9 @@ def complete_chat_stream(

def _prepare_chat_history_for_request(
self,
chat_history: ChatHistory,
chat_history: "ChatHistory",
role_key: str = "role",
content_key: str = "content",
) -> list[dict[str, str | None]]:
"""
Prepare the chat history for a request, allowing customization of the key names for role/author,
Expand All @@ -79,9 +77,4 @@ def _prepare_chat_history_for_request(
Returns:
List[Dict[str, Optional[str]]] -- The prepared chat history.
"""
return [self._chat_message_content_to_dict(message) for message in chat_history.messages]

def _chat_message_content_to_dict(self, message: "ChatMessageContent") -> dict[str, str | None]:
"""can be overridden to customize the serialization of the chat message content"""
msg = message.model_dump(include=["role", "content"])
return msg
return [message.to_dict(role_key=role_key, content_key=content_key) for message in chat_history.messages]
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Copyright (c) Microsoft. All rights reserved.

from typing import Any, Dict, Iterable, List, Optional, Union

from pydantic import Field, model_validator
Expand Down

0 comments on commit 65e94b3

Please sign in to comment.