diff --git a/src/google/adk/agents/live_request_queue.py b/src/google/adk/agents/live_request_queue.py index 9b698c81d6..cf12b30f4e 100644 --- a/src/google/adk/agents/live_request_queue.py +++ b/src/google/adk/agents/live_request_queue.py @@ -62,8 +62,15 @@ class LiveRequestQueue: def __init__(self): self._queue = asyncio.Queue() + self._closed = False + + @property + def is_closed(self) -> bool: + """Returns True if close() has been called on this queue.""" + return self._closed def close(self): + self._closed = True self._queue.put_nowait(LiveRequest(close=True)) def send_content(self, content: types.Content): diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 4c253014a9..e2cb345a60 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -617,6 +617,14 @@ async def run_live( except asyncio.CancelledError: pass except (ConnectionClosed, ConnectionClosedOK) as e: + # An intentional close via LiveRequestQueue.close() may surface as a + # ConnectionClosed event. Do not reconnect in that case. + if invocation_context.live_request_queue.is_closed: + logger.info( + 'Live session for agent %s closed by client request.', + invocation_context.agent.name, + ) + return # If we have a session resumption handle, we attempt to reconnect. # This handle is updated dynamically during the session. if invocation_context.live_session_resumption_handle: @@ -630,6 +638,15 @@ async def run_live( logger.error('Connection closed: %s.', e) raise except errors.APIError as e: + # Error code 1000 indicates a normal (intentional) closure. If the + # client called LiveRequestQueue.close(), do not treat this as an error + # and do not attempt to reconnect regardless of session handle state. + if e.code == 1000 and invocation_context.live_request_queue.is_closed: + logger.info( + 'Live session for agent %s closed by client request.', + invocation_context.agent.name, + ) + return # Error code 1000 and 1006 indicates a recoverable connection drop. # In that case, we attempt to reconnect with session handle if available. if e.code in [1000, 1006]: @@ -649,6 +666,15 @@ async def run_live( ) raise + # If the client explicitly closed the queue and no exception was raised + # (e.g. the receive generator returned normally), do not reconnect. + if invocation_context.live_request_queue.is_closed: + logger.info( + 'Live session for agent %s closed by client request.', + invocation_context.agent.name, + ) + return + async def _send_to_model( self, llm_connection: BaseLlmConnection, diff --git a/tests/unittests/agents/test_live_request_queue.py b/tests/unittests/agents/test_live_request_queue.py index ab98894daf..1a17c5143e 100644 --- a/tests/unittests/agents/test_live_request_queue.py +++ b/tests/unittests/agents/test_live_request_queue.py @@ -17,6 +17,24 @@ async def test_close_queue(): mock_put_nowait.assert_called_once_with(LiveRequest(close=True)) +def test_is_closed_initially_false(): + queue = LiveRequestQueue() + assert queue.is_closed is False + + +def test_is_closed_true_after_close(): + queue = LiveRequestQueue() + queue.close() + assert queue.is_closed is True + + +def test_is_closed_not_affected_by_other_sends(): + queue = LiveRequestQueue() + queue.send_content(MagicMock(spec=types.Content)) + queue.send_realtime(MagicMock(spec=types.Blob)) + assert queue.is_closed is False + + def test_send_content(): queue = LiveRequestQueue() content = MagicMock(spec=types.Content) diff --git a/tests/unittests/flows/llm_flows/test_base_llm_flow.py b/tests/unittests/flows/llm_flows/test_base_llm_flow.py index 793ebb83cd..be76e63dd2 100644 --- a/tests/unittests/flows/llm_flows/test_base_llm_flow.py +++ b/tests/unittests/flows/llm_flows/test_base_llm_flow.py @@ -895,6 +895,168 @@ async def mock_receive(): assert mock_connect.call_count == DEFAULT_MAX_RECONNECT_ATTEMPTS + 2 +@pytest.mark.asyncio +async def test_run_live_no_reconnect_after_queue_close_api_error_1000(): + """Test that run_live does not reconnect after LiveRequestQueue.close() (APIError 1000). + + Calling LiveRequestQueue.close() signals an intentional client-side shutdown. + When the resulting APIError(1000) arrives, run_live must terminate instead of + reconnecting — even when a session resumption handle is present. + """ + from google.adk.agents.live_request_queue import LiveRequestQueue + from google.genai.errors import APIError + + real_model = Gemini() + mock_connection = mock.AsyncMock() + + async def mock_receive(): + # Simulate receiving a session resumption handle from the server. + yield LlmResponse( + live_session_resumption_update=types.LiveServerSessionResumptionUpdate( + new_handle='test_handle' + ) + ) + # Simulate the normal-close APIError that arrives after llm_connection.close(). + raise APIError(1000, {}) + + mock_connection.receive = mock.Mock(side_effect=mock_receive) + + agent = Agent(name='test_agent', model=real_model) + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + invocation_context.live_request_queue = LiveRequestQueue() + # Simulate what live_request_queue.close() does before the error arrives. + invocation_context.live_request_queue.close() + + flow = BaseLlmFlowForTesting() + + with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock): + with mock.patch( + 'google.adk.models.google_llm.Gemini.connect' + ) as mock_connect: + mock_connect.return_value.__aenter__.return_value = mock_connection + + events = [] + async for event in flow.run_live(invocation_context): + events.append(event) + + # run_live must terminate after the first connection — no reconnect. + assert mock_connect.call_count == 1 + + +@pytest.mark.asyncio +async def test_run_live_no_reconnect_after_queue_close_connection_closed(): + """Test that run_live does not reconnect after LiveRequestQueue.close() (ConnectionClosed). + + Same as the APIError(1000) case but the connection surfaces as ConnectionClosed, + which can happen depending on the websockets library version or transport layer. + """ + from google.adk.agents.live_request_queue import LiveRequestQueue + from websockets.exceptions import ConnectionClosed + + real_model = Gemini() + mock_connection = mock.AsyncMock() + + async def mock_receive(): + yield LlmResponse( + live_session_resumption_update=types.LiveServerSessionResumptionUpdate( + new_handle='test_handle' + ) + ) + raise ConnectionClosed(None, None) + + mock_connection.receive = mock.Mock(side_effect=mock_receive) + + agent = Agent(name='test_agent', model=real_model) + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + invocation_context.live_request_queue = LiveRequestQueue() + invocation_context.live_request_queue.close() + + flow = BaseLlmFlowForTesting() + + with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock): + with mock.patch( + 'google.adk.models.google_llm.Gemini.connect' + ) as mock_connect: + mock_connect.return_value.__aenter__.return_value = mock_connection + + events = [] + async for event in flow.run_live(invocation_context): + events.append(event) + + # run_live must terminate after the first connection — no reconnect. + assert mock_connect.call_count == 1 + + +@pytest.mark.asyncio +async def test_run_live_still_reconnects_on_unintentional_drop_with_handle(): + """Test that session-resumption reconnection still works for genuine drops. + + A genuine network drop (ConnectionClosed without queue.close()) with a session + resumption handle must still trigger reconnection. The queue.close() fix + must not break this existing behaviour. + """ + from google.adk.agents.live_request_queue import LiveRequestQueue + from websockets.exceptions import ConnectionClosed + + real_model = Gemini() + mock_connection = mock.AsyncMock() + + async def mock_receive(): + yield LlmResponse( + live_session_resumption_update=types.LiveServerSessionResumptionUpdate( + new_handle='test_handle' + ) + ) + # Genuine network drop (queue was NOT closed). + raise ConnectionClosed(None, None) + + mock_connection.receive = mock.Mock(side_effect=mock_receive) + + agent = Agent(name='test_agent', model=real_model) + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + invocation_context.live_request_queue = LiveRequestQueue() + # Note: queue.close() is NOT called — this is an unintentional drop. + + flow = BaseLlmFlowForTesting() + + with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock): + mock_connection_2 = mock.AsyncMock() + + class NonRetryableError(Exception): + pass + + async def mock_receive_2(): + if False: + yield + raise NonRetryableError('stop') + + mock_connection_2.receive = mock.Mock(side_effect=mock_receive_2) + + mock_aenter = mock.AsyncMock() + mock_aenter.side_effect = [mock_connection, mock_connection_2] + + with mock.patch( + 'google.adk.models.google_llm.Gemini.connect' + ) as mock_connect: + mock_connect.return_value.__aenter__ = mock_aenter + + try: + async for _ in flow.run_live(invocation_context): + pass + except NonRetryableError: + pass + + # Reconnection must have been attempted (2 connections). + assert mock_connect.call_count == 2 + assert invocation_context.live_session_resumption_handle == 'test_handle' + + @pytest.mark.asyncio async def test_postprocess_live_session_resumption_update(): """Test that _postprocess_live yields live_session_resumption_update."""