Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,6 @@ async def run_live(
llm_request.live_connect_config.session_resumption.handle = (
invocation_context.live_session_resumption_handle
)
llm_request.live_connect_config.session_resumption.transparent = True

logger.info(
'Establishing live connection for agent: %s',
Expand Down
44 changes: 44 additions & 0 deletions tests/unittests/flows/llm_flows/test_base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,50 @@ async def mock_receive():
mock_connection.send_history.assert_not_called()


@pytest.mark.asyncio
async def test_run_live_resumption_preserves_transparent_setting():
"""Test that reconnect does not force transparent resumption."""
from google.adk.agents.live_request_queue import LiveRequestQueue

real_model = Gemini()
mock_connection = mock.AsyncMock()

async def mock_receive():
yield LlmResponse(
content=types.Content(parts=[types.Part.from_text(text='hi')])
)
raise RuntimeError('stop')

mock_connection.receive = mock.Mock(side_effect=mock_receive)

agent = Agent(name='test_agent', model=real_model)
run_config = RunConfig(session_resumption=types.SessionResumptionConfig())
invocation_context = await testing_utils.create_invocation_context(
agent=agent, run_config=run_config
)
invocation_context.live_session_resumption_handle = 'test_handle'
invocation_context.live_request_queue = LiveRequestQueue()

flow = BaseLlmFlowForTesting()

with mock.patch.object(
flow, '_send_to_model', new_callable=AsyncMock
) as mock_send:
with mock.patch(
'google.adk.models.google_llm.Gemini.connect'
) as mock_connect:
mock_connect.return_value.__aenter__.return_value = mock_connection

with pytest.raises(RuntimeError, match='stop'):
async for _ in flow.run_live(invocation_context):
pass

llm_request = mock_connect.call_args.args[0]
session_resumption = llm_request.live_connect_config.session_resumption
assert session_resumption.handle == 'test_handle'
assert session_resumption.transparent is None


@pytest.mark.asyncio
async def test_live_session_resumption_go_away():
"""Test that go_away triggers reconnection."""
Expand Down