Skip to content

Commit

Permalink
feat: LLM - Support streaming prediction for code chat models
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 558364254
  • Loading branch information
Ark-kun authored and Copybara-Service committed Aug 19, 2023
1 parent 3a8348b commit 0359f1d
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 8 deletions.
13 changes: 12 additions & 1 deletion tests/system/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,19 @@ def test_code_generation_streaming(self):

for response in model.predict_streaming(
prefix="def reverse_string(s):",
suffix=" return s",
# code-bison does not support suffix
# suffix=" return s",
max_output_tokens=128,
temperature=0,
):
assert response.text

def test_code_chat_model_send_message_streaming(self):
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)

chat_model = language_models.ChatModel.from_pretrained("codeodechat-bison@001")
chat = chat_model.start_chat()

message1 = "Please help write a function to calculate the max of two numbers"
for response in chat.send_message_streaming(message1):
assert response.text
45 changes: 45 additions & 0 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1938,6 +1938,51 @@ def test_code_chat(self):
assert prediction_parameters["temperature"] == message_temperature
assert prediction_parameters["maxDecodeSteps"] == message_max_output_tokens

def test_code_chat_model_send_message_streaming(self):
"""Tests the chat generation model."""
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_CODECHAT_BISON_PUBLISHER_MODEL_DICT
),
):
model = language_models.CodeChatModel.from_pretrained("codechat-bison@001")

chat = model.start_chat(temperature=0.0)

# Using list instead of a generator so that it can be reused.
response_generator = [
gca_prediction_service.StreamingPredictResponse(
outputs=[_streaming_prediction.value_to_tensor(response_dict)]
)
for response_dict in _TEST_CHAT_PREDICTION_STREAMING
]

with mock.patch.object(
target=prediction_service_client.PredictionServiceClient,
attribute="server_streaming_predict",
return_value=response_generator,
):
message_text1 = (
"Please help write a function to calculate the max of two numbers"
)
# New messages are not added until the response is fully read
assert not chat.message_history
for response in chat.send_message_streaming(message_text1):
assert len(response.text) > 10
# New messages are only added after the response is fully read
assert chat.message_history

assert len(chat.message_history) == 2
assert chat.message_history[0].author == chat.USER_AUTHOR
assert chat.message_history[0].content == message_text1
assert chat.message_history[1].author == chat.MODEL_AUTHOR

def test_code_generation(self):
"""Tests code generation with the code generation model."""
aiplatform.init(
Expand Down
35 changes: 28 additions & 7 deletions vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,6 @@ def _get_model_id_from_tuning_model_id(tuning_model_id: str) -> str:
return f"publishers/google/models/{model_name}@{version}"


@dataclasses.dataclass
class _PredictionRequest:
"""A single-instance prediction request."""
instance: Dict[str, Any]
parameters: Optional[Dict[str, Any]] = None


class _LanguageModel(_model_garden_models._ModelGardenModel):
"""_LanguageModel is a base class for all language models."""

Expand Down Expand Up @@ -1234,6 +1227,34 @@ def send_message(
temperature=temperature,
)

def send_message_streaming(
self,
message: str,
*,
max_output_tokens: Optional[int] = None,
temperature: Optional[float] = None,
) -> Iterator[TextGenerationResponse]:
"""Sends message to the language model and gets a streamed response.
The response is only added to the history once it's fully read.
Args:
message: Message to send to the model
max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
Uses the value specified when calling `ChatModel.start_chat` by default.
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
Uses the value specified when calling `ChatModel.start_chat` by default.
Returns:
A stream of `TextGenerationResponse` objects that contain partial
responses produced by the model.
"""
return super().send_message_streaming(
message=message,
max_output_tokens=max_output_tokens,
temperature=temperature,
)


class CodeGenerationModel(_LanguageModel):
"""A language model that generates code.
Expand Down

0 comments on commit 0359f1d

Please sign in to comment.