Skip to content

Commit

Permalink
fix: GenAI - Workaround for streaming when content role is missing in…
Browse files Browse the repository at this point in the history
… service responses

PiperOrigin-RevId: 623296418
  • Loading branch information
Ark-kun authored and Copybara-Service committed Apr 9, 2024
1 parent 40b728b commit fa35b91
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 6 deletions.
77 changes: 73 additions & 4 deletions tests/unit/vertexai/test_generative_models.py
Expand Up @@ -119,7 +119,7 @@ def mock_generate_content(
*,
model: Optional[str] = None,
contents: Optional[MutableSequence[gapic_content_types.Content]] = None,
) -> Iterable[gapic_prediction_service_types.GenerateContentResponse]:
) -> gapic_prediction_service_types.GenerateContentResponse:
last_message_part = request.contents[-1].parts[0]
should_fail = last_message_part.text and "Please fail" in last_message_part.text
if should_fail:
Expand Down Expand Up @@ -203,8 +203,7 @@ def mock_generate_content(
gapic_content_types.Candidate(
index=0,
content=gapic_content_types.Content(
# Model currently does not identify itself
# role="model",
role="model",
parts=[
gapic_content_types.Part(response_part_struct),
],
Expand Down Expand Up @@ -240,6 +239,13 @@ def mock_generate_content(
),
],
)

if "Please block response with finish_reason=OTHER" in (
last_message_part.text or ""
):
finish_reason = gapic_content_types.Candidate.FinishReason.OTHER
response.candidates[0].finish_reason = finish_reason

return response


Expand All @@ -250,9 +256,32 @@ def mock_stream_generate_content(
model: Optional[str] = None,
contents: Optional[MutableSequence[gapic_content_types.Content]] = None,
) -> Iterable[gapic_prediction_service_types.GenerateContentResponse]:
yield mock_generate_content(
response = mock_generate_content(
self=self, request=request, model=model, contents=contents
)
# When a streaming response gets blocked, the last chunk has no content.
# Creating such last chunk.
blocked_chunk = None
candidate_0 = response.candidates[0] if response.candidates else None
if candidate_0 and candidate_0.finish_reason not in (
gapic_content_types.Candidate.FinishReason.STOP,
gapic_content_types.Candidate.FinishReason.MAX_TOKENS,
):
blocked_chunk = gapic_prediction_service_types.GenerateContentResponse(
candidates=[
gapic_content_types.Candidate(
index=0,
finish_reason=candidate_0.finish_reason,
finish_message=candidate_0.finish_message,
safety_ratings=candidate_0.safety_ratings,
)
]
)
candidate_0.finish_reason = None
candidate_0.finish_message = None
yield response
if blocked_chunk:
yield blocked_chunk


def get_current_weather(location: str, unit: Optional[str] = "centigrade"):
Expand Down Expand Up @@ -407,6 +436,25 @@ def test_chat_send_message(self, generative_models: generative_models):
response2 = chat.send_message("Is sky blue on other planets?")
assert response2.text

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="stream_generate_content",
new=mock_stream_generate_content,
)
@pytest.mark.parametrize(
"generative_models",
[generative_models, preview_generative_models],
)
def test_chat_send_message_streaming(self, generative_models: generative_models):
model = generative_models.GenerativeModel("gemini-pro")
chat = model.start_chat()
stream1 = chat.send_message("Why is sky blue?", stream=True)
for chunk in stream1:
assert chunk.candidates
stream2 = chat.send_message("Is sky blue on other planets?", stream=True)
for chunk in stream2:
assert chunk.candidates

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="generate_content",
Expand Down Expand Up @@ -455,6 +503,27 @@ def test_chat_send_message_response_blocked_errors(
# Checking that history did not get updated
assert len(chat.history) == 2

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="generate_content",
new=mock_generate_content,
)
@pytest.mark.parametrize(
"generative_models",
[generative_models, preview_generative_models],
)
def test_chat_send_message_response_candidate_blocked_error(
self, generative_models: generative_models
):
model = generative_models.GenerativeModel("gemini-pro")
chat = model.start_chat()

with pytest.raises(generative_models.ResponseValidationError):
chat.send_message("Please block response with finish_reason=OTHER.")

# Checking that history did not get updated
assert not chat.history

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="generate_content",
Expand Down
8 changes: 6 additions & 2 deletions vertexai/generative_models/_generative_models.py
Expand Up @@ -2113,7 +2113,9 @@ def _append_gapic_candidate(
f"Incorrect candidate indexes: {base_candidate.index} != {new_candidate.index}"
)

_append_gapic_content(base_candidate.content, new_candidate.content)
# Only merge content if it exists.
if "content" in new_candidate:
_append_gapic_content(base_candidate.content, new_candidate.content)

# For these attributes, the last value wins
if new_candidate.finish_reason:
Expand All @@ -2130,7 +2132,9 @@ def _append_gapic_content(
base_content: gapic_content_types.Content,
new_content: gapic_content_types.Content,
):
if base_content.role != new_content.role:
# Handling empty role is a workaround for a case when service returns
# some chunks with missing role field (e.g. when response is blocked).
if new_content.role and base_content.role != new_content.role:
raise ValueError(
f"Content roles do not match: {base_content.role} != {new_content.role}"
)
Expand Down

0 comments on commit fa35b91

Please sign in to comment.