Skip to content

Commit

Permalink
feat: LLM - Added support for multiple response candidates in code ch…
Browse files Browse the repository at this point in the history
…at models

PiperOrigin-RevId: 573371030
  • Loading branch information
Ark-kun authored and Copybara-Service committed Oct 14, 2023
1 parent 0c371a4 commit 598d57d
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 4 deletions.
51 changes: 51 additions & 0 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2419,6 +2419,57 @@ def test_code_chat(self):
assert prediction_parameters["maxDecodeSteps"] == message_max_output_tokens
assert prediction_parameters["stopSequences"] == message_stop_sequences

def test_code_chat_model_send_message_with_multiple_candidates(self):
"""Tests the code chat model with multiple candidates."""
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_CODECHAT_BISON_PUBLISHER_MODEL_DICT
),
autospec=True,
):
model = language_models.CodeChatModel.from_pretrained(
"google/codechat-bison@001"
)

chat = model.start_chat()

gca_predict_response1 = gca_prediction_service.PredictResponse()
gca_predict_response1.predictions.append(
_TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION
)

with mock.patch.object(
target=prediction_service_client.PredictionServiceClient,
attribute="predict",
return_value=gca_predict_response1,
autospec=True,
):
message_text1 = "Are my favorite movies based on a book series?"
expected_response_candidates = (
_TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION["candidates"]
)
expected_candidate_0 = expected_response_candidates[0]["content"]
expected_candidate_1 = expected_response_candidates[1]["content"]

response = chat.send_message(
message=message_text1,
# candidate_count acts as a maximum number, not exact number.
candidate_count=7,
)
# The service can return a different number of candidates.
assert response.text == expected_candidate_0
assert len(response.candidates) == 2
assert response.candidates[0].text == expected_candidate_0
assert response.candidates[1].text == expected_candidate_1

assert len(chat.message_history) == 2
assert chat.message_history[0].author == chat.USER_AUTHOR
assert chat.message_history[0].content == message_text1
assert chat.message_history[1].author == chat.MODEL_AUTHOR
assert chat.message_history[1].content == expected_candidate_0

def test_code_chat_model_send_message_streaming(self):
"""Tests the chat generation model."""
aiplatform.init(
Expand Down
16 changes: 12 additions & 4 deletions vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2112,7 +2112,8 @@ def send_message(
max_output_tokens: Optional[int] = None,
temperature: Optional[float] = None,
stop_sequences: Optional[List[str]] = None,
) -> "TextGenerationResponse":
candidate_count: Optional[int] = None,
) -> "MultiCandidateTextGenerationResponse":
"""Sends message to the code chat model and gets a response.
Args:
Expand All @@ -2122,15 +2123,18 @@ def send_message(
temperature: Controls the randomness of predictions. Range: [0, 1].
Uses the value specified when calling `CodeChatModel.start_chat` by default.
stop_sequences: Customized stop sequences to stop the decoding process.
candidate_count: Number of candidates to return.
Returns:
A `TextGenerationResponse` object that contains the text produced by the model.
A `MultiCandidateTextGenerationResponse` object that contains the
text produced by the model.
"""
return super().send_message(
message=message,
max_output_tokens=max_output_tokens,
temperature=temperature,
stop_sequences=stop_sequences,
candidate_count=candidate_count,
)

async def send_message_async(
Expand All @@ -2139,7 +2143,8 @@ async def send_message_async(
*,
max_output_tokens: Optional[int] = None,
temperature: Optional[float] = None,
) -> "TextGenerationResponse":
candidate_count: Optional[int] = None,
) -> "MultiCandidateTextGenerationResponse":
"""Asynchronously sends message to the code chat model and gets a response.
Args:
Expand All @@ -2148,14 +2153,17 @@ async def send_message_async(
Uses the value specified when calling `CodeChatModel.start_chat` by default.
temperature: Controls the randomness of predictions. Range: [0, 1].
Uses the value specified when calling `CodeChatModel.start_chat` by default.
candidate_count: Number of candidates to return.
Returns:
A `TextGenerationResponse` object that contains the text produced by the model.
A `MultiCandidateTextGenerationResponse` object that contains the
text produced by the model.
"""
return super().send_message_async(
message=message,
max_output_tokens=max_output_tokens,
temperature=temperature,
candidate_count=candidate_count,
)

def send_message_streaming(
Expand Down

0 comments on commit 598d57d

Please sign in to comment.