diff --git a/contributing/samples/mcp_stdio_user_auth_passing_sample/README.md b/contributing/samples/mcp_stdio_user_auth_passing_sample/README.md
new file mode 100644
index 0000000000..bd89dbe0cf
--- /dev/null
+++ b/contributing/samples/mcp_stdio_user_auth_passing_sample/README.md
@@ -0,0 +1,80 @@
+# Sample: Passing User Token from Agent State to MCP via ContextToEnvMapperCallback
+
+This sample demonstrates how to use the `context_to_env_mapper_callback` feature in ADK to pass a user token from the agent's session state to an MCP process (using stdio transport). This is useful when your MCP server (built by your organization) requires the same user token for internal API calls.
+
+## How it works
+- The agent is initialized with a `MCPToolset` using `StdioServerParameters`.
+- The `context_to_env_mapper_callback` is set to a function that extracts the `user_token` from the agent's state and maps it to the `USER_TOKEN` environment variable.
+- When the agent calls the MCP, the token is injected into the MCP process environment, allowing the MCP to use it for internal authentication.
+
+## Directory Structure
+```
+contributing/samples/stdio_mcp_user_auth_passing_sample/
+├── agent.py # Basic agent setup
+├── main.py # Complete runnable example
+└── README.md
+```
+
+## How to Run
+
+### Option 1: Run the complete example
+```bash
+cd /home/sanjay-dev/Workspace/adk-python
+python -m contributing.samples.stdio_mcp_user_auth_passing_sample.main
+```
+
+### Option 2: Use the agent in your own code
+```python
+from contributing.samples.stdio_mcp_user_auth_passing_sample.agent import create_agent
+from google.adk.sessions import Session
+
+agent = create_agent()
+session = Session(
+ id="your_session_id",
+ app_name="your_app_name",
+ user_id="your_user_id"
+)
+
+# Set user token in session state
+session.state['user_token'] = 'YOUR_ACTUAL_TOKEN_HERE'
+session.state['api_endpoint'] = 'https://your-internal-api.com'
+
+# Then use the agent in your workflow...
+```
+
+## Flow Diagram
+
+```mermaid
+graph TD
+ subgraph "User Application"
+ U[User]
+ end
+
+ subgraph "Agent Process"
+ A[Agent Instance
per user-app-agentid]
+ S[Session State
user_token, api_endpoint]
+ C[ContextToEnvMapperCallback]
+ end
+
+ subgraph "MCP Process"
+ M[MCP Server
stdio transport]
+ E[Environment Variables
USER_TOKEN, API_ENDPOINT]
+ API[Internal API Calls]
+ end
+
+ U -->|Sends request| A
+ A -->|Reads state| S
+ S -->|Extracts tokens| C
+ C -->|Maps to env vars| E
+ A -->|Spawns with env| M
+ M -->|Uses env vars| API
+ API -->|Response| M
+ M -->|Tool result| A
+ A -->|Response| U
+```
+
+## Context
+- Each agent instance is initiated per user-app-agentid.
+- The agent receives a user context (with token) and calls the MCP using stdio transport.
+- The MCP, built by the same organization, uses the token for internal API calls.
+- The ADK's context-to-env mapping feature makes this seamless.
diff --git a/contributing/samples/mcp_stdio_user_auth_passing_sample/__init__.py b/contributing/samples/mcp_stdio_user_auth_passing_sample/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/contributing/samples/mcp_stdio_user_auth_passing_sample/agent.py b/contributing/samples/mcp_stdio_user_auth_passing_sample/agent.py
new file mode 100644
index 0000000000..e72ff5e138
--- /dev/null
+++ b/contributing/samples/mcp_stdio_user_auth_passing_sample/agent.py
@@ -0,0 +1,60 @@
+"""
+Sample: Using ContextToEnvMapperCallback to pass user token from agent state to MCP via stdio transport.
+"""
+
+import os
+import tempfile
+from typing import Any
+from typing import Dict
+
+from google.adk.agents.llm_agent import LlmAgent
+from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams
+from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset
+from mcp import StdioServerParameters
+
+_allowed_path = os.path.dirname(os.path.abspath(__file__))
+
+
+def user_token_env_mapper(state: Dict[str, Any]) -> Dict[str, str]:
+ """Extracts USER_TOKEN from agent state and maps to MCP env."""
+ env = {}
+ if "user_token" in state:
+ env["USER_TOKEN"] = state["user_token"]
+ if "api_endpoint" in state:
+ env["API_ENDPOINT"] = state["api_endpoint"]
+
+ print(f"Environment variables being passed to MCP: {env}")
+ return env
+
+
+def create_agent() -> LlmAgent:
+ """Create the agent with context to env mapper callback."""
+ # Create a temporary directory for the filesystem server
+ temp_dir = tempfile.mkdtemp()
+
+ return LlmAgent(
+ model="gemini-2.0-flash",
+ name="user_token_agent",
+ instruction=f"""
+ You are an agent that calls an internal MCP server which requires a user token for internal API calls.
+ The user token is available in your session state and must be passed to the MCP process as an environment variable.
+ Test directory: {temp_dir}
+ """,
+ tools=[
+ MCPToolset(
+ connection_params=StdioConnectionParams(
+ server_params=StdioServerParameters(
+ command="npx",
+ args=[
+ "-y", # Arguments for the command
+ "@modelcontextprotocol/server-filesystem",
+ _allowed_path,
+ ],
+ ),
+ timeout=5,
+ ),
+ get_env_from_context_fn=user_token_env_mapper,
+ tool_filter=["read_file", "list_directory"],
+ )
+ ],
+ )
diff --git a/contributing/samples/mcp_stdio_user_auth_passing_sample/main.py b/contributing/samples/mcp_stdio_user_auth_passing_sample/main.py
new file mode 100644
index 0000000000..d0408f4809
--- /dev/null
+++ b/contributing/samples/mcp_stdio_user_auth_passing_sample/main.py
@@ -0,0 +1,95 @@
+"""
+Sample: Using ContextToEnvMapperCallback to pass user token from agent state to MCP via stdio transport.
+"""
+
+import asyncio
+
+from google.adk.agents.invocation_context import InvocationContext
+from google.adk.agents.readonly_context import ReadonlyContext
+from google.adk.sessions import InMemorySessionService
+from google.adk.sessions import Session
+
+from .agent import create_agent
+
+
+async def main():
+ """Example of how to set up and run the agent with user token."""
+ print("=== STDIO MCP User Auth Passing Sample ===")
+ print()
+
+ # Create the agent
+ agent = create_agent()
+ print(f"✓ Created agent: {agent.name}")
+
+ # Create session service and session
+ session_service = InMemorySessionService()
+ session = Session(
+ id="sample_session",
+ app_name="stdio_mcp_user_auth_passing_sample",
+ user_id="sample_user",
+ )
+ print(f"✓ Created session: {session.id}")
+
+ # Set user token in session state
+ session.state["user_token"] = "sample_user_token_123"
+ session.state["api_endpoint"] = "https://internal-api.company.com"
+ print(f"✓ Set session state with user_token: {session.state['user_token']}")
+
+ # Create invocation context
+ invocation_context = InvocationContext(
+ invocation_id="sample_invocation",
+ agent=agent,
+ session=session,
+ session_service=session_service,
+ )
+
+ # Create readonly context
+ readonly_context = ReadonlyContext(invocation_context)
+ print(f"✓ Created readonly context")
+
+ print()
+ print("=== Demonstrating User Auth Token Passing to MCP ===")
+ print(
+ "Note: This sample shows how the callback extracts environment variables."
+ )
+ print("In a real scenario, these would be passed to an actual MCP server.")
+ print()
+
+ # Access the MCP toolset to demonstrate the callback
+ mcp_toolset = agent.tools[0]
+ mcp_session_manager = mcp_toolset._mcp_session_manager
+
+ # Extract environment variables using the callback (without connecting to MCP)
+ if mcp_session_manager._context_to_env_mapper_callback:
+ print("✓ Context-to-env mapper callback is configured")
+
+ # Simulate what happens during MCP session creation
+ env_vars = mcp_session_manager._extract_env_from_context(readonly_context)
+
+ print(f"✓ Extracted environment variables:")
+ for key, value in env_vars.items():
+ print(f" {key}={value}")
+ print()
+
+ print(
+ "✓ These environment variables would be injected into the MCP process"
+ )
+ print("✓ The MCP server can then use them for internal API calls")
+ else:
+ print("✗ No context-to-env mapper callback configured")
+
+ print()
+ print("=== Sample completed successfully! ===")
+ print()
+ print("Key points demonstrated:")
+ print("1. Session state holds user tokens and configuration")
+ print(
+ "2. Context-to-env mapper callback extracts these as environment"
+ " variables"
+ )
+ print("3. Environment variables would be passed to MCP server processes")
+ print("4. MCP servers can use these for authenticated API calls")
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/pyproject.toml b/pyproject.toml
index e85bdaff5e..c49fa5dc1b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -38,7 +38,7 @@ dependencies = [
"google-cloud-storage>=2.18.0, <3.0.0", # For GCS Artifact service
"google-genai>=1.21.1", # Google GenAI SDK
"graphviz>=0.20.2", # Graphviz for graph rendering
- "mcp>=1.8.0;python_version>='3.10'", # For MCP Toolset
+ "mcp>=1.9.4;python_version>='3.10'", # For MCP Toolset
"opentelemetry-api>=1.31.0", # OpenTelemetry
"opentelemetry-exporter-gcp-trace>=1.9.0",
"opentelemetry-sdk>=1.31.0",
diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py
index 1853fb1a72..fae81c94d4 100644
--- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py
+++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py
@@ -23,6 +23,7 @@
import logging
import sys
from typing import Any
+from typing import Callable
from typing import Dict
from typing import Optional
from typing import TextIO
@@ -177,13 +178,43 @@ def __init__(
else:
self._connection_params = connection_params
self._errlog = errlog
-
# Session pool: maps session keys to (session, exit_stack) tuples
self._sessions: Dict[str, tuple[ClientSession, AsyncExitStack]] = {}
-
# Lock to prevent race conditions in session creation
self._session_lock = asyncio.Lock()
+ def update_connection_params(
+ self,
+ new_connection_params: Union[
+ StdioServerParameters,
+ StdioConnectionParams,
+ SseConnectionParams,
+ StreamableHTTPConnectionParams,
+ ],
+ ) -> None:
+ """Updates the connection parameters and invalidates existing sessions.
+
+ Args:
+ new_connection_params: New connection parameters to use.
+ """
+ if isinstance(new_connection_params, StdioServerParameters):
+ logger.warning(
+ 'StdioServerParameters is not recommended. Please use'
+ ' StdioConnectionParams.'
+ )
+ self._connection_params = StdioConnectionParams(
+ server_params=new_connection_params,
+ timeout=5,
+ )
+ else:
+ self._connection_params = new_connection_params
+
+ # Clear existing sessions since connection params changed
+ # Sessions will be recreated on next request
+ # Note: We don't close sessions here to avoid blocking,
+ # they will be cleaned up when detected as disconnected
+
+
def _generate_session_key(
self, merged_headers: Optional[Dict[str, str]] = None
) -> str:
diff --git a/src/google/adk/tools/mcp_tool/mcp_toolset.py b/src/google/adk/tools/mcp_tool/mcp_toolset.py
index 2fc9d640af..942444e642 100644
--- a/src/google/adk/tools/mcp_tool/mcp_toolset.py
+++ b/src/google/adk/tools/mcp_tool/mcp_toolset.py
@@ -16,9 +16,14 @@
import logging
import sys
+from typing import Any
+from typing import Callable
+from typing import Dict
from typing import List
+from typing import Mapping
from typing import Optional
from typing import TextIO
+from typing import Tuple
from typing import Union
from ...agents.readonly_context import ReadonlyContext
@@ -53,6 +58,20 @@
logger = logging.getLogger("google_adk." + __name__)
+# Type definition for auth extraction callback
+GetAuthFromContextCallback = Callable[
+ [Dict[str, Any]], Tuple[Optional[AuthScheme], Optional[AuthCredential]]
+]
+
+# Type definition for environment extraction callback
+GetEnvFromContextCallback = Callable[[Mapping[str, Any]], Dict[str, str]]
+
+
+class AuthExtractionError(Exception):
+ """Exception raised when auth extraction from context fails."""
+
+ pass
+
class MCPToolset(BaseToolset):
"""Connects to a MCP Server, and retrieves MCP Tools into ADK Tools.
@@ -94,6 +113,8 @@ def __init__(
StreamableHTTPConnectionParams,
],
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
+ get_auth_from_context_fn: Optional[GetAuthFromContextCallback] = None,
+ get_env_from_context_fn: Optional[GetEnvFromContextCallback] = None,
errlog: TextIO = sys.stderr,
auth_scheme: Optional[AuthScheme] = None,
auth_credential: Optional[AuthCredential] = None,
@@ -112,6 +133,16 @@ def __init__(
tool_filter: Optional filter to select specific tools. Can be either: - A
list of tool names to include - A ToolPredicate function for custom
filtering logic
+ get_auth_from_context_fn: Optional callback function to extract auth data
+ from ReadonlyContext.state into AuthScheme and AuthCredential. Must
+ return a tuple of (AuthScheme, AuthCredential). If None, the toolset
+ will use the auth_scheme and auth_credential provided in __init__.
+ If provided, the callback must return valid AuthScheme and
+ AuthCredential objects - None values are not allowed.
+ get_env_from_context_fn: Optional callback function to transform session
+ state into environment variables for the MCP connection. Takes a
+ dictionary of session state and returns a dictionary of environment
+ variables to be injected into the MCP connection.
errlog: TextIO stream for error logging.
auth_scheme: The auth scheme of the tool for tool calling
auth_credential: The auth credential of the tool for tool calling
@@ -122,6 +153,8 @@ def __init__(
raise ValueError("Missing connection params in MCPToolset.")
self._connection_params = connection_params
+ self._get_auth_from_context_fn = get_auth_from_context_fn
+ self._get_env_from_context_fn = get_env_from_context_fn
self._errlog = errlog
# Create the session manager that will handle the MCP connection
@@ -132,6 +165,170 @@ def __init__(
self._auth_scheme = auth_scheme
self._auth_credential = auth_credential
+ def _extract_env_from_context(
+ self, readonly_context: Optional[ReadonlyContext]
+ ) -> Dict[str, str]:
+ """Extracts environment variables from readonly context using callback.
+
+ Args:
+ readonly_context: The readonly context containing state information.
+
+ Returns:
+ Dictionary of environment variables to inject.
+ """
+ if not self._get_env_from_context_fn or not readonly_context:
+ return {}
+
+ try:
+ # Get state from readonly context if available
+ if hasattr(readonly_context, "state") and readonly_context.state:
+ # Pass readonly state directly - no need to copy for read-only access
+ return self._get_env_from_context_fn(readonly_context.state)
+ else:
+ return {}
+ except Exception as e:
+ logger.warning(f"Context to env mapper callback failed: {e}")
+ return {}
+
+ def _inject_env_vars(self, env_vars: Dict[str, str]) -> Union[
+ StdioServerParameters,
+ StdioConnectionParams,
+ SseConnectionParams,
+ StreamableHTTPConnectionParams,
+ ]:
+ """Injects environment variables into StdioConnectionParams.
+
+ Args:
+ env_vars: Dictionary of environment variables to inject.
+
+ Returns:
+ Updated connection params with injected environment variables.
+ """
+ if not env_vars or not isinstance(
+ self._connection_params, StdioConnectionParams
+ ):
+ return self._connection_params
+
+ # Get existing env vars from connection params
+ existing_env = (
+ getattr(self._connection_params.server_params, "env", None) or {}
+ )
+
+ # Merge existing and new env vars (new ones take precedence)
+ merged_env = {**existing_env, **env_vars}
+
+ # Create new server params with merged environment variables
+ from mcp import StdioServerParameters
+
+ new_server_params = StdioServerParameters(
+ command=self._connection_params.server_params.command,
+ args=self._connection_params.server_params.args,
+ env=merged_env,
+ cwd=getattr(self._connection_params.server_params, "cwd", None),
+ encoding=getattr(
+ self._connection_params.server_params, "encoding", None
+ )
+ or "utf-8",
+ encoding_error_handler=getattr(
+ self._connection_params.server_params,
+ "encoding_error_handler",
+ None,
+ )
+ or "strict",
+ )
+
+ # Create new connection params with updated server params
+ return StdioConnectionParams(
+ server_params=new_server_params,
+ timeout=self._connection_params.timeout,
+ )
+
+ def _extract_auth_from_context(
+ self, readonly_context: Optional[ReadonlyContext]
+ ) -> Tuple[Optional[AuthScheme], Optional[AuthCredential]]:
+ """Extracts auth scheme and credential from readonly context.
+
+ Args:
+ readonly_context: The readonly context containing state information.
+
+ Returns:
+ Tuple of (AuthScheme, AuthCredential) or (None, None) if not found.
+
+ Raises:
+ AuthExtractionError: If callback is provided but returns invalid types
+ or if callback execution fails.
+ """
+ # If no context provided, return init values
+ if not readonly_context:
+ return self._auth_scheme, self._auth_credential
+
+ # Get state from readonly context if available
+ if hasattr(readonly_context, "state") and readonly_context.state:
+ try:
+ # Handle both real ReadonlyContext (state is MappingProxyType)
+ # and test mocks (state might be a callable returning dict)
+ if callable(readonly_context.state):
+ state_dict = readonly_context.state()
+ else:
+ state_dict = dict(readonly_context.state)
+ except (TypeError, ValueError) as e:
+ if self._get_auth_from_context_fn:
+ raise AuthExtractionError(
+ f"Failed to extract state from readonly context: {e}"
+ ) from e
+ else:
+ # If no callback, just return init values on state extraction failure
+ return self._auth_scheme, self._auth_credential
+ else:
+ return self._auth_scheme, self._auth_credential
+
+ # If callback is provided, use it and validate return
+ if self._get_auth_from_context_fn:
+ try:
+ auth_result = self._get_auth_from_context_fn(state_dict)
+ except Exception as e:
+ raise AuthExtractionError(
+ f"Auth extraction callback failed: {e}"
+ ) from e
+
+ # Validate callback return type
+ if not isinstance(auth_result, tuple) or len(auth_result) != 2:
+ raise AuthExtractionError(
+ "Auth extraction callback must return a tuple of (AuthScheme,"
+ f" AuthCredential), got {type(auth_result)}"
+ )
+
+ auth_scheme, auth_credential = auth_result
+
+ # Validate that returned values are correct types (allow None)
+ if auth_scheme is not None and not isinstance(auth_scheme, AuthScheme):
+ raise AuthExtractionError(
+ "Auth extraction callback returned invalid auth_scheme type: "
+ f"expected AuthScheme or None, got {type(auth_scheme)}"
+ )
+
+ if auth_credential is not None and not isinstance(
+ auth_credential, AuthCredential
+ ):
+ raise AuthExtractionError(
+ "Auth extraction callback returned invalid auth_credential type: "
+ f"expected AuthCredential or None, got {type(auth_credential)}"
+ )
+
+ return auth_scheme, auth_credential
+
+ # If no callback, look for auth data directly in state (fallback behavior)
+ auth_scheme = state_dict.get("auth_scheme", self._auth_scheme)
+ auth_credential = state_dict.get("auth_credential", self._auth_credential)
+
+ # Validate types - only use state values if they are correct types
+ if not isinstance(auth_scheme, AuthScheme):
+ auth_scheme = self._auth_scheme
+ if not isinstance(auth_credential, AuthCredential):
+ auth_credential = self._auth_credential
+
+ return auth_scheme, auth_credential
+
@retry_on_closed_resource
async def get_tools(
self,
@@ -141,14 +338,30 @@ async def get_tools(
Args:
readonly_context: Context used to filter tools available to the agent.
- If None, all tools in the toolset are returned.
+ If None, all tools in the toolset are returned. The context may
+ also contain auth information in its state.
Returns:
List[BaseTool]: A list of tools available under the specified context.
"""
+ # Extract environment variables from context and inject them
+ env_vars = self._extract_env_from_context(readonly_context)
+ if env_vars:
+ # Update connection params with environment variables
+ updated_connection_params = self._inject_env_vars(env_vars)
+ # Update the session manager with new connection params
+ self._mcp_session_manager.update_connection_params(
+ updated_connection_params
+ )
+
# Get session from session manager
session = await self._mcp_session_manager.create_session()
+ # Extract auth information from context
+ auth_scheme, auth_credential = self._extract_auth_from_context(
+ readonly_context
+ )
+
# Fetch available tools from the MCP server
tools_response: ListToolsResult = await session.list_tools()
@@ -158,14 +371,29 @@ async def get_tools(
mcp_tool = MCPTool(
mcp_tool=tool,
mcp_session_manager=self._mcp_session_manager,
- auth_scheme=self._auth_scheme,
- auth_credential=self._auth_credential,
+ auth_scheme=auth_scheme,
+ auth_credential=auth_credential,
)
if self._is_tool_selected(mcp_tool, readonly_context):
tools.append(mcp_tool)
return tools
+ def _is_tool_selected(
+ self, tool: BaseTool, readonly_context: Optional[ReadonlyContext]
+ ) -> bool:
+ """Override to handle None readonly_context."""
+ if not self.tool_filter:
+ return True
+
+ if isinstance(self.tool_filter, ToolPredicate):
+ return self.tool_filter(tool, readonly_context)
+
+ if isinstance(self.tool_filter, list):
+ return tool.name in self.tool_filter
+
+ return False
+
async def close(self) -> None:
"""Performs cleanup and releases resources held by the toolset.
diff --git a/tests/integration/test_mcp_env_integration.py b/tests/integration/test_mcp_env_integration.py
new file mode 100644
index 0000000000..71a3bd41fc
--- /dev/null
+++ b/tests/integration/test_mcp_env_integration.py
@@ -0,0 +1,653 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Integration tests for MCP environment variable extraction and injection."""
+
+import asyncio
+import os
+import tempfile
+from typing import Any
+from typing import Dict
+from unittest.mock import AsyncMock
+from unittest.mock import MagicMock
+from unittest.mock import patch
+
+from google.adk.agents.invocation_context import InvocationContext
+from google.adk.agents.llm_agent import LlmAgent
+from google.adk.agents.readonly_context import ReadonlyContext
+from google.adk.sessions import InMemorySessionService
+from google.adk.sessions import Session
+from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager
+from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset
+import pytest
+
+from .utils import TestRunner
+
+# Import MCP dependencies
+try:
+ from mcp import StdioServerParameters
+ from mcp.client.session import ClientSession
+ from mcp.types import ListToolsResult
+ from mcp.types import Tool as McpTool
+except ImportError:
+ pytest.skip('MCP dependencies not available', allow_module_level=True)
+
+
+class TestMCPEnvironmentIntegration:
+ """Integration tests for MCP environment variable functionality."""
+
+ def create_test_agent_with_env_callback(
+ self, get_env_from_context_fn=None
+ ) -> LlmAgent:
+ """Create a test agent with MCP toolset and context to env mapper callback."""
+ # Create a temporary directory for the filesystem server
+ temp_dir = tempfile.mkdtemp()
+
+ return LlmAgent(
+ model='gemini-2.0-flash',
+ name='test_env_agent',
+ instruction=f"""
+ You are a test agent with access to filesystem operations.
+ Test directory: {temp_dir}
+ """,
+ tools=[
+ MCPToolset(
+ connection_params=StdioServerParameters(
+ command='npx',
+ args=[
+ '-y',
+ '@modelcontextprotocol/server-filesystem',
+ temp_dir,
+ ],
+ env={'INITIAL_VAR': 'initial_value'},
+ ),
+ get_env_from_context_fn=get_env_from_context_fn,
+ tool_filter=[
+ 'read_file',
+ 'list_directory',
+ 'directory_tree',
+ ],
+ )
+ ],
+ )
+
+ def sample_get_env_from_context_fn(
+ self, state_dict: Dict[str, Any]
+ ) -> Dict[str, str]:
+ """Sample context to env mapper callback."""
+ env_vars = {}
+
+ # Extract common environment variables
+ if 'api_key' in state_dict:
+ env_vars['API_KEY'] = state_dict['api_key']
+
+ if 'environment' in state_dict:
+ env_vars['ENVIRONMENT'] = state_dict['environment']
+
+ if 'user_config' in state_dict:
+ config = state_dict['user_config']
+ if isinstance(config, dict):
+ for key, value in config.items():
+ env_vars[f'USER_{key.upper()}'] = str(value)
+
+ return env_vars
+
+ @pytest.mark.asyncio
+ async def test_env_extraction_and_injection_with_session_state(
+ self, llm_backend
+ ):
+ """Test environment variable extraction from session state and injection into MCP server."""
+ # Create agent with environment callback
+ agent = self.create_test_agent_with_env_callback(
+ get_env_from_context_fn=self.sample_get_env_from_context_fn
+ )
+
+ # Create test runner
+ runner = TestRunner(agent)
+ session_service = runner.session_service
+
+ # Get the current session and add state with environment variables
+ session = await runner.get_current_session_async()
+ session.state.update({
+ 'api_key': 'test_api_key_123',
+ 'environment': 'production',
+ 'user_config': {'timeout': '30', 'retries': '3', 'debug': 'true'},
+ })
+
+ # Create proper InvocationContext for ReadonlyContext
+ invocation_context = InvocationContext(
+ invocation_id='test_invocation',
+ agent=agent,
+ session=session,
+ session_service=runner.session_service,
+ )
+
+ # Mock the MCP server components to verify environment variable injection
+ mock_session = AsyncMock(spec=ClientSession)
+ mock_session.list_tools.return_value = ListToolsResult(
+ tools=[
+ McpTool(
+ name='list_directory',
+ description='List directory contents',
+ inputSchema={'type': 'object', 'properties': {}},
+ )
+ ]
+ )
+ mock_session.call_tool.return_value = {'result': 'Mock directory listing'}
+
+ # Track environment variable injection
+ injected_env_vars = {}
+
+ # Mock the _inject_env_vars method to track environment variable injection
+ def mock_inject_env_vars(self, env_vars):
+ # Track the environment variables that were attempted to be injected
+ injected_env_vars.update(env_vars)
+ # Return the original connection params (since we're just testing the extraction)
+ return self._connection_params
+
+ # Create mock instances that behave like the real ones
+ mock_exit_stack_instance = AsyncMock()
+ mock_exit_stack_instance.aclose = AsyncMock()
+ mock_exit_stack_instance.enter_async_context = AsyncMock(
+ side_effect=[
+ [AsyncMock(), AsyncMock()], # transports
+ mock_session, # session
+ ]
+ )
+
+ with (
+ patch(
+ 'google.adk.tools.mcp_tool.mcp_session_manager.stdio_client'
+ ) as mock_stdio_client,
+ patch(
+ 'google.adk.tools.mcp_tool.mcp_session_manager.ClientSession',
+ return_value=mock_session,
+ ),
+ patch(
+ 'google.adk.tools.mcp_tool.mcp_session_manager.AsyncExitStack',
+ return_value=mock_exit_stack_instance,
+ ),
+ patch.object(
+ MCPSessionManager, '_inject_env_vars', mock_inject_env_vars
+ ),
+ ):
+
+ # Trigger tool retrieval which should extract and inject environment variables
+ mcp_toolset = agent.tools[0]
+
+ # Create a ReadonlyContext from the invocation context
+ readonly_context = ReadonlyContext(invocation_context)
+
+ # Get tools, which triggers environment variable extraction and injection
+ tools = await mcp_toolset.get_tools(readonly_context)
+
+ # Verify tools were retrieved
+ assert len(tools) == 1
+ assert tools[0].name == 'list_directory'
+
+ # Verify environment variables were extracted and injected
+ expected_env_vars = {
+ 'API_KEY': 'test_api_key_123', # From session state
+ 'ENVIRONMENT': 'production', # From session state
+ 'USER_TIMEOUT': '30', # From user_config
+ 'USER_RETRIES': '3', # From user_config
+ 'USER_DEBUG': 'true', # From user_config
+ }
+
+ # Check that environment variables were properly injected
+ for key, value in expected_env_vars.items():
+ assert (
+ key in injected_env_vars
+ ), f'Environment variable {key} was not injected'
+ assert (
+ injected_env_vars[key] == value
+ ), f'Environment variable {key} has wrong value'
+
+ @pytest.mark.asyncio
+ async def test_env_extraction_without_callback(self, llm_backend):
+ """Test that no environment variables are extracted when no callback is provided."""
+ # Create agent without environment callback
+ agent = self.create_test_agent_with_env_callback(
+ get_env_from_context_fn=None
+ )
+
+ # Create test runner
+ runner = TestRunner(agent)
+
+ # Get the current session and add state
+ session = await runner.get_current_session_async()
+ session.state.update(
+ {'api_key': 'test_api_key_123', 'environment': 'production'}
+ )
+
+ # Create proper InvocationContext for ReadonlyContext
+ invocation_context = InvocationContext(
+ invocation_id='test_invocation_2',
+ agent=agent,
+ session=session,
+ session_service=runner.session_service,
+ )
+
+ # Mock the MCP server components
+ mock_session = AsyncMock(spec=ClientSession)
+ mock_session.list_tools.return_value = ListToolsResult(
+ tools=[
+ McpTool(
+ name='list_directory',
+ description='List directory contents',
+ inputSchema={'type': 'object', 'properties': {}},
+ )
+ ]
+ )
+
+ # Track environment variable injection
+ injected_env_vars = {}
+
+ # Mock the _inject_env_vars method to track environment variable injection
+ def mock_inject_env_vars(self, env_vars):
+ # Track the environment variables that were attempted to be injected
+ injected_env_vars.update(env_vars)
+ # Return the original connection params (since we're just testing the extraction)
+ return self._connection_params
+
+ # Mock the _extract_env_from_context method to ensure it returns empty dict (no callback)
+ def mock_extract_env_from_context(self, readonly_context):
+ # Return empty dict since no callback is provided
+ return {}
+
+ # Create mock instances that behave like the real ones
+ mock_exit_stack_instance = AsyncMock()
+ mock_exit_stack_instance.aclose = AsyncMock()
+ mock_exit_stack_instance.enter_async_context = AsyncMock(
+ side_effect=[
+ [AsyncMock(), AsyncMock()], # transports
+ mock_session, # session
+ ]
+ )
+
+ with (
+ patch(
+ 'google.adk.tools.mcp_tool.mcp_session_manager.stdio_client'
+ ) as mock_stdio_client,
+ patch(
+ 'google.adk.tools.mcp_tool.mcp_session_manager.ClientSession',
+ return_value=mock_session,
+ ),
+ patch(
+ 'google.adk.tools.mcp_tool.mcp_session_manager.AsyncExitStack',
+ return_value=mock_exit_stack_instance,
+ ),
+ patch.object(
+ MCPSessionManager, '_inject_env_vars', mock_inject_env_vars
+ ),
+ patch.object(
+ MCPSessionManager,
+ '_extract_env_from_context',
+ mock_extract_env_from_context,
+ ),
+ ):
+
+ # Trigger tool retrieval
+ mcp_toolset = agent.tools[0]
+ readonly_context = ReadonlyContext(invocation_context)
+
+ # Get tools
+ tools = await mcp_toolset.get_tools(readonly_context)
+
+ # Verify tools were retrieved
+ assert len(tools) == 1
+
+ # Verify no environment variables were extracted/injected (since no callback)
+ assert injected_env_vars == {}
+
+ @pytest.mark.asyncio
+ async def test_env_extraction_with_callback_exception(self, llm_backend):
+ """Test behavior when environment transform callback raises an exception."""
+
+ def failing_env_callback(state_dict: Dict[str, Any]) -> Dict[str, str]:
+ """Callback that always raises an exception."""
+ raise ValueError('Callback failed')
+
+ # Create agent with failing callback
+ agent = self.create_test_agent_with_env_callback(
+ get_env_from_context_fn=failing_env_callback
+ )
+
+ # Create test runner
+ runner = TestRunner(agent)
+
+ # Get the current session and add state
+ session = await runner.get_current_session_async()
+ session.state.update({'api_key': 'test_api_key_123'})
+
+ # Create proper InvocationContext for ReadonlyContext
+ invocation_context = InvocationContext(
+ invocation_id='test_invocation_3',
+ agent=agent,
+ session=session,
+ session_service=runner.session_service,
+ )
+
+ # Mock the MCP server components
+ mock_session = AsyncMock(spec=ClientSession)
+ mock_session.list_tools.return_value = ListToolsResult(
+ tools=[
+ McpTool(
+ name='list_directory',
+ description='List directory contents',
+ inputSchema={'type': 'object', 'properties': {}},
+ )
+ ]
+ )
+
+ # Track environment variable injection
+ injected_env_vars = {}
+
+ # Mock the _inject_env_vars method to track environment variable injection
+ def mock_inject_env_vars(self, env_vars):
+ # Track the environment variables that were attempted to be injected
+ injected_env_vars.update(env_vars)
+ # Return the original connection params (since we're just testing the extraction)
+ return self._connection_params
+
+ # Mock the _extract_env_from_context method to simulate callback exception
+ def mock_extract_env_from_context(self, readonly_context):
+ # Simulate the failing callback - should catch exception and return empty dict
+ try:
+ if self._get_env_from_context_fn:
+ return self._get_env_from_context_fn(readonly_context.state)
+ return {}
+ except Exception:
+ # The real implementation should catch exceptions and return empty dict
+ return {}
+
+ # Create mock instances that behave like the real ones
+ mock_exit_stack_instance = AsyncMock()
+ mock_exit_stack_instance.aclose = AsyncMock()
+ mock_exit_stack_instance.enter_async_context = AsyncMock(
+ side_effect=[
+ [AsyncMock(), AsyncMock()], # transports
+ mock_session, # session
+ ]
+ )
+
+ with (
+ patch(
+ 'google.adk.tools.mcp_tool.mcp_session_manager.stdio_client'
+ ) as mock_stdio_client,
+ patch(
+ 'google.adk.tools.mcp_tool.mcp_session_manager.ClientSession',
+ return_value=mock_session,
+ ),
+ patch(
+ 'google.adk.tools.mcp_tool.mcp_session_manager.AsyncExitStack',
+ return_value=mock_exit_stack_instance,
+ ),
+ patch.object(
+ MCPSessionManager, '_inject_env_vars', mock_inject_env_vars
+ ),
+ patch.object(
+ MCPSessionManager,
+ '_extract_env_from_context',
+ mock_extract_env_from_context,
+ ),
+ ):
+
+ # Trigger tool retrieval (should not raise exception)
+ mcp_toolset = agent.tools[0]
+ readonly_context = ReadonlyContext(invocation_context)
+
+ # Get tools - should complete successfully despite callback failure
+ tools = await mcp_toolset.get_tools(readonly_context)
+
+ # Verify tools were retrieved
+ assert len(tools) == 1
+
+ # Verify no environment variables were injected (callback failed gracefully)
+ assert injected_env_vars == {}
+
+ @pytest.mark.asyncio
+ async def test_env_extraction_with_empty_session_state(self, llm_backend):
+ """Test environment variable extraction with empty session state."""
+ # Create agent with environment callback
+ agent = self.create_test_agent_with_env_callback(
+ get_env_from_context_fn=self.sample_get_env_from_context_fn
+ )
+
+ # Create test runner
+ runner = TestRunner(agent)
+
+ # Keep session state empty (don't add any state)
+ session = await runner.get_current_session_async()
+
+ # Create proper InvocationContext for ReadonlyContext
+ invocation_context = InvocationContext(
+ invocation_id='test_invocation_4',
+ agent=agent,
+ session=session,
+ session_service=runner.session_service,
+ )
+
+ # Mock the MCP server components
+ mock_session = AsyncMock(spec=ClientSession)
+ mock_session.list_tools.return_value = ListToolsResult(
+ tools=[
+ McpTool(
+ name='list_directory',
+ description='List directory contents',
+ inputSchema={'type': 'object', 'properties': {}},
+ )
+ ]
+ )
+
+ # Track environment variable injection
+ injected_env_vars = {}
+
+ # Mock the _inject_env_vars method to track environment variable injection
+ def mock_inject_env_vars(self, env_vars):
+ # Track the environment variables that were attempted to be injected
+ injected_env_vars.update(env_vars)
+ # Return the original connection params (since we're just testing the extraction)
+ return self._connection_params
+
+ # Mock the _extract_env_from_context method to return empty dict (empty state)
+ def mock_extract_env_from_context(self, readonly_context):
+ # With empty session state, should return empty dict
+ if self._get_env_from_context_fn:
+ return self._get_env_from_context_fn(readonly_context.state)
+ return {}
+
+ # Create mock instances that behave like the real ones
+ mock_exit_stack_instance = AsyncMock()
+ mock_exit_stack_instance.aclose = AsyncMock()
+ mock_exit_stack_instance.enter_async_context = AsyncMock(
+ side_effect=[
+ [AsyncMock(), AsyncMock()], # transports
+ mock_session, # session
+ ]
+ )
+
+ with (
+ patch(
+ 'google.adk.tools.mcp_tool.mcp_session_manager.stdio_client'
+ ) as mock_stdio_client,
+ patch(
+ 'google.adk.tools.mcp_tool.mcp_session_manager.ClientSession',
+ return_value=mock_session,
+ ),
+ patch(
+ 'google.adk.tools.mcp_tool.mcp_session_manager.AsyncExitStack',
+ return_value=mock_exit_stack_instance,
+ ),
+ patch.object(
+ MCPSessionManager, '_inject_env_vars', mock_inject_env_vars
+ ),
+ patch.object(
+ MCPSessionManager,
+ '_extract_env_from_context',
+ mock_extract_env_from_context,
+ ),
+ ):
+
+ # Trigger tool retrieval
+ mcp_toolset = agent.tools[0]
+ readonly_context = ReadonlyContext(invocation_context)
+
+ # Get tools
+ tools = await mcp_toolset.get_tools(readonly_context)
+
+ # Verify tools were retrieved
+ assert len(tools) == 1
+
+ # Verify no environment variables were extracted/injected (empty state)
+ assert injected_env_vars == {}
+
+ @pytest.mark.asyncio
+ async def test_env_extraction_with_non_stdio_connection(self, llm_backend):
+ """Test that environment variable extraction is not attempted for non-stdio connections."""
+ from google.adk.tools.mcp_tool.mcp_session_manager import SseServerParams
+
+ # Create agent with SSE connection (no environment variable support)
+ agent = LlmAgent(
+ model='gemini-2.0-flash',
+ name='test_sse_agent',
+ instruction='Test agent with SSE connection',
+ tools=[
+ MCPToolset(
+ connection_params=SseServerParams(url='http://example.com/sse'),
+ get_env_from_context_fn=self.sample_get_env_from_context_fn,
+ )
+ ],
+ )
+
+ # Create test runner
+ runner = TestRunner(agent)
+
+ # Get the current session and add state
+ session = await runner.get_current_session_async()
+ session.state.update(
+ {'api_key': 'test_api_key_123', 'environment': 'production'}
+ )
+
+ # Create proper InvocationContext for ReadonlyContext
+ invocation_context = InvocationContext(
+ invocation_id='test_invocation_5',
+ agent=agent,
+ session=session,
+ session_service=runner.session_service,
+ )
+
+ # Mock the MCP server components for SSE
+ mock_session = AsyncMock(spec=ClientSession)
+ mock_session.list_tools.return_value = ListToolsResult(
+ tools=[
+ McpTool(
+ name='test_tool',
+ description='Test tool',
+ inputSchema={'type': 'object', 'properties': {}},
+ )
+ ]
+ )
+
+ with (
+ patch(
+ 'google.adk.tools.mcp_tool.mcp_session_manager.sse_client'
+ ) as mock_sse_client,
+ patch(
+ 'google.adk.tools.mcp_tool.mcp_session_manager.ClientSession',
+ return_value=mock_session,
+ ),
+ patch(
+ 'google.adk.tools.mcp_tool.mcp_session_manager.AsyncExitStack'
+ ) as mock_exit_stack,
+ ):
+
+ mock_exit_stack.return_value.__aenter__ = AsyncMock()
+ mock_exit_stack.return_value.__aexit__ = AsyncMock()
+ mock_exit_stack.return_value.enter_async_context = AsyncMock(
+ side_effect=[
+ [AsyncMock(), AsyncMock()], # transports
+ mock_session, # session
+ ]
+ )
+
+ # Trigger tool retrieval
+ mcp_toolset = agent.tools[0]
+ readonly_context = ReadonlyContext(invocation_context)
+
+ # Get tools - should complete successfully without environment variable processing
+ tools = await mcp_toolset.get_tools(readonly_context)
+
+ # Verify tools were retrieved
+ assert len(tools) == 1
+ assert tools[0].name == 'test_tool'
+
+ # Verify SSE client was called (not stdio client)
+ mock_sse_client.assert_called_once()
+
+ def test_get_env_from_context_fn_signature(self):
+ """Test that environment transform callback has correct signature."""
+
+ def valid_callback(state_dict: Dict[str, Any]) -> Dict[str, str]:
+ return {'TEST_VAR': 'test_value'}
+
+ # Create agent with valid callback
+ agent = self.create_test_agent_with_env_callback(
+ get_env_from_context_fn=valid_callback
+ )
+
+ # Verify agent was created successfully
+ assert agent is not None
+ assert len(agent.tools) == 1
+
+ # Verify callback was set in the toolset
+ mcp_toolset = agent.tools[0]
+ assert mcp_toolset._get_env_from_context_fn == valid_callback
+
+ @pytest.mark.asyncio
+ async def test_session_manager_direct_env_injection(self, llm_backend):
+ """Test MCPSessionManager environment variable injection directly."""
+ connection_params = StdioServerParameters(
+ command='npx',
+ args=['-y', '@modelcontextprotocol/server-filesystem'],
+ env={'EXISTING_VAR': 'existing_value'},
+ )
+
+ def test_env_callback(state_dict: Dict[str, Any]) -> Dict[str, str]:
+ return {'NEW_VAR': 'new_value', 'API_KEY': 'secret123'}
+
+ session_manager = MCPSessionManager(
+ connection_params=connection_params,
+ get_env_from_context_fn=test_env_callback,
+ )
+
+ # Create mock readonly context with state
+ mock_context = MagicMock()
+ mock_context.state = {'api_key': 'secret123', 'config': {'debug': True}}
+
+ # Test environment variable extraction
+ extracted_env = session_manager._extract_env_from_context(mock_context)
+ expected_env = {'NEW_VAR': 'new_value', 'API_KEY': 'secret123'}
+ assert extracted_env == expected_env
+
+ # Test environment variable injection
+ updated_params = session_manager._inject_env_vars(extracted_env)
+ expected_merged_env = {
+ 'EXISTING_VAR': 'existing_value',
+ 'NEW_VAR': 'new_value',
+ 'API_KEY': 'secret123',
+ }
+ assert updated_params.env == expected_merged_env
+ assert updated_params.command == connection_params.command
+ assert updated_params.args == connection_params.args
diff --git a/tests/integration/utils/test_runner.py b/tests/integration/utils/test_runner.py
index 94c8d92682..725d8860f5 100644
--- a/tests/integration/utils/test_runner.py
+++ b/tests/integration/utils/test_runner.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import asyncio
import importlib
from typing import Optional
@@ -29,6 +30,9 @@
class TestRunner:
"""Agents runner for testing."""
+ # Prevent pytest from collecting this as a test class
+ __test__ = False
+
app_name = "test_app"
user_id = "test_user"
@@ -46,17 +50,35 @@ def __init__(
session_service=session_service,
)
self.session_service = session_service
- self.current_session_id = session_service.create_session(
- app_name=self.app_name, user_id=self.user_id
- ).id
+ self.current_session_id = None
+ self._session_initialized = False
- def new_session(self, session_id: Optional[str] = None) -> None:
- self.current_session_id = self.session_service.create_session(
+ async def _ensure_session(self) -> str:
+ """Ensure a session is created and return the session ID."""
+ if not self._session_initialized:
+ session = await self.session_service.create_session(
+ app_name=self.app_name, user_id=self.user_id
+ )
+ self.current_session_id = session.id
+ self._session_initialized = True
+ return self.current_session_id
+
+ async def new_session_async(self, session_id: Optional[str] = None) -> None:
+ session = await self.session_service.create_session(
app_name=self.app_name, user_id=self.user_id, session_id=session_id
- ).id
+ )
+ self.current_session_id = session.id
+ self._session_initialized = True
- def run(self, prompt: str) -> list[Event]:
- current_session = self.session_service.get_session(
+ def new_session(self, session_id: Optional[str] = None) -> None:
+ """Create a new session (sync version)."""
+ return asyncio.get_event_loop().run_until_complete(
+ self.new_session_async(session_id)
+ )
+
+ async def run_async(self, prompt: str) -> list[Event]:
+ await self._ensure_session()
+ current_session = await self.session_service.get_session(
app_name=self.app_name,
user_id=self.user_id,
session_id=self.current_session_id,
@@ -74,15 +96,31 @@ def run(self, prompt: str) -> list[Event]:
)
)
- def get_current_session(self) -> Optional[Session]:
- return self.session_service.get_session(
+ def run(self, prompt: str) -> list[Event]:
+ """Run the agent with a prompt (sync version)."""
+ return asyncio.get_event_loop().run_until_complete(self.run_async(prompt))
+
+ async def get_current_session_async(self) -> Optional[Session]:
+ await self._ensure_session()
+ return await self.session_service.get_session(
app_name=self.app_name,
user_id=self.user_id,
session_id=self.current_session_id,
)
+ def get_current_session(self) -> Optional[Session]:
+ """Get current session (sync version)."""
+ return asyncio.get_event_loop().run_until_complete(
+ self.get_current_session_async()
+ )
+
+ async def get_events_async(self) -> list[Event]:
+ session = await self.get_current_session_async()
+ return session.events
+
def get_events(self) -> list[Event]:
- return self.get_current_session().events
+ """Get events from current session (sync version)."""
+ return asyncio.get_event_loop().run_until_complete(self.get_events_async())
@classmethod
def from_agent_name(cls, agent_name: str):
@@ -91,7 +129,33 @@ def from_agent_name(cls, agent_name: str):
agent: Agent = agent_module.agent.root_agent
return cls(agent)
+ async def get_current_agent_name_async(self) -> str:
+ session = await self.get_current_session_async()
+ return self.agent_client._find_agent_to_run(session, self.agent).name
+
def get_current_agent_name(self) -> str:
- return self.agent_client._find_agent_to_run(
- self.get_current_session(), self.agent
- ).name
+ """Get current agent name (sync version)."""
+ return asyncio.get_event_loop().run_until_complete(
+ self.get_current_agent_name_async()
+ )
+
+ # Sync wrapper methods for backward compatibility
+ def run_sync(self, prompt: str) -> list[Event]:
+ """Synchronous wrapper for run method."""
+ return asyncio.get_event_loop().run_until_complete(self.run(prompt))
+
+ def get_events_sync(self) -> list[Event]:
+ """Synchronous wrapper for get_events method."""
+ return asyncio.get_event_loop().run_until_complete(self.get_events())
+
+ def get_current_agent_name_sync(self) -> str:
+ """Synchronous wrapper for get_current_agent_name method."""
+ return asyncio.get_event_loop().run_until_complete(
+ self.get_current_agent_name()
+ )
+
+ def new_session_sync(self, session_id: Optional[str] = None) -> None:
+ """Synchronous wrapper for new_session method."""
+ return asyncio.get_event_loop().run_until_complete(
+ self.new_session(session_id)
+ )
diff --git a/tests/unittests/tools/mcp_tool/test_mcp_tool_auth.py b/tests/unittests/tools/mcp_tool/test_mcp_tool_auth.py
new file mode 100644
index 0000000000..9aa5cbbaa1
--- /dev/null
+++ b/tests/unittests/tools/mcp_tool/test_mcp_tool_auth.py
@@ -0,0 +1,203 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for MCPTool auth functionality."""
+
+import logging
+from unittest.mock import AsyncMock
+from unittest.mock import MagicMock
+from unittest.mock import patch
+
+from google.adk.auth.auth_credential import AuthCredential
+from google.adk.auth.auth_credential import AuthCredentialTypes
+from google.adk.tools.tool_context import ToolContext
+import pytest
+
+
+# Mock MCP imports to avoid dependency issues in tests
+@pytest.fixture(autouse=True)
+def mock_mcp_imports():
+ """Mock MCP imports to avoid import errors in testing."""
+ with patch.dict(
+ "sys.modules",
+ {
+ "mcp": MagicMock(),
+ "mcp.types": MagicMock(),
+ "mcp.client": MagicMock(),
+ "mcp.client.stdio": MagicMock(),
+ "mcp.client.sse": MagicMock(),
+ "mcp.client.streamable_http": MagicMock(),
+ },
+ ):
+ # Mock the Tool class from mcp.types
+ mock_tool_class = MagicMock()
+ mock_tool_class.name = "test_tool"
+ mock_tool_class.description = "Test tool description"
+ mock_tool_class.inputSchema = {"type": "object", "properties": {}}
+
+ with patch(
+ "google.adk.tools.mcp_tool.mcp_tool.McpBaseTool", mock_tool_class
+ ):
+ yield
+
+
+# Import after mocking to avoid MCP dependency issues
+from google.adk.tools.mcp_tool.mcp_tool import MCPTool
+
+
+@pytest.fixture
+def mock_auth_scheme():
+ """Create a mock AuthScheme for testing."""
+ from fastapi.openapi.models import APIKey
+ from fastapi.openapi.models import APIKeyIn
+ from fastapi.openapi.models import SecuritySchemeType
+
+ return APIKey(
+ type=SecuritySchemeType.apiKey,
+ **{"in": APIKeyIn.header, "name": "X-API-Key"},
+ )
+
+
+@pytest.fixture
+def mock_auth_credential():
+ """Create a mock AuthCredential for testing."""
+ return AuthCredential(
+ auth_type=AuthCredentialTypes.API_KEY, api_key="test_api_key"
+ )
+
+
+@pytest.fixture
+def mock_mcp_tool():
+ """Create a mock MCP tool."""
+ tool = MagicMock()
+ tool.name = "test_tool"
+ tool.description = "Test tool description"
+ tool.inputSchema = {"type": "object", "properties": {}}
+ return tool
+
+
+@pytest.fixture
+def mock_session_manager():
+ """Create a mock MCP session manager."""
+ return MagicMock()
+
+
+class TestMCPToolAuth:
+ """Test auth functionality in MCPTool."""
+
+ def test_init_with_auth(
+ self,
+ mock_mcp_tool,
+ mock_session_manager,
+ mock_auth_scheme,
+ mock_auth_credential,
+ ):
+ """Test MCPTool initialization with auth parameters."""
+ tool = MCPTool(
+ mcp_tool=mock_mcp_tool,
+ mcp_session_manager=mock_session_manager,
+ auth_scheme=mock_auth_scheme,
+ auth_credential=mock_auth_credential,
+ )
+
+ assert tool._credentials_manager is not None
+ assert (
+ tool._credentials_manager._auth_config.auth_scheme == mock_auth_scheme
+ )
+ assert (
+ tool._credentials_manager._auth_config.raw_auth_credential
+ == mock_auth_credential
+ )
+
+ def test_init_without_auth(self, mock_mcp_tool, mock_session_manager):
+ """Test MCPTool initialization without auth parameters."""
+ tool = MCPTool(
+ mcp_tool=mock_mcp_tool, mcp_session_manager=mock_session_manager
+ )
+
+ assert tool._credentials_manager is None
+
+ @pytest.mark.asyncio
+ async def test_run_async_with_auth_logging(
+ self,
+ mock_mcp_tool,
+ mock_session_manager,
+ mock_auth_scheme,
+ mock_auth_credential,
+ caplog,
+ ):
+ """Test that run_async logs auth information when available."""
+ # Create mock session
+ mock_session = AsyncMock()
+ mock_session.call_tool.return_value = {"result": "success"}
+ mock_session_manager.create_session = AsyncMock(return_value=mock_session)
+
+ # Create tool with auth
+ tool = MCPTool(
+ mcp_tool=mock_mcp_tool,
+ mcp_session_manager=mock_session_manager,
+ auth_scheme=mock_auth_scheme,
+ auth_credential=mock_auth_credential,
+ )
+
+ # Create mock tool context
+ mock_tool_context = MagicMock(spec=ToolContext)
+
+ # Set logging level to capture info logs
+ with caplog.at_level(logging.INFO):
+ result = await tool.run_async(
+ args={"test": "value"}, tool_context=mock_tool_context
+ )
+
+ # Verify the tool was called
+ mock_session.call_tool.assert_called_once_with(
+ "test_tool", arguments={"test": "value"}
+ )
+ assert result == {"result": "success"}
+
+ # Check that the test executed successfully without errors
+ # The presence of auth configuration should not cause failures
+
+ @pytest.mark.asyncio
+ async def test_run_async_without_auth_no_logging(
+ self, mock_mcp_tool, mock_session_manager, caplog
+ ):
+ """Test that run_async doesn't log auth info when no auth is configured."""
+ # Create mock session
+ mock_session = AsyncMock()
+ mock_session.call_tool.return_value = {"result": "success"}
+ mock_session_manager.create_session = AsyncMock(return_value=mock_session)
+
+ # Create tool without auth
+ tool = MCPTool(
+ mcp_tool=mock_mcp_tool, mcp_session_manager=mock_session_manager
+ )
+
+ # Create mock tool context
+ mock_tool_context = MagicMock(spec=ToolContext)
+
+ # Set logging level to capture info logs
+ with caplog.at_level(logging.INFO):
+ result = await tool.run_async(
+ args={"test": "value"}, tool_context=mock_tool_context
+ )
+
+ # Verify the tool was called
+ mock_session.call_tool.assert_called_once_with(
+ "test_tool", arguments={"test": "value"}
+ )
+ assert result == {"result": "success"}
+
+ # Check that no auth-related errors occurred when no auth is configured
+ assert result == {"result": "success"}
diff --git a/tests/unittests/tools/mcp_tool/test_mcp_toolset_auth.py b/tests/unittests/tools/mcp_tool/test_mcp_toolset_auth.py
new file mode 100644
index 0000000000..e61be4e5c9
--- /dev/null
+++ b/tests/unittests/tools/mcp_tool/test_mcp_toolset_auth.py
@@ -0,0 +1,437 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for MCPToolset auth functionality."""
+
+from types import MappingProxyType
+from typing import Any
+from typing import Dict
+from typing import Optional
+from typing import Tuple
+from unittest.mock import AsyncMock
+from unittest.mock import MagicMock
+from unittest.mock import patch
+
+from google.adk.agents.readonly_context import ReadonlyContext
+from google.adk.auth.auth_credential import AuthCredential
+from google.adk.auth.auth_credential import AuthCredentialTypes
+from google.adk.auth.auth_schemes import AuthScheme
+from google.adk.tools.mcp_tool.mcp_toolset import AuthExtractionError
+from google.adk.tools.mcp_tool.mcp_toolset import GetAuthFromContextCallback
+import pytest
+
+
+# Mock MCP imports to avoid dependency issues in tests
+@pytest.fixture(autouse=True)
+def mock_mcp_imports():
+ """Mock MCP imports to avoid import errors in testing."""
+ from unittest.mock import MagicMock
+
+ with patch.dict(
+ "sys.modules",
+ {
+ "mcp": MagicMock(),
+ "mcp.types": MagicMock(),
+ },
+ ):
+ # Mock the specific classes we need
+ mock_stdio_params = MagicMock()
+ mock_list_tools_result = MagicMock()
+ mock_tool = MagicMock()
+
+ with (
+ patch(
+ "google.adk.tools.mcp_tool.mcp_toolset.StdioServerParameters",
+ mock_stdio_params,
+ ),
+ patch(
+ "google.adk.tools.mcp_tool.mcp_toolset.ListToolsResult",
+ mock_list_tools_result,
+ ),
+ ):
+ yield
+
+
+# Import after mocking to avoid MCP dependency issues
+from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset
+
+
+@pytest.fixture
+def mock_auth_scheme():
+ """Create a mock AuthScheme for testing."""
+ from fastapi.openapi.models import HTTPBearer
+
+ return HTTPBearer(bearerFormat="JWT")
+
+
+@pytest.fixture
+def mock_auth_credential():
+ """Create a mock AuthCredential for testing."""
+ return AuthCredential(
+ auth_type=AuthCredentialTypes.API_KEY, api_key="test_api_key"
+ )
+
+
+@pytest.fixture
+def mock_connection_params():
+ """Create mock connection parameters."""
+ mock_params = MagicMock()
+ mock_params.command = "test_command"
+ return mock_params
+
+
+@pytest.fixture
+def mock_readonly_context_with_auth(mock_auth_scheme, mock_auth_credential):
+ """Create a ReadonlyContext with auth information."""
+ context = MagicMock(spec=ReadonlyContext)
+ context.state = MappingProxyType({
+ "auth_scheme": mock_auth_scheme,
+ "auth_credential": mock_auth_credential,
+ "other_data": "test_value",
+ })
+ return context
+
+
+class TestMCPToolsetAuth:
+ """Test auth functionality in MCPToolset."""
+
+ def test_init_with_auth_callback(self, mock_connection_params):
+ """Test MCPToolset initialization with auth extraction callback."""
+
+ def custom_auth_callback(
+ state: Dict[str, Any],
+ ) -> Tuple[Optional[AuthScheme], Optional[AuthCredential]]:
+ return None, None
+
+ with patch("google.adk.tools.mcp_tool.mcp_toolset.MCPSessionManager"):
+ toolset = MCPToolset(
+ connection_params=mock_connection_params,
+ get_auth_from_context_fn=custom_auth_callback,
+ )
+
+ assert toolset._get_auth_from_context_fn == custom_auth_callback
+
+ def test_init_without_auth_callback(self, mock_connection_params):
+ """Test MCPToolset initialization without auth extraction callback."""
+ with patch("google.adk.tools.mcp_tool.mcp_toolset.MCPSessionManager"):
+ toolset = MCPToolset(connection_params=mock_connection_params)
+
+ assert toolset._get_auth_from_context_fn is None
+
+ def test_extract_auth_no_context_returns_init_values(
+ self,
+ mock_connection_params,
+ mock_auth_scheme,
+ mock_auth_credential,
+ ):
+ """Test fallback to __init__ auth_scheme and auth_credential when context is None."""
+ with patch("google.adk.tools.mcp_tool.mcp_toolset.MCPSessionManager"):
+ toolset = MCPToolset(
+ connection_params=mock_connection_params,
+ auth_scheme=mock_auth_scheme,
+ auth_credential=mock_auth_credential,
+ )
+ auth_scheme, auth_credential = toolset._extract_auth_from_context(None)
+ assert auth_scheme == mock_auth_scheme
+ assert auth_credential == mock_auth_credential
+
+ def test_extract_auth_no_context_no_init_values(self, mock_connection_params):
+ """Test auth extraction with no context and no init values returns (None, None)."""
+ with patch("google.adk.tools.mcp_tool.mcp_toolset.MCPSessionManager"):
+ toolset = MCPToolset(connection_params=mock_connection_params)
+ auth_scheme, auth_credential = toolset._extract_auth_from_context(None)
+ assert auth_scheme is None
+ assert auth_credential is None
+
+ def test_extract_auth_no_callback_returns_init_values(
+ self, mock_connection_params, mock_auth_scheme, mock_auth_credential
+ ):
+ """Test that context without callback returns __init__ values."""
+ context = MagicMock(spec=ReadonlyContext)
+ context.state = MagicMock(return_value={"some_data": "value"})
+ with patch("google.adk.tools.mcp_tool.mcp_toolset.MCPSessionManager"):
+ toolset = MCPToolset(
+ connection_params=mock_connection_params,
+ auth_scheme=mock_auth_scheme,
+ auth_credential=mock_auth_credential,
+ )
+ auth_scheme, auth_credential = toolset._extract_auth_from_context(context)
+ assert auth_scheme == mock_auth_scheme
+ assert auth_credential == mock_auth_credential
+
+ def test_extract_auth_with_callback_success(
+ self,
+ mock_connection_params,
+ mock_auth_scheme,
+ mock_auth_credential,
+ ):
+ """Test successful auth extraction using custom callback."""
+
+ def custom_auth_callback(state: Dict[str, Any]):
+ if state.get("custom_auth"):
+ return mock_auth_scheme, mock_auth_credential
+ return None, None
+
+ context = MagicMock(spec=ReadonlyContext)
+ context.state = MagicMock(return_value={"custom_auth": {"type": "bearer"}})
+
+ with patch("google.adk.tools.mcp_tool.mcp_toolset.MCPSessionManager"):
+ toolset = MCPToolset(
+ connection_params=mock_connection_params,
+ get_auth_from_context_fn=custom_auth_callback,
+ )
+ auth_scheme, auth_credential = toolset._extract_auth_from_context(context)
+ assert auth_scheme == mock_auth_scheme
+ assert auth_credential == mock_auth_credential
+
+ def test_extract_auth_fallback_to_direct_state_lookup(
+ self,
+ mock_connection_params,
+ mock_auth_scheme,
+ mock_auth_credential,
+ ):
+ """Test fallback to direct state lookup when no callback is provided."""
+ context = MagicMock(spec=ReadonlyContext)
+ context.state = MagicMock(
+ return_value={
+ "auth_scheme": mock_auth_scheme,
+ "auth_credential": mock_auth_credential,
+ }
+ )
+
+ with patch("google.adk.tools.mcp_tool.mcp_toolset.MCPSessionManager"):
+ toolset = MCPToolset(
+ connection_params=mock_connection_params,
+ )
+ auth_scheme, auth_credential = toolset._extract_auth_from_context(context)
+ assert auth_scheme == mock_auth_scheme
+ assert auth_credential == mock_auth_credential
+
+ def test_extract_auth_callback_execution_error(
+ self,
+ mock_connection_params,
+ ):
+ """Test that callback execution errors are wrapped in AuthExtractionError."""
+
+ def failing_callback(state: Dict[str, Any]):
+ raise ValueError("Callback failed")
+
+ context = MagicMock(spec=ReadonlyContext)
+ context.state = MagicMock(return_value={})
+
+ with patch("google.adk.tools.mcp_tool.mcp_toolset.MCPSessionManager"):
+ toolset = MCPToolset(
+ connection_params=mock_connection_params,
+ get_auth_from_context_fn=failing_callback,
+ )
+
+ with pytest.raises(
+ AuthExtractionError, match="Auth extraction callback failed"
+ ):
+ toolset._extract_auth_from_context(context)
+
+ def test_extract_auth_callback_invalid_return_type(
+ self,
+ mock_connection_params,
+ ):
+ """Test that invalid callback return types raise AuthExtractionError."""
+
+ def invalid_callback(
+ state: Dict[str, Any],
+ ) -> Any: # Use Any to avoid type checker
+ return "invalid_return" # Should return tuple
+
+ context = MagicMock(spec=ReadonlyContext)
+ context.state = MagicMock(return_value={})
+
+ with patch("google.adk.tools.mcp_tool.mcp_toolset.MCPSessionManager"):
+ toolset = MCPToolset(
+ connection_params=mock_connection_params,
+ get_auth_from_context_fn=invalid_callback, # type: ignore
+ )
+
+ with pytest.raises(AuthExtractionError, match="must return a tuple"):
+ toolset._extract_auth_from_context(context)
+
+ def test_extract_auth_callback_invalid_tuple_length(
+ self,
+ mock_connection_params,
+ ):
+ """Test that callback returning wrong tuple length raises AuthExtractionError."""
+
+ def invalid_callback(
+ state: Dict[str, Any],
+ ) -> Any: # Use Any to avoid type checker
+ return (None,) # Should return tuple of length 2
+
+ context = MagicMock(spec=ReadonlyContext)
+ context.state = MagicMock(return_value={})
+
+ with patch("google.adk.tools.mcp_tool.mcp_toolset.MCPSessionManager"):
+ toolset = MCPToolset(
+ connection_params=mock_connection_params,
+ get_auth_from_context_fn=invalid_callback, # type: ignore
+ )
+
+ with pytest.raises(AuthExtractionError, match="must return a tuple"):
+ toolset._extract_auth_from_context(context)
+
+ def test_extract_auth_callback_invalid_auth_scheme_type(
+ self,
+ mock_connection_params,
+ ):
+ """Test that invalid auth_scheme type raises AuthExtractionError."""
+
+ def invalid_callback(
+ state: Dict[str, Any],
+ ) -> Any: # Use Any to avoid type checker
+ return (
+ "invalid_auth_scheme",
+ None,
+ ) # auth_scheme should be AuthScheme or None
+
+ context = MagicMock(spec=ReadonlyContext)
+ context.state = MagicMock(return_value={})
+
+ with patch("google.adk.tools.mcp_tool.mcp_toolset.MCPSessionManager"):
+ toolset = MCPToolset(
+ connection_params=mock_connection_params,
+ get_auth_from_context_fn=invalid_callback, # type: ignore
+ )
+
+ with pytest.raises(AuthExtractionError, match="invalid auth_scheme type"):
+ toolset._extract_auth_from_context(context)
+
+ def test_extract_auth_callback_invalid_auth_credential_type(
+ self,
+ mock_connection_params,
+ mock_auth_scheme,
+ ):
+ """Test that invalid auth_credential type raises AuthExtractionError."""
+
+ def invalid_callback(
+ state: Dict[str, Any],
+ ) -> Any: # Use Any to avoid type checker
+ return (
+ mock_auth_scheme,
+ "invalid_credential",
+ ) # credential should be AuthCredential or None
+
+ context = MagicMock(spec=ReadonlyContext)
+ context.state = MagicMock(return_value={})
+
+ with patch("google.adk.tools.mcp_tool.mcp_toolset.MCPSessionManager"):
+ toolset = MCPToolset(
+ connection_params=mock_connection_params,
+ get_auth_from_context_fn=invalid_callback, # type: ignore
+ )
+
+ with pytest.raises(
+ AuthExtractionError, match="invalid auth_credential type"
+ ):
+ toolset._extract_auth_from_context(context)
+
+ def test_extract_auth_state_extraction_error(
+ self,
+ mock_connection_params,
+ ):
+ """Test that state extraction errors are wrapped in AuthExtractionError."""
+
+ def valid_callback(state: Dict[str, Any]):
+ return None, None
+
+ context = MagicMock(spec=ReadonlyContext)
+ context.state = MagicMock(side_effect=TypeError("State extraction failed"))
+
+ with patch("google.adk.tools.mcp_tool.mcp_toolset.MCPSessionManager"):
+ toolset = MCPToolset(
+ connection_params=mock_connection_params,
+ get_auth_from_context_fn=valid_callback,
+ )
+
+ with pytest.raises(AuthExtractionError, match="Failed to extract state"):
+ toolset._extract_auth_from_context(context)
+
+ def test_extract_auth_callback_allows_none_values(
+ self,
+ mock_connection_params,
+ ):
+ """Test that callback can return None values for auth_scheme and auth_credential."""
+
+ def callback_returning_none(state: Dict[str, Any]):
+ return None, None
+
+ context = MagicMock(spec=ReadonlyContext)
+ context.state = MagicMock(return_value={})
+
+ with patch("google.adk.tools.mcp_tool.mcp_toolset.MCPSessionManager"):
+ toolset = MCPToolset(
+ connection_params=mock_connection_params,
+ get_auth_from_context_fn=callback_returning_none,
+ )
+
+ auth_scheme, auth_credential = toolset._extract_auth_from_context(context)
+ assert auth_scheme is None
+ assert auth_credential is None
+
+ @pytest.mark.asyncio
+ async def test_get_tools_passes_auth_to_mcp_tool(
+ self,
+ mock_connection_params,
+ mock_readonly_context_with_auth,
+ mock_auth_scheme,
+ mock_auth_credential,
+ ):
+ """Test that get_tools passes auth parameters to MCPTool constructor."""
+ # Mock the MCP session and tools
+ mock_session = AsyncMock()
+ mock_tool = MagicMock()
+ mock_tool.name = "test_tool"
+ mock_tool.description = "Test tool description"
+
+ mock_tools_response = MagicMock()
+ mock_tools_response.tools = [mock_tool]
+ mock_session.list_tools.return_value = mock_tools_response
+
+ with (
+ patch(
+ "google.adk.tools.mcp_tool.mcp_toolset.MCPSessionManager"
+ ) as mock_session_manager_class,
+ patch(
+ "google.adk.tools.mcp_tool.mcp_toolset.MCPTool"
+ ) as mock_mcp_tool_class,
+ ):
+
+ # Setup session manager mock to return async mock properly
+ mock_session_manager = AsyncMock()
+ mock_session_manager.create_session.return_value = mock_session
+ mock_session_manager_class.return_value = mock_session_manager
+
+ # Setup MCPTool mock
+ mock_mcp_tool_instance = MagicMock()
+ mock_mcp_tool_class.return_value = mock_mcp_tool_instance
+
+ toolset = MCPToolset(connection_params=mock_connection_params)
+
+ # Mock the _is_tool_selected method to return True
+ toolset._is_tool_selected = MagicMock(return_value=True)
+
+ tools = await toolset.get_tools(mock_readonly_context_with_auth)
+
+ # Verify MCPTool was called with auth parameters
+ assert len(tools) == 1
+ mock_mcp_tool_class.assert_called_once()
+ call_args = mock_mcp_tool_class.call_args
+ assert "auth_scheme" in call_args.kwargs
+ assert "auth_credential" in call_args.kwargs
diff --git a/tests/unittests/tools/mcp_tool/test_mcp_toolset_env.py b/tests/unittests/tools/mcp_tool/test_mcp_toolset_env.py
new file mode 100644
index 0000000000..cd21a76586
--- /dev/null
+++ b/tests/unittests/tools/mcp_tool/test_mcp_toolset_env.py
@@ -0,0 +1,275 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for MCPToolset environment variable functionality."""
+
+from types import MappingProxyType
+from typing import Any
+from typing import Dict
+from typing import Mapping
+from unittest.mock import AsyncMock
+from unittest.mock import MagicMock
+from unittest.mock import patch
+
+from google.adk.agents.readonly_context import ReadonlyContext
+import pytest
+
+
+# Mock MCP imports to avoid dependency issues in tests
+@pytest.fixture(autouse=True)
+def mock_mcp_imports():
+ """Mock MCP imports to avoid import errors in testing."""
+ from unittest.mock import MagicMock
+
+ with patch.dict(
+ 'sys.modules',
+ {
+ 'mcp': MagicMock(),
+ 'mcp.types': MagicMock(),
+ },
+ ):
+ # Mock the specific classes we need
+ mock_stdio_params = MagicMock()
+ mock_list_tools_result = MagicMock()
+
+ with (
+ patch(
+ 'google.adk.tools.mcp_tool.mcp_toolset.StdioServerParameters',
+ mock_stdio_params,
+ ),
+ patch(
+ 'google.adk.tools.mcp_tool.mcp_toolset.ListToolsResult',
+ mock_list_tools_result,
+ ),
+ ):
+ yield
+
+
+# Import after mocking to avoid MCP dependency issues
+from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset
+
+
+@pytest.fixture
+def mock_stdio_params():
+ """Create a mock StdioServerParameters instance."""
+ from unittest.mock import MagicMock
+
+ mock_params = MagicMock()
+ mock_params.command = 'npx'
+ mock_params.args = ['-y', '@modelcontextprotocol/server-filesystem']
+ mock_params.env = {'EXISTING_VAR': 'existing_value'}
+ return mock_params
+
+
+@pytest.fixture
+def sample_get_env_from_context_fn():
+ """Create a sample get_env_from_context_fn callback."""
+
+ def env_callback(state: Mapping[str, Any]) -> Dict[str, str]:
+ env_vars = {}
+ if 'api_key' in state:
+ env_vars['API_KEY'] = state['api_key']
+ if 'workspace_path' in state:
+ env_vars['WORKSPACE_PATH'] = state['workspace_path']
+ return env_vars
+
+ return env_callback
+
+
+@pytest.fixture
+def mock_readonly_context():
+ """Create a mock ReadonlyContext with sample state."""
+ context = MagicMock(spec=ReadonlyContext)
+ context.state = MappingProxyType({
+ 'api_key': 'test_api_key_123',
+ 'workspace_path': '/home/user/workspace',
+ 'other_data': 'some_value',
+ })
+ return context
+
+
+class TestMCPToolsetEnv:
+ """Test environment variable functionality in MCPToolset."""
+
+ def test_init_with_env_callback(
+ self, mock_stdio_params, sample_get_env_from_context_fn
+ ):
+ """Test MCPToolset initialization with context to env mapper callback."""
+ with patch(
+ 'google.adk.tools.mcp_tool.mcp_toolset.MCPSessionManager'
+ ) as mock_session_manager:
+ toolset = MCPToolset(
+ connection_params=mock_stdio_params,
+ get_env_from_context_fn=sample_get_env_from_context_fn,
+ )
+
+ # Verify the session manager was created without the env callback
+ # (since it's now handled in MCPToolset)
+ mock_session_manager.assert_called_once_with(
+ connection_params=mock_stdio_params,
+ errlog=toolset._errlog,
+ )
+
+ assert toolset._get_env_from_context_fn == sample_get_env_from_context_fn
+
+ def test_init_without_env_callback(self, mock_stdio_params):
+ """Test MCPToolset initialization without environment callback."""
+ with patch(
+ 'google.adk.tools.mcp_tool.mcp_toolset.MCPSessionManager'
+ ) as mock_session_manager:
+ toolset = MCPToolset(connection_params=mock_stdio_params)
+
+ # Verify the session manager was created without env callback
+ mock_session_manager.assert_called_once_with(
+ connection_params=mock_stdio_params,
+ errlog=toolset._errlog,
+ )
+
+ assert toolset._get_env_from_context_fn is None
+
+ @pytest.mark.asyncio
+ async def test_get_tools_extracts_env_and_calls_session_manager(
+ self,
+ mock_stdio_params,
+ sample_get_env_from_context_fn,
+ mock_readonly_context,
+ ):
+ """Test that get_tools extracts environment variables and calls session manager correctly."""
+ with patch(
+ 'google.adk.tools.mcp_tool.mcp_toolset.MCPSessionManager'
+ ) as mock_session_manager_class:
+ # Set up mock session manager instance
+ mock_session_manager = AsyncMock()
+ mock_session = AsyncMock()
+ mock_session.list_tools.return_value = MagicMock(tools=[])
+ mock_session_manager.create_session.return_value = mock_session
+ mock_session_manager_class.return_value = mock_session_manager
+
+ toolset = MCPToolset(
+ connection_params=mock_stdio_params,
+ get_env_from_context_fn=sample_get_env_from_context_fn,
+ )
+
+ # Call get_tools with readonly_context
+ await toolset.get_tools(mock_readonly_context)
+
+ # Verify create_session was called without parameters (new architecture)
+ mock_session_manager.create_session.assert_called_once_with()
+
+ # Verify that the session manager was updated with new connection params
+ # (this happens when environment variables are extracted and injected)
+ mock_session_manager.update_connection_params.assert_called_once()
+
+ @pytest.mark.asyncio
+ async def test_get_tools_without_context(
+ self, mock_stdio_params, sample_get_env_from_context_fn
+ ):
+ """Test that get_tools works without readonly_context."""
+ with patch(
+ 'google.adk.tools.mcp_tool.mcp_toolset.MCPSessionManager'
+ ) as mock_session_manager_class:
+ # Set up mock session manager instance
+ mock_session_manager = AsyncMock()
+ mock_session = AsyncMock()
+ mock_session.list_tools.return_value = MagicMock(tools=[])
+ mock_session_manager.create_session.return_value = mock_session
+ mock_session_manager_class.return_value = mock_session_manager
+
+ toolset = MCPToolset(
+ connection_params=mock_stdio_params,
+ get_env_from_context_fn=sample_get_env_from_context_fn,
+ )
+
+ # Call get_tools without readonly_context
+ await toolset.get_tools(None)
+
+ # Verify create_session was called without parameters (new architecture)
+ mock_session_manager.create_session.assert_called_once_with()
+
+ # Verify that update_connection_params was NOT called since no context was provided
+ mock_session_manager.update_connection_params.assert_not_called()
+
+ def test_both_auth_and_env_callbacks(self, mock_stdio_params):
+ """Test MCPToolset initialization with both auth and env callbacks."""
+
+ def auth_callback(state):
+ return None, None
+
+ def env_callback(state: Mapping[str, Any]) -> Dict[str, str]:
+ return {'TEST_VAR': 'test_value'}
+
+ with patch(
+ 'google.adk.tools.mcp_tool.mcp_toolset.MCPSessionManager'
+ ) as mock_session_manager:
+ toolset = MCPToolset(
+ connection_params=mock_stdio_params,
+ get_auth_from_context_fn=auth_callback,
+ get_env_from_context_fn=env_callback,
+ )
+
+ # Verify both callbacks are stored
+ assert toolset._get_auth_from_context_fn == auth_callback
+ assert toolset._get_env_from_context_fn == env_callback
+
+ # Verify the session manager was created without the env callback (new architecture)
+ mock_session_manager.assert_called_once_with(
+ connection_params=mock_stdio_params,
+ errlog=toolset._errlog,
+ )
+
+ @pytest.mark.asyncio
+ async def test_integration_env_extraction_and_injection(
+ self, mock_stdio_params, mock_readonly_context
+ ):
+ """Test end-to-end environment variable extraction and injection."""
+
+ def env_callback(state: Mapping[str, Any]) -> Dict[str, str]:
+ return {
+ 'API_KEY': state.get('api_key', ''),
+ 'WORKSPACE_PATH': state.get('workspace_path', ''),
+ }
+
+ with (
+ patch(
+ 'google.adk.tools.mcp_tool.mcp_toolset.MCPSessionManager'
+ ) as mock_session_manager_class,
+ patch(
+ 'google.adk.tools.mcp_tool.mcp_toolset.MCPTool'
+ ) as mock_mcp_tool_class,
+ ):
+
+ # Set up mock session manager instance
+ mock_session_manager = AsyncMock()
+ mock_session = AsyncMock()
+ mock_tool_response = MagicMock()
+ mock_tool_response.tools = []
+ mock_session.list_tools.return_value = mock_tool_response
+ mock_session_manager.create_session.return_value = mock_session
+ mock_session_manager_class.return_value = mock_session_manager
+
+ toolset = MCPToolset(
+ connection_params=mock_stdio_params,
+ get_env_from_context_fn=env_callback,
+ )
+
+ # Call get_tools with context containing state
+ tools = await toolset.get_tools(mock_readonly_context)
+
+ # Verify the session manager's create_session was called without parameters (new architecture)
+ mock_session_manager.create_session.assert_called_once_with()
+
+ # Verify list_tools was called on the session
+ mock_session.list_tools.assert_called_once()
+
+ assert isinstance(tools, list)