Skip to content

Commit

Permalink
feat: LLM - Released the Chat models to GA
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 546475152
  • Loading branch information
Ark-kun authored and Copybara-Service committed Jul 8, 2023
1 parent 52d0267 commit 22aa26d
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 2 deletions.
135 changes: 134 additions & 1 deletion tests/unit/aiplatform/test_language_models.py
Expand Up @@ -88,7 +88,7 @@
"name": "publishers/google/models/chat-bison",
"version_id": "001",
"open_source_category": "PROPRIETARY",
"launch_stage": gca_publisher_model.PublisherModel.LaunchStage.PUBLIC_PREVIEW,
"launch_stage": gca_publisher_model.PublisherModel.LaunchStage.GA,
"publisher_model_template": "projects/{user-project}/locations/{location}/publishers/google/models/chat-bison@001",
"predict_schemata": {
"instance_schema_uri": "gs://google-cloud-aiplatform/schema/predict/instance/chat_generation_1.0.0.yaml",
Expand Down Expand Up @@ -792,6 +792,139 @@ def test_chat(self):
gca_predict_response2 = gca_prediction_service.PredictResponse()
gca_predict_response2.predictions.append(_TEST_CHAT_GENERATION_PREDICTION2)

with mock.patch.object(
target=prediction_service_client.PredictionServiceClient,
attribute="predict",
return_value=gca_predict_response2,
):
message_text2 = "When were these books published?"
expected_response2 = _TEST_CHAT_GENERATION_PREDICTION2["candidates"][0][
"content"
]
response = chat.send_message(message_text2, temperature=0.1)
assert response.text == expected_response2
assert len(chat.message_history) == 6
assert chat.message_history[4].author == chat.USER_AUTHOR
assert chat.message_history[4].content == message_text2
assert chat.message_history[5].author == chat.MODEL_AUTHOR
assert chat.message_history[5].content == expected_response2

# Validating the parameters
chat_temperature = 0.1
chat_max_output_tokens = 100
chat_top_k = 1
chat_top_p = 0.1
message_temperature = 0.2
message_max_output_tokens = 200
message_top_k = 2
message_top_p = 0.2

chat2 = model.start_chat(
temperature=chat_temperature,
max_output_tokens=chat_max_output_tokens,
top_k=chat_top_k,
top_p=chat_top_p,
)

gca_predict_response3 = gca_prediction_service.PredictResponse()
gca_predict_response3.predictions.append(_TEST_CHAT_GENERATION_PREDICTION1)

with mock.patch.object(
target=prediction_service_client.PredictionServiceClient,
attribute="predict",
return_value=gca_predict_response3,
) as mock_predict3:
chat2.send_message("Are my favorite movies based on a book series?")
prediction_parameters = mock_predict3.call_args[1]["parameters"]
assert prediction_parameters["temperature"] == chat_temperature
assert prediction_parameters["maxDecodeSteps"] == chat_max_output_tokens
assert prediction_parameters["topK"] == chat_top_k
assert prediction_parameters["topP"] == chat_top_p

chat2.send_message(
"Are my favorite movies based on a book series?",
temperature=message_temperature,
max_output_tokens=message_max_output_tokens,
top_k=message_top_k,
top_p=message_top_p,
)
prediction_parameters = mock_predict3.call_args[1]["parameters"]
assert prediction_parameters["temperature"] == message_temperature
assert prediction_parameters["maxDecodeSteps"] == message_max_output_tokens
assert prediction_parameters["topK"] == message_top_k
assert prediction_parameters["topP"] == message_top_p

def test_chat_ga(self):
"""Tests the chat generation model."""
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_CHAT_BISON_PUBLISHER_MODEL_DICT
),
) as mock_get_publisher_model:
model = language_models.ChatModel.from_pretrained("chat-bison@001")

mock_get_publisher_model.assert_called_once_with(
name="publishers/google/models/chat-bison@001", retry=base._DEFAULT_RETRY
)

chat = model.start_chat(
context="""
My name is Ned.
You are my personal assistant.
My favorite movies are Lord of the Rings and Hobbit.
""",
examples=[
language_models.InputOutputTextPair(
input_text="Who do you work for?",
output_text="I work for Ned.",
),
language_models.InputOutputTextPair(
input_text="What do I like?",
output_text="Ned likes watching movies.",
),
],
message_history=[
language_models.ChatMessage(
author=preview_language_models.ChatSession.USER_AUTHOR,
content="Question 1?",
),
language_models.ChatMessage(
author=preview_language_models.ChatSession.MODEL_AUTHOR,
content="Answer 1.",
),
],
temperature=0.0,
)

gca_predict_response1 = gca_prediction_service.PredictResponse()
gca_predict_response1.predictions.append(_TEST_CHAT_GENERATION_PREDICTION1)

with mock.patch.object(
target=prediction_service_client.PredictionServiceClient,
attribute="predict",
return_value=gca_predict_response1,
):
message_text1 = "Are my favorite movies based on a book series?"
expected_response1 = _TEST_CHAT_GENERATION_PREDICTION1["candidates"][0][
"content"
]
response = chat.send_message(message_text1)
assert response.text == expected_response1
assert len(chat.message_history) == 4
assert chat.message_history[2].author == chat.USER_AUTHOR
assert chat.message_history[2].content == message_text1
assert chat.message_history[3].author == chat.MODEL_AUTHOR
assert chat.message_history[3].content == expected_response1

gca_predict_response2 = gca_prediction_service.PredictResponse()
gca_predict_response2.predictions.append(_TEST_CHAT_GENERATION_PREDICTION2)

with mock.patch.object(
target=prediction_service_client.PredictionServiceClient,
attribute="predict",
Expand Down
6 changes: 6 additions & 0 deletions vertexai/language_models/__init__.py
Expand Up @@ -15,6 +15,9 @@
"""Classes for working with language models."""

from vertexai.language_models._language_models import (
ChatMessage,
ChatModel,
ChatSession,
CodeChatModel,
CodeChatSession,
CodeGenerationModel,
Expand All @@ -26,6 +29,9 @@
)

__all__ = [
"ChatMessage",
"ChatModel",
"ChatSession",
"CodeChatModel",
"CodeChatSession",
"CodeGenerationModel",
Expand Down
6 changes: 5 additions & 1 deletion vertexai/language_models/_language_models.py
Expand Up @@ -584,7 +584,7 @@ class ChatMessage:
class _ChatModelBase(_LanguageModel):
"""_ChatModelBase is a base class for chat models."""

_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
_LAUNCH_STAGE = _model_garden_models._SDK_GA_LAUNCH_STAGE

def start_chat(
self,
Expand Down Expand Up @@ -653,6 +653,10 @@ class ChatModel(_ChatModelBase):
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/chat_generation_1.0.0.yaml"


class _PreviewChatModel(ChatModel):
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE


class CodeChatModel(_ChatModelBase):
"""CodeChatModel represents a model that is capable of completing code.
Expand Down
2 changes: 2 additions & 0 deletions vertexai/preview/language_models.py
Expand Up @@ -15,6 +15,7 @@
"""Classes for working with language models."""

from vertexai.language_models._language_models import (
_PreviewChatModel,
_PreviewTextEmbeddingModel,
_PreviewTextGenerationModel,
ChatMessage,
Expand All @@ -28,6 +29,7 @@
TextGenerationResponse,
)

ChatModel = _PreviewChatModel
TextGenerationModel = _PreviewTextGenerationModel
TextEmbeddingModel = _PreviewTextEmbeddingModel

Expand Down

0 comments on commit 22aa26d

Please sign in to comment.