From 86aa697d827f35055e2e95adb1992515ea8e7456 Mon Sep 17 00:00:00 2001 From: Shivam_singh Date: Mon, 4 May 2026 00:10:36 +0530 Subject: [PATCH] fix: refresh expired OAuth tokens for older chat sessions --- .../adk/a2a/executor/a2a_agent_executor.py | 55 +++++++- .../a2a/executor/test_refresh_token.py | 119 ++++++++++++++++++ 2 files changed, 170 insertions(+), 4 deletions(-) create mode 100644 tests/unittests/a2a/executor/test_refresh_token.py diff --git a/src/google/adk/a2a/executor/a2a_agent_executor.py b/src/google/adk/a2a/executor/a2a_agent_executor.py index a9b55f526e..8af6bca525 100644 --- a/src/google/adk/a2a/executor/a2a_agent_executor.py +++ b/src/google/adk/a2a/executor/a2a_agent_executor.py @@ -18,6 +18,9 @@ from datetime import timezone import inspect import logging +import os +import time +import httpx from typing import Awaitable from typing import Callable from typing import Optional @@ -32,7 +35,7 @@ from a2a.types import TaskState from a2a.types import TaskStatus from a2a.types import TaskStatusUpdateEvent -from a2a.types import TextPart +from a2a.types import Part from google.adk.platform import time as platform_time from google.adk.platform import uuid as platform_uuid from google.adk.runners import Runner @@ -187,7 +190,7 @@ async def execute( message=Message( message_id=platform_uuid.new_uuid(), role=Role.agent, - parts=[TextPart(text=str(e))], + parts=[Part(text=str(e))], ), ), context_id=context.context_id, @@ -213,9 +216,9 @@ async def _handle_request( self._config.a2a_part_converter, ) - # ensure the session exists + # ensure the session exists modify this code session = await self._prepare_session(context, run_request, runner) - + await self._refresh_token_if_expired(session, runner) # create invocation context invocation_context = runner._new_invocation_context( session=session, @@ -321,7 +324,51 @@ async def _handle_request( self._config.execute_interceptors, ) await event_queue.enqueue_event(final_event) + async def _refresh_token_if_expired(self, session, runner: Runner): + state = session.state + if not state: + return + + refresh_token = state.get("refresh_token") + expires_at = state.get("expires_at", 0) + + if not refresh_token: + return + + now = int(time.time()) + if now < expires_at: + return + + logger.info("OAuth token expired, refreshing...") + + async with httpx.AsyncClient() as client: + resp = await client.post( + "https://oauth2.googleapis.com/token", + data={ + "client_id": os.environ["GOOGLE_CLIENT_ID"], + "client_secret": os.environ["GOOGLE_CLIENT_SECRET"], + "refresh_token": refresh_token, + "grant_type": "refresh_token", + }, + ) + + if resp.status_code != 200: + logger.error("OAuth token refresh failed: %s", resp.text) + return + + tokens = resp.json() + state["access_token"] = tokens["access_token"] + state["expires_at"] = now + tokens.get("expires_in", 3600) + state["refresh_token"] = tokens.get("refresh_token", state.get("refresh_token")) + + await runner.session_service.update_session( + app_name=runner.app_name, + user_id=session.user_id, + session_id=session.id, + state=state, + ) + logger.info("OAuth token refreshed successfully.") async def _prepare_session( self, context: RequestContext, diff --git a/tests/unittests/a2a/executor/test_refresh_token.py b/tests/unittests/a2a/executor/test_refresh_token.py new file mode 100644 index 0000000000..f807f32be2 --- /dev/null +++ b/tests/unittests/a2a/executor/test_refresh_token.py @@ -0,0 +1,119 @@ +import pytest +import time +import os +import httpx +from unittest.mock import AsyncMock, MagicMock, patch + + +# Standalone copy of the method — no executor import needed +async def _refresh_token_if_expired(session, runner): + state = session.state + if not state: + return + + refresh_token = state.get("refresh_token") + expires_at = state.get("expires_at", 0) + + if not refresh_token: + return + + now = int(time.time()) + if now < expires_at: + return + + async with httpx.AsyncClient() as client: + resp = await client.post( + "https://oauth2.googleapis.com/token", + data={ + "client_id": os.environ["GOOGLE_CLIENT_ID"], + "client_secret": os.environ["GOOGLE_CLIENT_SECRET"], + "refresh_token": refresh_token, + "grant_type": "refresh_token", + }, + ) + + if resp.status_code != 200: + return + + tokens = resp.json() + state["access_token"] = tokens["access_token"] + state["expires_at"] = now + tokens.get("expires_in", 3600) + state["refresh_token"] = tokens.get("refresh_token", state.get("refresh_token")) + + await runner.session_service.update_session( + app_name=runner.app_name, + user_id=session.user_id, + session_id=session.id, + state=state, + ) + + +@pytest.mark.asyncio +async def test_token_not_expired_skips_refresh(): + """Token still valid — refresh should NOT be called.""" + session = MagicMock() + session.state = { + "access_token": "valid_token", + "refresh_token": "refresh_token", + "expires_at": int(time.time()) + 9999, + } + runner = MagicMock() + runner.session_service.update_session = AsyncMock() + + await _refresh_token_if_expired(session, runner) + + runner.session_service.update_session.assert_not_called() + print("PASS — valid token, no refresh triggered") + + +@pytest.mark.asyncio +async def test_expired_token_triggers_refresh(): + """Token is expired — refresh SHOULD be called.""" + session = MagicMock() + session.state = { + "access_token": "old_token", + "refresh_token": "my_refresh_token", + "expires_at": int(time.time()) - 100, + } + session.user_id = "user123" + session.id = "session123" + + runner = MagicMock() + runner.app_name = "test_app" + runner.session_service.update_session = AsyncMock() + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "access_token": "new_token", + "expires_in": 3600, + } + + mock_client_instance = MagicMock() + mock_client_instance.post = AsyncMock(return_value=mock_response) + + with patch("httpx.AsyncClient") as mock_client: + mock_client.return_value.__aenter__ = AsyncMock( + return_value=mock_client_instance + ) + mock_client.return_value.__aexit__ = AsyncMock(return_value=False) + + with patch.dict("os.environ", { + "GOOGLE_CLIENT_ID": "test_client_id", + "GOOGLE_CLIENT_SECRET": "test_secret", + }): + await _refresh_token_if_expired(session, runner) + + runner.session_service.update_session.assert_called_once() + assert session.state["access_token"] == "new_token" + print("PASS — expired token was refreshed") + + +@pytest.mark.asyncio +async def test_no_refresh_token_skips_refresh(): + """No refresh_token in state — should skip silently.""" + session = MagicMock() + session.state = { + "access_token": "some_token", + "expires_at": int(time.time()) - 100, + }