Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 3 additions & 32 deletions src/google/adk/models/gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,21 +159,14 @@ async def send_realtime(self, input: RealtimeInput):
else:
raise ValueError('Unsupported input type: %s' % type(input))

def __build_full_text_response(
self,
text: str,
is_thought: bool = False,
grounding_metadata: types.GroundingMetadata | None = None,
):
def __build_full_text_response(self, text: str):
"""Builds a full text response.

The text should not be partial and the returned LlmResponse is not
partial.

Args:
text: The text to be included in the response.
is_thought: Whether the text is a thought.
grounding_metadata: The grounding metadata to include.

Returns:
An LlmResponse containing the full text.
Expand All @@ -183,8 +176,6 @@ def __build_full_text_response(
role='model',
parts=[types.Part.from_text(text=text)],
),
grounding_metadata=grounding_metadata,
partial=False,
live_session_id=self._gemini_session.session_id,
)

Expand All @@ -197,7 +188,6 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:

text = ''
tool_call_parts = []
pending_grounding_metadata = None
async with Aclosing(self._gemini_session.receive()) as agen:
# TODO(b/440101573): Reuse StreamingResponseAggregator to accumulate
# partial content and emit responses as needed.
Expand All @@ -213,10 +203,6 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
)
if message.server_content:
content = message.server_content.model_turn
if message.server_content.grounding_metadata:
pending_grounding_metadata = (
message.server_content.grounding_metadata
)

# Standalone grounding_metadata event (when content is empty)
if (
Expand All @@ -229,9 +215,6 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
interrupted=message.server_content.interrupted,
model_version=self._model_version,
live_session_id=live_session_id,
turn_complete_reason=getattr(
message.server_content, 'turn_complete_reason', None
),
)

if content and content.parts:
Expand All @@ -240,9 +223,6 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
interrupted=message.server_content.interrupted,
model_version=self._model_version,
live_session_id=live_session_id,
turn_complete_reason=getattr(
message.server_content, 'turn_complete_reason', None
),
)
# grounding_metadata is yielded again at turn_complete,
# so avoid duplicating it here if turn_complete is true.
Expand Down Expand Up @@ -344,14 +324,9 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
)
self._output_transcription_text = ''
if message.server_content.turn_complete:
g_metadata_to_yield = pending_grounding_metadata
if text:
yield self.__build_full_text_response(
text, is_thought, g_metadata_to_yield
)
yield self.__build_full_text_response(text)
text = ''
is_thought = False
g_metadata_to_yield = None
if tool_call_parts:
logger.debug('Returning aggregated tool_call_parts')
yield LlmResponse(
Expand All @@ -363,13 +338,9 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
yield LlmResponse(
turn_complete=True,
interrupted=message.server_content.interrupted,
grounding_metadata=message.server_content.grounding_metadata
or g_metadata_to_yield,
grounding_metadata=message.server_content.grounding_metadata,
model_version=self._model_version,
live_session_id=live_session_id,
turn_complete_reason=getattr(
message.server_content, 'turn_complete_reason', None
),
)
break
# in case of empty content or parts, we still surface it
Expand Down
6 changes: 0 additions & 6 deletions src/google/adk/models/llm_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,6 @@ class LlmResponse(BaseModel):
Only used for streaming mode.
"""

turn_complete_reason: Optional[types.TurnCompleteReason] = None
"""The reason why the turn is complete.

Only used for streaming mode.
"""

finish_reason: Optional[types.FinishReason] = None
"""The finish reason of the response."""

Expand Down
197 changes: 0 additions & 197 deletions tests/unittests/models/test_gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1262,200 +1262,3 @@ async def mock_receive_generator():
content_response = next((r for r in responses if r.content), None)
assert content_response is not None
assert content_response.content == mock_content


@pytest.mark.asyncio
async def test_receive_grounding_metadata_pending(
gemini_connection, mock_gemini_session
):
"""Test that grounding metadata in partial chunks is pending and yielded on full text."""
grounding_metadata = types.GroundingMetadata(
web_search_queries=['stock price of google'],
)

def make_msg(text=None, g_meta=None, tc=False):
msg = mock.Mock(
usage_metadata=None,
tool_call=None,
session_resumption_update=None,
go_away=None,
)
msg.server_content = mock.Mock(
interrupted=False,
input_transcription=None,
output_transcription=None,
generation_complete=False,
turn_complete=tc,
grounding_metadata=g_meta,
model_turn=types.Content(
role='model', parts=[types.Part.from_text(text=text)]
)
if text
else None,
)
return msg

msg1 = make_msg(text='hello', g_meta=grounding_metadata)
msg2 = make_msg(text=' world')
msg3 = make_msg(tc=True)

async def gen():
yield msg1
yield msg2
yield msg3

mock_gemini_session.receive = mock.Mock(return_value=gen())

responses = [resp async for resp in gemini_connection.receive()]

# Expected responses:
# 1. Msg 1 partial (hello) with grounding_metadata
# 2. Msg 2 partial ( world) without grounding_metadata
# 3. Full text response (hello world) with PENDING grounding_metadata
# 4. Turn complete response without grounding_metadata (already cleared)
assert len(responses) == 4

assert responses[0].content.parts[0].text == 'hello'
assert responses[0].partial is True
assert responses[0].grounding_metadata == grounding_metadata

assert responses[1].content.parts[0].text == ' world'
assert responses[1].partial is True
assert responses[1].grounding_metadata is None

assert responses[2].content.parts[0].text == 'hello world'
assert responses[2].partial is False
assert responses[2].grounding_metadata == grounding_metadata

assert responses[3].turn_complete is True
assert responses[3].grounding_metadata is None


@pytest.mark.asyncio
async def test_receive_populates_turn_complete_reason(
gemini_connection, mock_gemini_session
):
"""Test that receive populates turn_complete_reason in LlmResponse."""
mock_server_content = mock.create_autospec(
types.LiveServerContent, instance=True
)
mock_server_content.model_turn = None
mock_server_content.grounding_metadata = None
mock_server_content.turn_complete = True
mock_server_content.interrupted = False
mock_server_content.input_transcription = None
mock_server_content.output_transcription = None
mock_server_content.generation_complete = False
mock_server_content.turn_complete_reason = (
types.TurnCompleteReason.RESPONSE_REJECTED
)

mock_message = mock.create_autospec(types.LiveServerMessage, instance=True)
mock_message.usage_metadata = None
mock_message.server_content = mock_server_content
mock_message.tool_call = None
mock_message.session_resumption_update = None
mock_message.go_away = None

async def mock_receive_generator():
yield mock_message

mock_gemini_session.receive = mock.Mock(return_value=mock_receive_generator())

responses = [resp async for resp in gemini_connection.receive()]

assert len(responses) == 1
assert responses[0].turn_complete is True
assert (
responses[0].turn_complete_reason
== types.TurnCompleteReason.RESPONSE_REJECTED
)


@pytest.mark.asyncio
async def test_receive_populates_turn_complete_reason_standalone_grounding(
gemini_connection, mock_gemini_session
):
"""Test that receive populates turn_complete_reason in LlmResponse for standalone grounding metadata."""
mock_server_content = mock.create_autospec(
types.LiveServerContent, instance=True
)
mock_server_content.model_turn = None
mock_server_content.grounding_metadata = mock.create_autospec(
types.GroundingMetadata, instance=True
)
mock_server_content.turn_complete = False
mock_server_content.interrupted = False
mock_server_content.input_transcription = None
mock_server_content.output_transcription = None
mock_server_content.generation_complete = False
mock_server_content.turn_complete_reason = (
types.TurnCompleteReason.RESPONSE_REJECTED
)

mock_message = mock.create_autospec(types.LiveServerMessage, instance=True)
mock_message.usage_metadata = None
mock_message.server_content = mock_server_content
mock_message.tool_call = None
mock_message.session_resumption_update = None
mock_message.go_away = None

async def mock_receive_generator():
yield mock_message

mock_gemini_session.receive = mock.Mock(return_value=mock_receive_generator())

responses = [resp async for resp in gemini_connection.receive()]

assert len(responses) == 1
assert responses[0].grounding_metadata is not None
assert responses[0].turn_complete is None
assert (
responses[0].turn_complete_reason
== types.TurnCompleteReason.RESPONSE_REJECTED
)


@pytest.mark.asyncio
async def test_receive_populates_turn_complete_reason_with_content(
gemini_connection, mock_gemini_session
):
"""Test that receive populates turn_complete_reason in LlmResponse when model turn has content parts."""
mock_content = types.Content(
role='model',
parts=[types.Part.from_text(text='hello')],
)
mock_server_content = mock.create_autospec(
types.LiveServerContent, instance=True
)
mock_server_content.model_turn = mock_content
mock_server_content.grounding_metadata = None
mock_server_content.turn_complete = False
mock_server_content.interrupted = False
mock_server_content.input_transcription = None
mock_server_content.output_transcription = None
mock_server_content.generation_complete = False
mock_server_content.turn_complete_reason = (
types.TurnCompleteReason.RESPONSE_REJECTED
)

mock_message = mock.create_autospec(types.LiveServerMessage, instance=True)
mock_message.usage_metadata = None
mock_message.server_content = mock_server_content
mock_message.tool_call = None
mock_message.session_resumption_update = None
mock_message.go_away = None

async def mock_receive_generator():
yield mock_message

mock_gemini_session.receive = mock.Mock(return_value=mock_receive_generator())

responses = [resp async for resp in gemini_connection.receive()]

assert len(responses) == 1
assert responses[0].content == mock_content
assert (
responses[0].turn_complete_reason
== types.TurnCompleteReason.RESPONSE_REJECTED
)