diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index c9c6ac6500..fb9a3a5163 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -159,12 +159,7 @@ async def send_realtime(self, input: RealtimeInput): else: raise ValueError('Unsupported input type: %s' % type(input)) - def __build_full_text_response( - self, - text: str, - is_thought: bool = False, - grounding_metadata: types.GroundingMetadata | None = None, - ): + def __build_full_text_response(self, text: str): """Builds a full text response. The text should not be partial and the returned LlmResponse is not @@ -172,8 +167,6 @@ def __build_full_text_response( Args: text: The text to be included in the response. - is_thought: Whether the text is a thought. - grounding_metadata: The grounding metadata to include. Returns: An LlmResponse containing the full text. @@ -183,8 +176,6 @@ def __build_full_text_response( role='model', parts=[types.Part.from_text(text=text)], ), - grounding_metadata=grounding_metadata, - partial=False, live_session_id=self._gemini_session.session_id, ) @@ -197,7 +188,6 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: text = '' tool_call_parts = [] - pending_grounding_metadata = None async with Aclosing(self._gemini_session.receive()) as agen: # TODO(b/440101573): Reuse StreamingResponseAggregator to accumulate # partial content and emit responses as needed. @@ -213,10 +203,6 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: ) if message.server_content: content = message.server_content.model_turn - if message.server_content.grounding_metadata: - pending_grounding_metadata = ( - message.server_content.grounding_metadata - ) # Standalone grounding_metadata event (when content is empty) if ( @@ -229,9 +215,6 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: interrupted=message.server_content.interrupted, model_version=self._model_version, live_session_id=live_session_id, - turn_complete_reason=getattr( - message.server_content, 'turn_complete_reason', None - ), ) if content and content.parts: @@ -240,9 +223,6 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: interrupted=message.server_content.interrupted, model_version=self._model_version, live_session_id=live_session_id, - turn_complete_reason=getattr( - message.server_content, 'turn_complete_reason', None - ), ) # grounding_metadata is yielded again at turn_complete, # so avoid duplicating it here if turn_complete is true. @@ -344,14 +324,9 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: ) self._output_transcription_text = '' if message.server_content.turn_complete: - g_metadata_to_yield = pending_grounding_metadata if text: - yield self.__build_full_text_response( - text, is_thought, g_metadata_to_yield - ) + yield self.__build_full_text_response(text) text = '' - is_thought = False - g_metadata_to_yield = None if tool_call_parts: logger.debug('Returning aggregated tool_call_parts') yield LlmResponse( @@ -363,13 +338,9 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: yield LlmResponse( turn_complete=True, interrupted=message.server_content.interrupted, - grounding_metadata=message.server_content.grounding_metadata - or g_metadata_to_yield, + grounding_metadata=message.server_content.grounding_metadata, model_version=self._model_version, live_session_id=live_session_id, - turn_complete_reason=getattr( - message.server_content, 'turn_complete_reason', None - ), ) break # in case of empty content or parts, we still surface it diff --git a/src/google/adk/models/llm_response.py b/src/google/adk/models/llm_response.py index 333034565f..c921f197c3 100644 --- a/src/google/adk/models/llm_response.py +++ b/src/google/adk/models/llm_response.py @@ -81,12 +81,6 @@ class LlmResponse(BaseModel): Only used for streaming mode. """ - turn_complete_reason: Optional[types.TurnCompleteReason] = None - """The reason why the turn is complete. - - Only used for streaming mode. - """ - finish_reason: Optional[types.FinishReason] = None """The finish reason of the response.""" diff --git a/tests/unittests/models/test_gemini_llm_connection.py b/tests/unittests/models/test_gemini_llm_connection.py index 6d28c7a0df..58aace30ed 100644 --- a/tests/unittests/models/test_gemini_llm_connection.py +++ b/tests/unittests/models/test_gemini_llm_connection.py @@ -1262,200 +1262,3 @@ async def mock_receive_generator(): content_response = next((r for r in responses if r.content), None) assert content_response is not None assert content_response.content == mock_content - - -@pytest.mark.asyncio -async def test_receive_grounding_metadata_pending( - gemini_connection, mock_gemini_session -): - """Test that grounding metadata in partial chunks is pending and yielded on full text.""" - grounding_metadata = types.GroundingMetadata( - web_search_queries=['stock price of google'], - ) - - def make_msg(text=None, g_meta=None, tc=False): - msg = mock.Mock( - usage_metadata=None, - tool_call=None, - session_resumption_update=None, - go_away=None, - ) - msg.server_content = mock.Mock( - interrupted=False, - input_transcription=None, - output_transcription=None, - generation_complete=False, - turn_complete=tc, - grounding_metadata=g_meta, - model_turn=types.Content( - role='model', parts=[types.Part.from_text(text=text)] - ) - if text - else None, - ) - return msg - - msg1 = make_msg(text='hello', g_meta=grounding_metadata) - msg2 = make_msg(text=' world') - msg3 = make_msg(tc=True) - - async def gen(): - yield msg1 - yield msg2 - yield msg3 - - mock_gemini_session.receive = mock.Mock(return_value=gen()) - - responses = [resp async for resp in gemini_connection.receive()] - - # Expected responses: - # 1. Msg 1 partial (hello) with grounding_metadata - # 2. Msg 2 partial ( world) without grounding_metadata - # 3. Full text response (hello world) with PENDING grounding_metadata - # 4. Turn complete response without grounding_metadata (already cleared) - assert len(responses) == 4 - - assert responses[0].content.parts[0].text == 'hello' - assert responses[0].partial is True - assert responses[0].grounding_metadata == grounding_metadata - - assert responses[1].content.parts[0].text == ' world' - assert responses[1].partial is True - assert responses[1].grounding_metadata is None - - assert responses[2].content.parts[0].text == 'hello world' - assert responses[2].partial is False - assert responses[2].grounding_metadata == grounding_metadata - - assert responses[3].turn_complete is True - assert responses[3].grounding_metadata is None - - -@pytest.mark.asyncio -async def test_receive_populates_turn_complete_reason( - gemini_connection, mock_gemini_session -): - """Test that receive populates turn_complete_reason in LlmResponse.""" - mock_server_content = mock.create_autospec( - types.LiveServerContent, instance=True - ) - mock_server_content.model_turn = None - mock_server_content.grounding_metadata = None - mock_server_content.turn_complete = True - mock_server_content.interrupted = False - mock_server_content.input_transcription = None - mock_server_content.output_transcription = None - mock_server_content.generation_complete = False - mock_server_content.turn_complete_reason = ( - types.TurnCompleteReason.RESPONSE_REJECTED - ) - - mock_message = mock.create_autospec(types.LiveServerMessage, instance=True) - mock_message.usage_metadata = None - mock_message.server_content = mock_server_content - mock_message.tool_call = None - mock_message.session_resumption_update = None - mock_message.go_away = None - - async def mock_receive_generator(): - yield mock_message - - mock_gemini_session.receive = mock.Mock(return_value=mock_receive_generator()) - - responses = [resp async for resp in gemini_connection.receive()] - - assert len(responses) == 1 - assert responses[0].turn_complete is True - assert ( - responses[0].turn_complete_reason - == types.TurnCompleteReason.RESPONSE_REJECTED - ) - - -@pytest.mark.asyncio -async def test_receive_populates_turn_complete_reason_standalone_grounding( - gemini_connection, mock_gemini_session -): - """Test that receive populates turn_complete_reason in LlmResponse for standalone grounding metadata.""" - mock_server_content = mock.create_autospec( - types.LiveServerContent, instance=True - ) - mock_server_content.model_turn = None - mock_server_content.grounding_metadata = mock.create_autospec( - types.GroundingMetadata, instance=True - ) - mock_server_content.turn_complete = False - mock_server_content.interrupted = False - mock_server_content.input_transcription = None - mock_server_content.output_transcription = None - mock_server_content.generation_complete = False - mock_server_content.turn_complete_reason = ( - types.TurnCompleteReason.RESPONSE_REJECTED - ) - - mock_message = mock.create_autospec(types.LiveServerMessage, instance=True) - mock_message.usage_metadata = None - mock_message.server_content = mock_server_content - mock_message.tool_call = None - mock_message.session_resumption_update = None - mock_message.go_away = None - - async def mock_receive_generator(): - yield mock_message - - mock_gemini_session.receive = mock.Mock(return_value=mock_receive_generator()) - - responses = [resp async for resp in gemini_connection.receive()] - - assert len(responses) == 1 - assert responses[0].grounding_metadata is not None - assert responses[0].turn_complete is None - assert ( - responses[0].turn_complete_reason - == types.TurnCompleteReason.RESPONSE_REJECTED - ) - - -@pytest.mark.asyncio -async def test_receive_populates_turn_complete_reason_with_content( - gemini_connection, mock_gemini_session -): - """Test that receive populates turn_complete_reason in LlmResponse when model turn has content parts.""" - mock_content = types.Content( - role='model', - parts=[types.Part.from_text(text='hello')], - ) - mock_server_content = mock.create_autospec( - types.LiveServerContent, instance=True - ) - mock_server_content.model_turn = mock_content - mock_server_content.grounding_metadata = None - mock_server_content.turn_complete = False - mock_server_content.interrupted = False - mock_server_content.input_transcription = None - mock_server_content.output_transcription = None - mock_server_content.generation_complete = False - mock_server_content.turn_complete_reason = ( - types.TurnCompleteReason.RESPONSE_REJECTED - ) - - mock_message = mock.create_autospec(types.LiveServerMessage, instance=True) - mock_message.usage_metadata = None - mock_message.server_content = mock_server_content - mock_message.tool_call = None - mock_message.session_resumption_update = None - mock_message.go_away = None - - async def mock_receive_generator(): - yield mock_message - - mock_gemini_session.receive = mock.Mock(return_value=mock_receive_generator()) - - responses = [resp async for resp in gemini_connection.receive()] - - assert len(responses) == 1 - assert responses[0].content == mock_content - assert ( - responses[0].turn_complete_reason - == types.TurnCompleteReason.RESPONSE_REJECTED - )