From f0086dfd76c138443e50bc18ae49b232905468f3 Mon Sep 17 00:00:00 2001 From: Jaycee Li Date: Mon, 18 Mar 2024 17:00:11 -0700 Subject: [PATCH] fix: GenAI - Capture content blocked case when validating responses PiperOrigin-RevId: 616988650 --- tests/unit/vertexai/test_generative_models.py | 39 +++++++++++++++++++ .../generative_models/_generative_models.py | 24 ++++++++---- 2 files changed, 56 insertions(+), 7 deletions(-) diff --git a/tests/unit/vertexai/test_generative_models.py b/tests/unit/vertexai/test_generative_models.py index c18e42d126..ee32763fa7 100644 --- a/tests/unit/vertexai/test_generative_models.py +++ b/tests/unit/vertexai/test_generative_models.py @@ -133,6 +133,20 @@ def mock_generate_content( ) return response + should_block = ( + last_message_part.text + and "Please block with block_reason=OTHER" in last_message_part.text + ) + if should_block: + response = gapic_prediction_service_types.GenerateContentResponse( + candidates=[], + prompt_feedback=gapic_prediction_service_types.GenerateContentResponse.PromptFeedback( + block_reason=gapic_prediction_service_types.GenerateContentResponse.PromptFeedback.BlockedReason.OTHER, + block_reason_message="Blocked for testing", + ), + ) + return response + is_continued_chat = len(request.contents) > 1 has_retrieval = any( tool.retrieval or tool.google_search_retrieval for tool in request.tools @@ -349,6 +363,31 @@ def test_chat_send_message_response_validation_errors( # Checking that history did not get updated assert len(chat.history) == 2 + @mock.patch.object( + target=prediction_service.PredictionServiceClient, + attribute="generate_content", + new=mock_generate_content, + ) + @pytest.mark.parametrize( + "generative_models", + [generative_models, preview_generative_models], + ) + def test_chat_send_message_response_blocked_errors( + self, generative_models: generative_models + ): + model = generative_models.GenerativeModel("gemini-pro") + chat = model.start_chat() + response1 = chat.send_message("Why is sky blue?") + assert response1.text + assert len(chat.history) == 2 + + with pytest.raises(generative_models.ResponseValidationError) as e: + chat.send_message("Please block with block_reason=OTHER.") + + assert e.match("Blocked for testing") + # Checking that history did not get updated + assert len(chat.history) == 2 + @mock.patch.object( target=prediction_service.PredictionServiceClient, attribute="generate_content", diff --git a/vertexai/generative_models/_generative_models.py b/vertexai/generative_models/_generative_models.py index bf9dfaec0f..9389152327 100644 --- a/vertexai/generative_models/_generative_models.py +++ b/vertexai/generative_models/_generative_models.py @@ -640,13 +640,23 @@ def _validate_response( request_contents: Optional[List["Content"]] = None, response_chunks: Optional[List["GenerationResponse"]] = None, ) -> None: - candidate = response.candidates[0] - if candidate.finish_reason not in _SUCCESSFUL_FINISH_REASONS: - message = ( - "The model response did not completed successfully.\n" - f"Finish reason: {candidate.finish_reason}.\n" - f"Finish message: {candidate.finish_message}.\n" - f"Safety ratings: {candidate.safety_ratings}.\n" + message = "" + if not response.candidates: + message += ( + f"The model response was blocked due to {response._raw_response.prompt_feedback.block_reason}.\n" + f"Blocke reason message: {response._raw_response.prompt_feedback.block_reason_message}.\n" + ) + else: + candidate = response.candidates[0] + if candidate.finish_reason not in _SUCCESSFUL_FINISH_REASONS: + message = ( + "The model response did not completed successfully.\n" + f"Finish reason: {candidate.finish_reason}.\n" + f"Finish message: {candidate.finish_message}.\n" + f"Safety ratings: {candidate.safety_ratings}.\n" + ) + if message: + message += ( "To protect the integrity of the chat session, the request and response were not added to chat history.\n" "To skip the response validation, specify `model.start_chat(response_validation=False)`.\n" "Note that letting blocked or otherwise incomplete responses into chat history might lead to future interactions being blocked by the service."