diff --git a/tests/system/aiplatform/test_language_models.py b/tests/system/aiplatform/test_language_models.py index b991207257..5f6e9c369c 100644 --- a/tests/system/aiplatform/test_language_models.py +++ b/tests/system/aiplatform/test_language_models.py @@ -260,8 +260,19 @@ def test_code_generation_streaming(self): for response in model.predict_streaming( prefix="def reverse_string(s):", - suffix=" return s", + # code-bison does not support suffix + # suffix=" return s", max_output_tokens=128, temperature=0, ): assert response.text + + def test_code_chat_model_send_message_streaming(self): + aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) + + chat_model = language_models.ChatModel.from_pretrained("codeodechat-bison@001") + chat = chat_model.start_chat() + + message1 = "Please help write a function to calculate the max of two numbers" + for response in chat.send_message_streaming(message1): + assert response.text diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py index 9deaaebd3c..cfc2e2383c 100644 --- a/tests/unit/aiplatform/test_language_models.py +++ b/tests/unit/aiplatform/test_language_models.py @@ -1938,6 +1938,51 @@ def test_code_chat(self): assert prediction_parameters["temperature"] == message_temperature assert prediction_parameters["maxDecodeSteps"] == message_max_output_tokens + def test_code_chat_model_send_message_streaming(self): + """Tests the chat generation model.""" + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + with mock.patch.object( + target=model_garden_service_client.ModelGardenServiceClient, + attribute="get_publisher_model", + return_value=gca_publisher_model.PublisherModel( + _CODECHAT_BISON_PUBLISHER_MODEL_DICT + ), + ): + model = language_models.CodeChatModel.from_pretrained("codechat-bison@001") + + chat = model.start_chat(temperature=0.0) + + # Using list instead of a generator so that it can be reused. + response_generator = [ + gca_prediction_service.StreamingPredictResponse( + outputs=[_streaming_prediction.value_to_tensor(response_dict)] + ) + for response_dict in _TEST_CHAT_PREDICTION_STREAMING + ] + + with mock.patch.object( + target=prediction_service_client.PredictionServiceClient, + attribute="server_streaming_predict", + return_value=response_generator, + ): + message_text1 = ( + "Please help write a function to calculate the max of two numbers" + ) + # New messages are not added until the response is fully read + assert not chat.message_history + for response in chat.send_message_streaming(message_text1): + assert len(response.text) > 10 + # New messages are only added after the response is fully read + assert chat.message_history + + assert len(chat.message_history) == 2 + assert chat.message_history[0].author == chat.USER_AUTHOR + assert chat.message_history[0].content == message_text1 + assert chat.message_history[1].author == chat.MODEL_AUTHOR + def test_code_generation(self): """Tests code generation with the code generation model.""" aiplatform.init( diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index 483b8566f5..85e63d830f 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -59,13 +59,6 @@ def _get_model_id_from_tuning_model_id(tuning_model_id: str) -> str: return f"publishers/google/models/{model_name}@{version}" -@dataclasses.dataclass -class _PredictionRequest: - """A single-instance prediction request.""" - instance: Dict[str, Any] - parameters: Optional[Dict[str, Any]] = None - - class _LanguageModel(_model_garden_models._ModelGardenModel): """_LanguageModel is a base class for all language models.""" @@ -1234,6 +1227,34 @@ def send_message( temperature=temperature, ) + def send_message_streaming( + self, + message: str, + *, + max_output_tokens: Optional[int] = None, + temperature: Optional[float] = None, + ) -> Iterator[TextGenerationResponse]: + """Sends message to the language model and gets a streamed response. + + The response is only added to the history once it's fully read. + + Args: + message: Message to send to the model + max_output_tokens: Max length of the output text in tokens. Range: [1, 1024]. + Uses the value specified when calling `ChatModel.start_chat` by default. + temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0. + Uses the value specified when calling `ChatModel.start_chat` by default. + + Returns: + A stream of `TextGenerationResponse` objects that contain partial + responses produced by the model. + """ + return super().send_message_streaming( + message=message, + max_output_tokens=max_output_tokens, + temperature=temperature, + ) + class CodeGenerationModel(_LanguageModel): """A language model that generates code.