Skip to content

Commit

Permalink
fix: LLM - Fixed parameters set in ChatModel.start_chat being ignored
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 537204205
  • Loading branch information
Ark-kun authored and Copybara-Service committed Jun 2, 2023
1 parent ed1f747 commit a0d815d
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 16 deletions.
45 changes: 45 additions & 0 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,51 @@ def test_chat(self):
)
assert len(chat._history) == 2

# Validating the parameters
chat_temperature = 0.1
chat_max_output_tokens = 100
chat_top_k = 1
chat_top_p = 0.1
message_temperature = 0.2
message_max_output_tokens = 200
message_top_k = 2
message_top_p = 0.2

chat2 = model.start_chat(
temperature=chat_temperature,
max_output_tokens=chat_max_output_tokens,
top_k=chat_top_k,
top_p=chat_top_p,
)

gca_predict_response3 = gca_prediction_service.PredictResponse()
gca_predict_response3.predictions.append(_TEST_CHAT_GENERATION_PREDICTION1)

with mock.patch.object(
target=prediction_service_client.PredictionServiceClient,
attribute="predict",
return_value=gca_predict_response3,
) as mock_predict3:
chat2.send_message("Are my favorite movies based on a book series?")
prediction_parameters = mock_predict3.call_args[1]["parameters"]
assert prediction_parameters["temperature"] == chat_temperature
assert prediction_parameters["maxDecodeSteps"] == chat_max_output_tokens
assert prediction_parameters["topK"] == chat_top_k
assert prediction_parameters["topP"] == chat_top_p

chat2.send_message(
"Are my favorite movies based on a book series?",
temperature=message_temperature,
max_output_tokens=message_max_output_tokens,
top_k=message_top_k,
top_p=message_top_p,
)
prediction_parameters = mock_predict3.call_args[1]["parameters"]
assert prediction_parameters["temperature"] == message_temperature
assert prediction_parameters["maxDecodeSteps"] == message_max_output_tokens
assert prediction_parameters["topK"] == message_top_k
assert prediction_parameters["topP"] == message_top_p

def test_text_embedding(self):
"""Tests the text embedding model."""
aiplatform.init(
Expand Down
46 changes: 30 additions & 16 deletions vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,19 +460,23 @@ def send_message(
self,
message: str,
*,
max_output_tokens: int = TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS,
temperature: float = TextGenerationModel._DEFAULT_TEMPERATURE,
top_k: int = TextGenerationModel._DEFAULT_TOP_K,
top_p: float = TextGenerationModel._DEFAULT_TOP_P,
max_output_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
) -> "TextGenerationResponse":
"""Sends message to the language model and gets a response.
Args:
message: Message to send to the model
max_output_tokens: Max length of the output text in tokens.
Uses the value specified when calling `ChatModel.start_chat` by default.
temperature: Controls the randomness of predictions. Range: [0, 1].
Uses the value specified when calling `ChatModel.start_chat` by default.
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering.
Uses the value specified when calling `ChatModel.start_chat` by default.
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1].
Uses the value specified when calling `ChatModel.start_chat` by default.
Returns:
A `TextGenerationResponse` object that contains the text produced by the model.
Expand All @@ -484,10 +488,12 @@ def send_message(

response_obj = self._model.predict(
prompt=new_history_text,
max_output_tokens=max_output_tokens or self._max_output_tokens,
temperature=temperature or self._temperature,
top_k=top_k or self._top_k,
top_p=top_p or self._top_p,
max_output_tokens=max_output_tokens
if max_output_tokens is not None
else self._max_output_tokens,
temperature=temperature if temperature is not None else self._temperature,
top_k=top_k if top_k is not None else self._top_k,
top_p=top_p if top_p is not None else self._top_p,
)
response_text = response_obj.text

Expand Down Expand Up @@ -636,28 +642,36 @@ def send_message(
self,
message: str,
*,
max_output_tokens: int = TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS,
temperature: float = TextGenerationModel._DEFAULT_TEMPERATURE,
top_k: int = TextGenerationModel._DEFAULT_TOP_K,
top_p: float = TextGenerationModel._DEFAULT_TOP_P,
max_output_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
) -> "TextGenerationResponse":
"""Sends message to the language model and gets a response.
Args:
message: Message to send to the model
max_output_tokens: Max length of the output text in tokens.
Uses the value specified when calling `ChatModel.start_chat` by default.
temperature: Controls the randomness of predictions. Range: [0, 1].
Uses the value specified when calling `ChatModel.start_chat` by default.
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering.
Uses the value specified when calling `ChatModel.start_chat` by default.
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1].
Uses the value specified when calling `ChatModel.start_chat` by default.
Returns:
A `TextGenerationResponse` object that contains the text produced by the model.
"""
prediction_parameters = {
"temperature": temperature,
"maxDecodeSteps": max_output_tokens,
"topP": top_p,
"topK": top_k,
"temperature": temperature
if temperature is not None
else self._temperature,
"maxDecodeSteps": max_output_tokens
if max_output_tokens is not None
else self._max_output_tokens,
"topP": top_p if top_p is not None else self._top_p,
"topK": top_k if top_k is not None else self._top_k,
}
messages = []
for input_text, output_text in self._history:
Expand Down

0 comments on commit a0d815d

Please sign in to comment.