Skip to content

Commit

Permalink
fix: GenAI - Capture content blocked case when validating responses
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 616988650
  • Loading branch information
jaycee-li authored and Copybara-Service committed Mar 19, 2024
1 parent 2dc7f41 commit f0086df
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 7 deletions.
39 changes: 39 additions & 0 deletions tests/unit/vertexai/test_generative_models.py
Expand Up @@ -133,6 +133,20 @@ def mock_generate_content(
)
return response

should_block = (
last_message_part.text
and "Please block with block_reason=OTHER" in last_message_part.text
)
if should_block:
response = gapic_prediction_service_types.GenerateContentResponse(
candidates=[],
prompt_feedback=gapic_prediction_service_types.GenerateContentResponse.PromptFeedback(
block_reason=gapic_prediction_service_types.GenerateContentResponse.PromptFeedback.BlockedReason.OTHER,
block_reason_message="Blocked for testing",
),
)
return response

is_continued_chat = len(request.contents) > 1
has_retrieval = any(
tool.retrieval or tool.google_search_retrieval for tool in request.tools
Expand Down Expand Up @@ -349,6 +363,31 @@ def test_chat_send_message_response_validation_errors(
# Checking that history did not get updated
assert len(chat.history) == 2

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="generate_content",
new=mock_generate_content,
)
@pytest.mark.parametrize(
"generative_models",
[generative_models, preview_generative_models],
)
def test_chat_send_message_response_blocked_errors(
self, generative_models: generative_models
):
model = generative_models.GenerativeModel("gemini-pro")
chat = model.start_chat()
response1 = chat.send_message("Why is sky blue?")
assert response1.text
assert len(chat.history) == 2

with pytest.raises(generative_models.ResponseValidationError) as e:
chat.send_message("Please block with block_reason=OTHER.")

assert e.match("Blocked for testing")
# Checking that history did not get updated
assert len(chat.history) == 2

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="generate_content",
Expand Down
24 changes: 17 additions & 7 deletions vertexai/generative_models/_generative_models.py
Expand Up @@ -640,13 +640,23 @@ def _validate_response(
request_contents: Optional[List["Content"]] = None,
response_chunks: Optional[List["GenerationResponse"]] = None,
) -> None:
candidate = response.candidates[0]
if candidate.finish_reason not in _SUCCESSFUL_FINISH_REASONS:
message = (
"The model response did not completed successfully.\n"
f"Finish reason: {candidate.finish_reason}.\n"
f"Finish message: {candidate.finish_message}.\n"
f"Safety ratings: {candidate.safety_ratings}.\n"
message = ""
if not response.candidates:
message += (
f"The model response was blocked due to {response._raw_response.prompt_feedback.block_reason}.\n"
f"Blocke reason message: {response._raw_response.prompt_feedback.block_reason_message}.\n"
)
else:
candidate = response.candidates[0]
if candidate.finish_reason not in _SUCCESSFUL_FINISH_REASONS:
message = (
"The model response did not completed successfully.\n"
f"Finish reason: {candidate.finish_reason}.\n"
f"Finish message: {candidate.finish_message}.\n"
f"Safety ratings: {candidate.safety_ratings}.\n"
)
if message:
message += (
"To protect the integrity of the chat session, the request and response were not added to chat history.\n"
"To skip the response validation, specify `model.start_chat(response_validation=False)`.\n"
"Note that letting blocked or otherwise incomplete responses into chat history might lead to future interactions being blocked by the service."
Expand Down

0 comments on commit f0086df

Please sign in to comment.