Skip to content

Commit

Permalink
fix: async call bug in CodeChatModel.send_message_async method
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 581045937
  • Loading branch information
vertex-sdk-bot authored and Copybara-Service committed Nov 9, 2023
1 parent 2e57983 commit fcf05cb
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 1 deletion.
48 changes: 48 additions & 0 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3182,6 +3182,54 @@ def test_code_chat_model_send_message_with_multiple_candidates(self):
assert chat.message_history[1].author == chat.MODEL_AUTHOR
assert chat.message_history[1].content == expected_candidate_0

async def test_code_chat_model_send_message_async(self):
"""Tests the send_message_async method for code chat model."""
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_CODECHAT_BISON_PUBLISHER_MODEL_DICT
),
):
model = language_models.CodeChatModel.from_pretrained("codechat-bison@001")
chat = model.start_chat()

gca_predict_response = gca_prediction_service.PredictResponse()
gca_predict_response.predictions.append(
_TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION
)
with mock.patch.object(
target=prediction_service_async_client.PredictionServiceAsyncClient,
attribute="predict",
return_value=gca_predict_response,
autospec=True,
):
message_text = "Are my favorite movies based on a book series?"
expected_response_candidates = (
_TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION["candidates"]
)
expected_candidate_0 = expected_response_candidates[0]["content"]
expected_candidate_1 = expected_response_candidates[1]["content"]

response = await chat.send_message_async(
message=message_text,
)
# The service can return a different number of candidates.
assert response.text == expected_candidate_0
assert len(response.candidates) == 2
assert response.candidates[0].text == expected_candidate_0
assert response.candidates[1].text == expected_candidate_1

assert len(chat.message_history) == 2
assert chat.message_history[0].author == chat.USER_AUTHOR
assert chat.message_history[0].content == message_text
assert chat.message_history[1].author == chat.MODEL_AUTHOR
assert chat.message_history[1].content == expected_candidate_0

def test_code_chat_model_send_message_streaming(self):
"""Tests the chat generation model."""
aiplatform.init(
Expand Down
3 changes: 2 additions & 1 deletion vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2558,12 +2558,13 @@ async def send_message_async(
A `MultiCandidateTextGenerationResponse` object that contains the
text produced by the model.
"""
return super().send_message_async(
response = await super().send_message_async(
message=message,
max_output_tokens=max_output_tokens,
temperature=temperature,
candidate_count=candidate_count,
)
return response

def send_message_streaming(
self,
Expand Down

0 comments on commit fcf05cb

Please sign in to comment.