Skip to content

Commit

Permalink
feat: LLM - Added support for multiple chat response candidates
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 572100735
  • Loading branch information
Ark-kun authored and Copybara-Service committed Oct 10, 2023
1 parent e76abd3 commit 587df74
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 22 deletions.
71 changes: 71 additions & 0 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,30 @@
}
],
}
_TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION = {
"safetyAttributes": [
{
"scores": [],
"categories": [],
"blocked": False,
},
{
"scores": [0.1],
"categories": ["Finance"],
"blocked": True,
},
],
"candidates": [
{
"author": "1",
"content": "Chat response 2",
},
{
"author": "1",
"content": "",
},
],
}

_TEST_CHAT_PREDICTION_STREAMING = [
{
Expand Down Expand Up @@ -2076,6 +2100,53 @@ def test_chat_ga(self):
assert prediction_parameters["topP"] == message_top_p
assert prediction_parameters["stopSequences"] == message_stop_sequences

def test_chat_model_send_message_with_multiple_candidates(self):
"""Tests the chat generation model with multiple candidates."""

with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_CHAT_BISON_PUBLISHER_MODEL_DICT
),
) as mock_get_publisher_model:
model = language_models.ChatModel.from_pretrained("chat-bison@001")

mock_get_publisher_model.assert_called_once_with(
name="publishers/google/models/chat-bison@001", retry=base._DEFAULT_RETRY
)

chat = model.start_chat()

gca_predict_response1 = gca_prediction_service.PredictResponse()
gca_predict_response1.predictions.append(
_TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION
)

with mock.patch.object(
target=prediction_service_client.PredictionServiceClient,
attribute="predict",
return_value=gca_predict_response1,
):
message_text1 = "Are my favorite movies based on a book series?"
expected_response_candidates = (
_TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION["candidates"]
)
expected_candidate_0 = expected_response_candidates[0]["content"]
expected_candidate_1 = expected_response_candidates[1]["content"]

response = chat.send_message(message_text1, candidate_count=2)
assert response.text == expected_candidate_0
assert len(response.candidates) == 2
assert response.candidates[0].text == expected_candidate_0
assert response.candidates[1].text == expected_candidate_1

assert len(chat.message_history) == 2
assert chat.message_history[0].author == chat.USER_AUTHOR
assert chat.message_history[0].content == message_text1
assert chat.message_history[1].author == chat.MODEL_AUTHOR
assert chat.message_history[1].content == expected_candidate_0

def test_chat_model_send_message_streaming(self):
"""Tests the chat generation model."""
with mock.patch.object(
Expand Down
64 changes: 42 additions & 22 deletions vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1615,6 +1615,7 @@ def _prepare_request(
top_k: Optional[int] = None,
top_p: Optional[float] = None,
stop_sequences: Optional[List[str]] = None,
candidate_count: Optional[int] = None,
) -> _PredictionRequest:
"""Prepares a request for the language model.
Expand All @@ -1629,6 +1630,7 @@ def _prepare_request(
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
Uses the value specified when calling `ChatModel.start_chat` by default.
stop_sequences: Customized stop sequences to stop the decoding process.
candidate_count: Number of candidates to return.
Returns:
A `_PredictionRequest` object.
Expand Down Expand Up @@ -1660,6 +1662,9 @@ def _prepare_request(
if stop_sequences:
prediction_parameters["stopSequences"] = stop_sequences

if candidate_count is not None:
prediction_parameters["candidateCount"] = candidate_count

message_structs = []
for past_message in self._message_history:
message_structs.append(
Expand Down Expand Up @@ -1697,8 +1702,7 @@ def _parse_chat_prediction_response(
cls,
prediction_response: aiplatform.models.Prediction,
prediction_idx: int = 0,
candidate_idx: int = 0,
) -> TextGenerationResponse:
) -> MultiCandidateTextGenerationResponse:
"""Parses prediction response for chat models.
Args:
Expand All @@ -1707,25 +1711,33 @@ def _parse_chat_prediction_response(
candidate_idx: Index of the candidate to parse.
Returns:
A `TextGenerationResponse` object.
A `MultiCandidateTextGenerationResponse` object.
"""
prediction = prediction_response.predictions[prediction_idx]
# ! Note: For chat models, the safetyAttributes is a list.
safety_attributes = prediction["safetyAttributes"][candidate_idx]
return TextGenerationResponse(
text=prediction["candidates"][candidate_idx]["content"]
if prediction.get("candidates")
else None,
candidate_count = len(prediction["candidates"])
candidates = []
for candidate_idx in range(candidate_count):
safety_attributes = prediction["safetyAttributes"][candidate_idx]
candidate_response = TextGenerationResponse(
text=prediction["candidates"][candidate_idx]["content"],
_prediction_response=prediction_response,
is_blocked=safety_attributes.get("blocked", False),
safety_attributes=dict(
zip(
# Unlike with normal prediction, in streaming prediction
# categories and scores can be None
safety_attributes.get("categories") or [],
safety_attributes.get("scores") or [],
)
),
)
candidates.append(candidate_response)
return MultiCandidateTextGenerationResponse(
text=candidates[0].text,
_prediction_response=prediction_response,
is_blocked=safety_attributes.get("blocked", False),
safety_attributes=dict(
zip(
# Unlike with normal prediction, in streaming prediction
# categories and scores can be None
safety_attributes.get("categories") or [],
safety_attributes.get("scores") or [],
)
),
is_blocked=candidates[0].is_blocked,
safety_attributes=candidates[0].safety_attributes,
candidates=candidates,
)

def send_message(
Expand All @@ -1737,7 +1749,8 @@ def send_message(
top_k: Optional[int] = None,
top_p: Optional[float] = None,
stop_sequences: Optional[List[str]] = None,
) -> "TextGenerationResponse":
candidate_count: Optional[int] = None,
) -> "MultiCandidateTextGenerationResponse":
"""Sends message to the language model and gets a response.
Args:
Expand All @@ -1751,9 +1764,11 @@ def send_message(
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
Uses the value specified when calling `ChatModel.start_chat` by default.
stop_sequences: Customized stop sequences to stop the decoding process.
candidate_count: Number of candidates to return.
Returns:
A `TextGenerationResponse` object that contains the text produced by the model.
A `MultiCandidateTextGenerationResponse` object that contains the
text produced by the model.
"""
prediction_request = self._prepare_request(
message=message,
Expand All @@ -1762,6 +1777,7 @@ def send_message(
top_k=top_k,
top_p=top_p,
stop_sequences=stop_sequences,
candidate_count=candidate_count,
)

prediction_response = self._model._endpoint.predict(
Expand Down Expand Up @@ -1791,7 +1807,8 @@ async def send_message_async(
top_k: Optional[int] = None,
top_p: Optional[float] = None,
stop_sequences: Optional[List[str]] = None,
) -> "TextGenerationResponse":
candidate_count: Optional[int] = None,
) -> "MultiCandidateTextGenerationResponse":
"""Asynchronously sends message to the language model and gets a response.
Args:
Expand All @@ -1805,9 +1822,11 @@ async def send_message_async(
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
Uses the value specified when calling `ChatModel.start_chat` by default.
stop_sequences: Customized stop sequences to stop the decoding process.
candidate_count: Number of candidates to return.
Returns:
A `TextGenerationResponse` object that contains the text produced by the model.
A `MultiCandidateTextGenerationResponse` object that contains
the text produced by the model.
"""
prediction_request = self._prepare_request(
message=message,
Expand All @@ -1816,6 +1835,7 @@ async def send_message_async(
top_k=top_k,
top_p=top_p,
stop_sequences=stop_sequences,
candidate_count=candidate_count,
)

prediction_response = await self._model._endpoint.predict_async(
Expand Down

0 comments on commit 587df74

Please sign in to comment.