Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python: updated Chat Message Content with Function Call and Result Content #5946

Merged
merged 8 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ repos:
- id: black
files: \.py$
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.7
rev: v0.4.1
hooks:
- id: ruff
args: [ --fix, --exit-non-zero-on-fix ]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,10 @@
from semantic_kernel.connectors.ai.open_ai import (
AzureAISearchDataSource,
AzureChatCompletion,
AzureChatMessageContent,
AzureChatPromptExecutionSettings,
ExtraBody,
FunctionCall,
ToolCall,
)
from semantic_kernel.contents import ChatHistory, ChatRole
from semantic_kernel.contents import ChatHistory
from semantic_kernel.functions import KernelArguments
from semantic_kernel.prompt_template import InputVariable, PromptTemplateConfig
from semantic_kernel.utils.settings import (
Expand All @@ -22,7 +19,7 @@
)

kernel = Kernel()
logging.basicConfig(level=logging.DEBUG)
logging.basicConfig(level=logging.INFO)

# Load Azure OpenAI Settings
aoai_settings = azure_openai_settings_from_dot_env_as_dict(include_api_version=True)
Expand All @@ -35,14 +32,15 @@

azure_ai_search_settings = azure_aisearch_settings_from_dot_env_as_dict()

# Our example index has fields "source_title", "source_text", "source_url", and "source_file".
# Add fields mapping to the settings to indicate which fields to use for the title, content, URL, and file path.
azure_ai_search_settings["fieldsMapping"] = {
"titleField": "source_title",
"urlField": "source_url",
"contentFields": ["source_text"],
"filepathField": "source_file",
}
# Depending on the index that you use, you might need to enable the below
# and adapt it so that it accurately reflects your index.

# azure_ai_search_settings["fieldsMapping"] = {
eavanvalkenburg marked this conversation as resolved.
Show resolved Hide resolved
# "titleField": "source_title",
# "urlField": "source_url",
# "contentFields": ["source_text"],
# "filepathField": "source_file",
# }

# Create the data source settings

Expand Down Expand Up @@ -88,38 +86,30 @@ async def chat() -> bool:
if user_input == "exit":
print("\n\nExiting chat...")
return False

# Non streaming
# answer = await kernel.run(chat_function, input_vars=context_vars)
# print(f"Assistant:> {answer}")
arguments = KernelArguments(chat_history=chat_history, user_input=user_input, execution_settings=req_settings)

full_message = None
print("Assistant:> ", end="")
async for message in kernel.invoke_stream(chat_function, arguments=arguments):
print(str(message[0]), end="")
full_message = message[0] if not full_message else full_message + message[0]
print("\n")

# The tool message containing cited sources is available in the context
if full_message:
stream = False
if stream:
# streaming
full_message = None
print("Assistant:> ", end="")
async for message in kernel.invoke_stream(chat_function, arguments=arguments):
print(str(message[0]), end="")
full_message = message[0] if not full_message else full_message + message[0]
print("\n")

# The tool message containing cited sources is available in the context
chat_history.add_user_message(user_input)
if hasattr(full_message, "tool_message"):
chat_history.add_message(
AzureChatMessageContent(
role="assistant",
tool_calls=[
ToolCall(
id="chat_with_your_data",
function=FunctionCall(name="chat_with_your_data", arguments=""),
)
],
)
)
chat_history.add_tool_message(full_message.tool_message, {"tool_call_id": "chat_with_your_data"})
if full_message.role is None:
full_message.role = ChatRole.ASSISTANT
chat_history.add_assistant_message(full_message.content)
for message in AzureChatCompletion.split_message(full_message):
chat_history.add_message(message)
return True

# Non streaming
answer = await kernel.invoke(chat_function, arguments=arguments)
print(f"Assistant:> {answer}")
chat_history.add_user_message(user_input)
for message in AzureChatCompletion.split_message(answer.value[0]):
chat_history.add_message(message)
return True


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,10 @@
from semantic_kernel.connectors.ai.open_ai import (
AzureAISearchDataSource,
AzureChatCompletion,
AzureChatMessageContent,
AzureChatPromptExecutionSettings,
ExtraBody,
FunctionCall,
ToolCall,
)
from semantic_kernel.contents import ChatHistory, ChatRole
from semantic_kernel.contents import ChatHistory
from semantic_kernel.functions import KernelArguments
from semantic_kernel.kernel import Kernel
from semantic_kernel.prompt_template import InputVariable, PromptTemplateConfig
Expand Down Expand Up @@ -114,22 +111,8 @@ async def chat() -> bool:
# The tool message containing cited sources is available in the context
if full_message:
chat_history.add_user_message(user_input)
if hasattr(full_message, "tool_message"):
chat_history.add_message(
AzureChatMessageContent(
role="assistant",
tool_calls=[
ToolCall(
id="chat_with_your_data",
function=FunctionCall(name="chat_with_your_data", arguments=""),
)
],
)
)
chat_history.add_tool_message(full_message.tool_message, {"tool_call_id": "chat_with_your_data"})
if full_message.role is None:
full_message.role = ChatRole.ASSISTANT
chat_history.add_assistant_message(full_message.content)
for message in AzureChatCompletion.split_message(full_message):
chat_history.add_message(message)
return True


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from semantic_kernel import Kernel
from semantic_kernel.connectors.ai.open_ai import AzureTextCompletion, AzureTextEmbedding
from semantic_kernel.connectors.memory import AzureCognitiveSearchMemoryStore
from semantic_kernel.connectors.memory.azure_cognitive_search import AzureCognitiveSearchMemoryStore
from semantic_kernel.core_plugins import TextMemoryPlugin
from semantic_kernel.memory import SemanticTextMemory

Expand Down Expand Up @@ -62,7 +62,7 @@ async def main() -> None:
api_key=AZURE_OPENAI_API_KEY,
),
)
embedding_service_id = ("ada",)
embedding_service_id = "ada"
embedding_gen = AzureTextEmbedding(
service_id=embedding_service_id,
deployment_name="text-embedding-ada-002",
Expand All @@ -81,10 +81,10 @@ async def main() -> None:
kernel.add_plugin(TextMemoryPlugin(memory), "TextMemoryPlugin")

print("Populating memory...")
await populate_memory(kernel)
await populate_memory(memory)

print("Asking questions... (manually)")
await search_acs_memory_questions(kernel)
await search_acs_memory_questions(memory)

await acs_connector.close()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,15 @@
import asyncio
import os
from functools import reduce
from typing import TYPE_CHECKING, Any, Dict, List
from typing import TYPE_CHECKING, List

from semantic_kernel import Kernel
from semantic_kernel.connectors.ai.open_ai import (
OpenAIChatCompletion,
OpenAIChatMessageContent,
OpenAIChatPromptExecutionSettings,
OpenAIStreamingChatMessageContent,
)
from semantic_kernel.connectors.ai.open_ai import OpenAIChatCompletion, OpenAIChatPromptExecutionSettings
from semantic_kernel.connectors.ai.open_ai.utils import get_tool_call_object
from semantic_kernel.contents import ChatHistory
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.function_call_content import FunctionCallContent
from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent
from semantic_kernel.core_plugins import MathPlugin, TimePlugin
from semantic_kernel.functions import KernelArguments
from semantic_kernel.utils.settings import openai_settings_from_dot_env
Expand Down Expand Up @@ -88,61 +86,46 @@
arguments = KernelArguments(settings=execution_settings)


def print_tool_calls(message: OpenAIChatMessageContent) -> None:
def print_tool_calls(message: ChatMessageContent) -> None:
# A helper method to pretty print the tool calls from the message.
# This is only triggered if auto invoke tool calls is disabled.
if isinstance(message, OpenAIChatMessageContent):
tool_calls = message.tool_calls
formatted_tool_calls = []
for i, tool_call in enumerate(tool_calls, start=1):
tool_call_id = tool_call.id
function_name = tool_call.function.name
function_arguments = tool_call.function.arguments
items = message.items
formatted_tool_calls = []
for i, item in enumerate(items, start=1):
if isinstance(item, FunctionCallContent):
tool_call_id = item.id
function_name = item.name
function_arguments = item.arguments
formatted_str = (
f"tool_call {i} id: {tool_call_id}\n"
f"tool_call {i} function name: {function_name}\n"
f"tool_call {i} arguments: {function_arguments}"
)
formatted_tool_calls.append(formatted_str)
print("Tool calls:\n" + "\n\n".join(formatted_tool_calls))
print("Tool calls:\n" + "\n\n".join(formatted_tool_calls))


async def handle_streaming(
kernel: Kernel,
chat_function: "KernelFunction",
user_input: str,
history: ChatHistory,
execution_settings: OpenAIChatPromptExecutionSettings,
arguments: KernelArguments,
) -> None:
response = kernel.invoke_stream(
chat_function,
return_function_results=False,
user_input=user_input,
chat_history=history,
arguments=arguments,
)

print("Mosscap:> ", end="")
streamed_chunks: List[OpenAIStreamingChatMessageContent] = []
tool_call_ids_by_index: Dict[str, Any] = {}

streamed_chunks: List[StreamingChatMessageContent] = []
async for message in response:
if not execution_settings.auto_invoke_kernel_functions and isinstance(
message[0], OpenAIStreamingChatMessageContent
):
if not execution_settings.auto_invoke_kernel_functions:
streamed_chunks.append(message[0])
if message[0].tool_calls is not None:
for tc in message[0].tool_calls:
if tc.id not in tool_call_ids_by_index:
tool_call_ids_by_index[tc.id] = tc
else:
for tc in message[0].tool_calls:
tool_call_ids_by_index[tc.id] += tc
else:
print(str(message[0]), end="")

if streamed_chunks:
streaming_chat_message = reduce(lambda first, second: first + second, streamed_chunks)
streaming_chat_message.tool_calls = list(tool_call_ids_by_index.values())
print("Auto tool calls is disabled, printing returned tool calls...")
print_tool_calls(streaming_chat_message)

Expand All @@ -162,19 +145,19 @@ async def chat() -> bool:
if user_input == "exit":
print("\n\nExiting chat...")
return False
arguments["user_input"] = user_input
arguments["chat_history"] = history

stream = True
if stream:
await handle_streaming(kernel, chat_function, user_input, history, execution_settings)
await handle_streaming(kernel, chat_function, arguments=arguments)
else:
result = await kernel.invoke(chat_function, user_input=user_input, chat_history=history)
result = await kernel.invoke(chat_function, arguments=arguments)

# If tools are used, and auto invoke tool calls is False, the response will be of type
# OpenAIChatMessageContent with information about the tool calls, which need to be sent
# ChatMessageContent with information about the tool calls, which need to be sent
# back to the model to get the final response.
if not execution_settings.auto_invoke_kernel_functions and isinstance(
result.value[0], OpenAIChatMessageContent
):
if not execution_settings.auto_invoke_kernel_functions:
print_tool_calls(result.value[0])
return True

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,13 @@


class ChatCompletionClientBase(AIServiceClientBase, ABC):
def get_chat_message_content_type(self) -> str:
"""Get the chat message content types used by a class, default is 'ChatMessageContent'."""
return "ChatMessageContent"

@abstractmethod
async def complete_chat(
self,
chat_history: ChatHistory,
settings: PromptExecutionSettings,
chat_history: "ChatHistory",
settings: "PromptExecutionSettings",
**kwargs: Any,
) -> list[ChatMessageContent]:
) -> list["ChatMessageContent"]:
"""
This is the method that is called from the kernel to get a response from a chat-optimized LLM.

Expand All @@ -42,10 +38,10 @@ async def complete_chat(
@abstractmethod
def complete_chat_stream(
self,
chat_history: ChatHistory,
settings: PromptExecutionSettings,
chat_history: "ChatHistory",
settings: "PromptExecutionSettings",
**kwargs: Any,
) -> AsyncGenerator[list[StreamingChatMessageContent], Any]:
) -> AsyncGenerator[list["StreamingChatMessageContent"], Any]:
"""
This is the method that is called from the kernel to get a stream response from a chat-optimized LLM.

Expand All @@ -63,7 +59,9 @@ def complete_chat_stream(

def _prepare_chat_history_for_request(
self,
chat_history: ChatHistory,
chat_history: "ChatHistory",
role_key: str = "role",
content_key: str = "content",
) -> list[dict[str, str | None]]:
"""
Prepare the chat history for a request, allowing customization of the key names for role/author,
Expand All @@ -79,9 +77,4 @@ def _prepare_chat_history_for_request(
Returns:
List[Dict[str, Optional[str]]] -- The prepared chat history.
"""
return [self._chat_message_content_to_dict(message) for message in chat_history.messages]

def _chat_message_content_to_dict(self, message: "ChatMessageContent") -> dict[str, str | None]:
"""can be overridden to customize the serialization of the chat message content"""
msg = message.model_dump(include=["role", "content"])
return msg
return [message.to_dict(role_key=role_key, content_key=content_key) for message in chat_history.messages]
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Copyright (c) Microsoft. All rights reserved.

from typing import Any, Dict, Iterable, List, Optional, Union

from pydantic import Field, model_validator
Expand Down