Skip to content

Commit

Permalink
feat: LLM - Exposed the chat history
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 542386162
  • Loading branch information
Ark-kun authored and copybara-github committed Jun 21, 2023
1 parent 8abd9e4 commit bf0e20b
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 42 deletions.
23 changes: 17 additions & 6 deletions tests/system/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,24 @@ def test_chat_on_chat_model(self):
temperature=0.0,
)

assert chat.send_message("Are my favorite movies based on a book series?").text
assert len(chat._history) == 1
assert chat.send_message(
"When where these books published?",
message1 = "Are my favorite movies based on a book series?"
response1 = chat.send_message(message1)
assert response1.text
assert len(chat.message_history) == 2
assert chat.message_history[0].author == chat.USER_AUTHOR
assert chat.message_history[0].content == message1
assert chat.message_history[1].author == chat.MODEL_AUTHOR

message2 = "When where these books published?"
response2 = chat.send_message(
message2,
temperature=0.1,
).text
assert len(chat._history) == 2
)
assert response2.text
assert len(chat.message_history) == 4
assert chat.message_history[2].author == chat.USER_AUTHOR
assert chat.message_history[2].content == message2
assert chat.message_history[3].author == chat.MODEL_AUTHOR

def test_text_embedding(self):
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
Expand Down
43 changes: 24 additions & 19 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,14 +758,17 @@ def test_chat(self):
attribute="predict",
return_value=gca_predict_response1,
):
response = chat.send_message(
"Are my favorite movies based on a book series?"
)
assert (
response.text
== _TEST_CHAT_GENERATION_PREDICTION1["candidates"][0]["content"]
)
assert len(chat._history) == 1
message_text1 = "Are my favorite movies based on a book series?"
expected_response1 = _TEST_CHAT_GENERATION_PREDICTION1["candidates"][0][
"content"
]
response = chat.send_message(message_text1)
assert response.text == expected_response1
assert len(chat.message_history) == 2
assert chat.message_history[0].author == chat.USER_AUTHOR
assert chat.message_history[0].content == message_text1
assert chat.message_history[1].author == chat.MODEL_AUTHOR
assert chat.message_history[1].content == expected_response1

gca_predict_response2 = gca_prediction_service.PredictResponse()
gca_predict_response2.predictions.append(_TEST_CHAT_GENERATION_PREDICTION2)
Expand All @@ -775,15 +778,17 @@ def test_chat(self):
attribute="predict",
return_value=gca_predict_response2,
):
response = chat.send_message(
"When where these books published?",
temperature=0.1,
)
assert (
response.text
== _TEST_CHAT_GENERATION_PREDICTION2["candidates"][0]["content"]
)
assert len(chat._history) == 2
message_text2 = "When where these books published?"
expected_response2 = _TEST_CHAT_GENERATION_PREDICTION2["candidates"][0][
"content"
]
response = chat.send_message(message_text2, temperature=0.1)
assert response.text == expected_response2
assert len(chat.message_history) == 4
assert chat.message_history[2].author == chat.USER_AUTHOR
assert chat.message_history[2].content == message_text2
assert chat.message_history[3].author == chat.MODEL_AUTHOR
assert chat.message_history[3].content == expected_response2

# Validating the parameters
chat_temperature = 0.1
Expand Down Expand Up @@ -870,7 +875,7 @@ def test_code_chat(self):
response.text
== _TEST_CHAT_GENERATION_PREDICTION1["candidates"][0]["content"]
)
assert len(code_chat._history) == 1
assert len(code_chat.message_history) == 2

gca_predict_response2 = gca_prediction_service.PredictResponse()
gca_predict_response2.predictions.append(_TEST_CHAT_GENERATION_PREDICTION2)
Expand All @@ -889,7 +894,7 @@ def test_code_chat(self):
response.text
== _TEST_CHAT_GENERATION_PREDICTION2["candidates"][0]["content"]
)
assert len(code_chat._history) == 2
assert len(code_chat.message_history) == 4

# Validating the parameters
chat_temperature = 0.1
Expand Down
60 changes: 43 additions & 17 deletions vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,19 @@ class InputOutputTextPair:
output_text: str


@dataclasses.dataclass
class ChatMessage:
"""A chat message.
Attributes:
content: Content of the message.
author: Author of the message.
"""

content: str
author: str


class _ChatModelBase(_LanguageModel):
"""_ChatModelBase is a base class for chat models."""

Expand All @@ -579,6 +592,7 @@ def start_chat(
temperature: float = TextGenerationModel._DEFAULT_TEMPERATURE,
top_k: int = TextGenerationModel._DEFAULT_TOP_K,
top_p: float = TextGenerationModel._DEFAULT_TOP_P,
message_history: Optional[List[ChatMessage]] = None,
) -> "ChatSession":
"""Starts a chat session with the model.
Expand All @@ -591,6 +605,7 @@ def start_chat(
temperature: Controls the randomness of predictions. Range: [0, 1].
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1].
message_history: A list of previously sent and received messages.
Returns:
A `ChatSession` object.
Expand All @@ -603,6 +618,7 @@ def start_chat(
temperature=temperature,
top_k=top_k,
top_p=top_p,
message_history=message_history,
)


Expand Down Expand Up @@ -678,6 +694,9 @@ def start_chat(
class _ChatSessionBase:
"""_ChatSessionBase is a base class for all chat sessions."""

USER_AUTHOR = "user"
MODEL_AUTHOR = "bot"

def __init__(
self,
model: _ChatModelBase,
Expand All @@ -688,16 +707,22 @@ def __init__(
top_k: int = TextGenerationModel._DEFAULT_TOP_K,
top_p: float = TextGenerationModel._DEFAULT_TOP_P,
is_code_chat_session: bool = False,
message_history: Optional[List[ChatMessage]] = None,
):
self._model = model
self._context = context
self._examples = examples
self._history = []
self._max_output_tokens = max_output_tokens
self._temperature = temperature
self._top_k = top_k
self._top_p = top_p
self._is_code_chat_session = is_code_chat_session
self._message_history: List[ChatMessage] = message_history or []

@property
def message_history(self) -> List[ChatMessage]:
"""List of previous messages."""
return self._message_history

def send_message(
self,
Expand Down Expand Up @@ -737,29 +762,22 @@ def send_message(
prediction_parameters["topP"] = top_p if top_p is not None else self._top_p
prediction_parameters["topK"] = top_k if top_k is not None else self._top_k

messages = []
for input_text, output_text in self._history:
messages.append(
message_structs = []
for past_message in self._message_history:
message_structs.append(
{
"author": "user",
"content": input_text,
"author": past_message.author,
"content": past_message.content,
}
)
messages.append(
{
"author": "bot",
"content": output_text,
}
)

messages.append(
message_structs.append(
{
"author": "user",
"author": self.USER_AUTHOR,
"content": message,
}
)

prediction_instance = {"messages": messages}
prediction_instance = {"messages": message_structs}
if not self._is_code_chat_session and self._context:
prediction_instance["context"] = self._context
if not self._is_code_chat_session and self._examples:
Expand Down Expand Up @@ -793,7 +811,13 @@ def send_message(
)
response_text = response_obj.text

self._history.append((message, response_text))
self._message_history.append(
ChatMessage(content=message, author=self.USER_AUTHOR)
)
self._message_history.append(
ChatMessage(content=response_text, author=self.MODEL_AUTHOR)
)

return response_obj


Expand All @@ -812,6 +836,7 @@ def __init__(
temperature: float = TextGenerationModel._DEFAULT_TEMPERATURE,
top_k: int = TextGenerationModel._DEFAULT_TOP_K,
top_p: float = TextGenerationModel._DEFAULT_TOP_P,
message_history: Optional[List[ChatMessage]] = None,
):
super().__init__(
model=model,
Expand All @@ -821,6 +846,7 @@ def __init__(
temperature=temperature,
top_k=top_k,
top_p=top_p,
message_history=message_history,
)


Expand Down

0 comments on commit bf0e20b

Please sign in to comment.