Skip to content

Commit

Permalink
update to GP
Browse files Browse the repository at this point in the history
  • Loading branch information
eavanvalkenburg committed Apr 22, 2024
1 parent e121eb5 commit 8f3de14
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ async def chat() -> bool:
print("\n\nExiting chat...")
return False

stream = False
stream = True
if stream:
answer = kernel.invoke_stream(
chat_function,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def complete_chat_stream(
def _prepare_chat_history_for_request(
self,
chat_history: "ChatHistory",
role_key: str = "role",
content_key: str = "content",
) -> list[dict[str, str | None]]:
"""
Prepare the chat history for a request, allowing customization of the key names for role/author,
Expand All @@ -75,4 +77,4 @@ def _prepare_chat_history_for_request(
Returns:
List[Dict[str, Optional[str]]] -- The prepared chat history.
"""
return [message.to_dict() for message in chat_history.messages]
return [message.to_dict(role_key=role_key, content_key=content_key) for message in chat_history.messages]
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
import sys
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, List, Optional, Tuple

if sys.version_info >= (3, 9):
from typing import Annotated
Expand Down Expand Up @@ -76,7 +76,7 @@ async def complete_chat(
Returns:
List[ChatMessageContent] -- A list of ChatMessageContent objects representing the response(s) from the LLM.
"""
settings.messages = self._prepare_chat_history_for_request(chat_history)
settings.messages = self._prepare_chat_history_for_request(chat_history, role_key="author")
if not settings.ai_model_id:
settings.ai_model_id = self.ai_model_id
response = await self._send_chat_request(settings)
Expand Down Expand Up @@ -227,18 +227,3 @@ async def _send_chat_request(
def get_prompt_execution_settings_class(self) -> "PromptExecutionSettings":
"""Create a request settings object."""
return GooglePalmChatPromptExecutionSettings

def _prepare_chat_history_for_request(
self,
chat_history: ChatHistory,
) -> List[Dict[str, Optional[str]]]:
"""
Prepare the chat history for a request, allowing customization of the key names for role/author,
and optionally overriding the role.
"""
standard_out = super()._prepare_chat_history_for_request(chat_history)
for message in standard_out:
message["author"] = message.pop("role")
# The last message should always be from the user
standard_out[-1]["author"] = "user"
return standard_out
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,8 @@
if sys.version_info >= (3, 9):
from google.generativeai.types import ChatResponse, MessageDict

from semantic_kernel.connectors.ai.google_palm import (
GooglePalmChatPromptExecutionSettings,
)
from semantic_kernel.connectors.ai.google_palm.services.gp_chat_completion import (
GooglePalmChatCompletion,
)
from semantic_kernel.connectors.ai.google_palm import GooglePalmChatPromptExecutionSettings
from semantic_kernel.connectors.ai.google_palm.services.gp_chat_completion import GooglePalmChatCompletion
from semantic_kernel.contents.chat_history import ChatHistory


Expand Down Expand Up @@ -87,5 +83,5 @@ def reply(self):
top_p=settings.top_p,
top_k=settings.top_k,
candidate_count=settings.candidate_count,
messages=gp_chat_completion._prepare_chat_history_for_request(chats),
messages=[message.to_dict(role_key="author") for message in chats.messages],
)

0 comments on commit 8f3de14

Please sign in to comment.