From ea0c1dd18faa84c47359e03ebbf877c4c521480f Mon Sep 17 00:00:00 2001 From: Eran Cohen Date: Thu, 26 Jun 2025 12:24:43 +0300 Subject: [PATCH] fix: resolve asyncio event loop conflicts in MCP server registration Add async/await support for LlamaStackAsLibraryClient to prevent "Cannot run event loop while another loop is running" RuntimeError. Changes: - Introduce register_mcp_servers_async() for async MCP server registration - Refactor MCP registration logic into separate sync/async helper functions - Use async client methods (client.async_client) for library clients to avoid "Cannot run event loop while another loop is running" errors - Add async tests for library client configuration The LlamaStackAsLibraryClient requires async initialization, but the previous sync implementation caused event loop conflicts. This change provides dual support for both service clients (sync) and library clients (async). Fixes RuntimeError when configuration.llama_stack.use_as_library_client = true --- src/app/main.py | 9 ++-- src/utils/common.py | 89 +++++++++++++++++++++++++++++---- tests/unit/utils/test_common.py | 61 +++++++++++++++++++++- 3 files changed, 141 insertions(+), 18 deletions(-) diff --git a/src/app/main.py b/src/app/main.py index 18feb0d83..c315136fb 100644 --- a/src/app/main.py +++ b/src/app/main.py @@ -6,7 +6,7 @@ import version from log import get_logger from configuration import configuration -from utils.common import register_mcp_servers +from utils.common import register_mcp_servers_async logger = get_logger(__name__) @@ -32,8 +32,7 @@ @app.on_event("startup") async def startup_event() -> None: """Perform logger setup on service startup.""" + logger.info("Registering MCP servers") + await register_mcp_servers_async(logger, configuration.configuration) get_logger("app.endpoints.handlers") - logger.info("Starting up: registering MCP servers") - register_mcp_servers(logger, configuration.configuration) - logger.info("Including routers") - routers.include_routers(app) + logger.info("App startup complete") diff --git a/src/utils/common.py b/src/utils/common.py index 7ed415031..911be00af 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -1,11 +1,17 @@ """Common utilities for the project.""" -from typing import Any +from typing import Any, List, cast from logging import Logger +from llama_stack_client import LlamaStackClient + +from llama_stack.distribution.library_client import ( + LlamaStackAsLibraryClient, + AsyncLlamaStackAsLibraryClient, +) from client import get_llama_stack_client -from models.config import Configuration +from models.config import Configuration, ModelContextProtocolServer # TODO(lucasagomes): implement this function to retrieve user ID from auth @@ -22,20 +28,81 @@ def retrieve_user_id(auth: Any) -> str: # pylint: disable=unused-argument return "user_id_placeholder" +async def register_mcp_servers_async( + logger: Logger, configuration: Configuration +) -> None: + """Register Model Context Protocol (MCP) servers with the LlamaStack client (async).""" + if configuration.llama_stack.use_as_library_client: + # Library client - use async interface + # config.py validation ensures library_client_config_path is not None + # when use_as_library_client is True + config_path = cast(str, configuration.llama_stack.library_client_config_path) + client = LlamaStackAsLibraryClient(config_path) + await client.async_client.initialize() + + await _register_mcp_toolgroups_async( + client.async_client, configuration.mcp_servers, logger + ) + else: + # Service client - use sync interface + register_mcp_servers(logger, configuration) + + def register_mcp_servers(logger: Logger, configuration: Configuration) -> None: - """Register Model Context Protocol (MCP) servers with the LlamaStack client.""" - # Get list of registered tools and extract their toolgroup IDs + """Register Model Context Protocol (MCP) servers with the LlamaStack client (sync).""" + # Service client - use sync interface client = get_llama_stack_client(configuration.llama_stack) + + _register_mcp_toolgroups_sync(client, configuration.mcp_servers, logger) + + +async def _register_mcp_toolgroups_async( + client: AsyncLlamaStackAsLibraryClient, + mcp_servers: List[ModelContextProtocolServer], + logger: Logger, +) -> None: + """Async logic for registering MCP toolgroups.""" + # Get registered tools + registered_tools = await client.tools.list() + registered_toolgroups = [tool.toolgroup_id for tool in registered_tools] + logger.debug("Registered toolgroups: %s", set(registered_toolgroups)) + + # Register toolgroups for MCP servers if not already registered + for mcp in mcp_servers: + if mcp.name not in registered_toolgroups: + logger.debug("Registering MCP server: %s, %s", mcp.name, mcp.url) + + registration_params = { + "toolgroup_id": mcp.name, + "provider_id": mcp.provider_id, + "mcp_endpoint": {"uri": mcp.url}, + } + + await client.toolgroups.register(**registration_params) + logger.debug("MCP server %s registered successfully", mcp.name) + + +def _register_mcp_toolgroups_sync( + client: LlamaStackClient, + mcp_servers: List[ModelContextProtocolServer], + logger: Logger, +) -> None: + """Sync logic for registering MCP toolgroups.""" + # Get registered tools registered_tools = client.tools.list() registered_toolgroups = [tool.toolgroup_id for tool in registered_tools] logger.debug("Registered toolgroups: %s", set(registered_toolgroups)) + # Register toolgroups for MCP servers if not already registered - for mcp in configuration.mcp_servers: - if mcp.name not in registered_toolgroups: # required + for mcp in mcp_servers: + if mcp.name not in registered_toolgroups: logger.debug("Registering MCP server: %s, %s", mcp.name, mcp.url) - client.toolgroups.register( - toolgroup_id=mcp.name, - provider_id=mcp.provider_id, - mcp_endpoint={"uri": mcp.url}, - ) + + registration_params = { + "toolgroup_id": mcp.name, + "provider_id": mcp.provider_id, + "mcp_endpoint": {"uri": mcp.url}, + } + + client.toolgroups.register(**registration_params) logger.debug("MCP server %s registered successfully", mcp.name) diff --git a/tests/unit/utils/test_common.py b/tests/unit/utils/test_common.py index d727e5ee6..64ff7c93f 100644 --- a/tests/unit/utils/test_common.py +++ b/tests/unit/utils/test_common.py @@ -1,9 +1,14 @@ """Test module for utils/common.py.""" -from unittest.mock import Mock +import pytest +from unittest.mock import Mock, AsyncMock from logging import Logger -from utils.common import retrieve_user_id, register_mcp_servers +from utils.common import ( + retrieve_user_id, + register_mcp_servers, + register_mcp_servers_async, +) from models.config import ( Configuration, ServiceConfiguration, @@ -222,3 +227,55 @@ def test_register_mcp_servers_with_custom_provider(mocker): provider_id="my-custom-provider", mcp_endpoint={"uri": "https://custom.example.com/mcp"}, ) + + +@pytest.mark.asyncio +async def test_register_mcp_servers_async_with_library_client(mocker): + """Test register_mcp_servers_async with library client configuration.""" + # Mock the logger + mock_logger = Mock(spec=Logger) + + # Mock the LlamaStackAsLibraryClient + mock_library_client = Mock() + mock_async_client = AsyncMock() + mock_async_client.initialize = AsyncMock() + mock_library_client.async_client = mock_async_client + + # Mock tools.list to return empty list + mock_tool = Mock() + mock_tool.toolgroup_id = "existing-tool" + mock_async_client.tools.list = AsyncMock(return_value=[mock_tool]) + mock_async_client.toolgroups.register = AsyncMock() + + mocker.patch( + "utils.common.LlamaStackAsLibraryClient", return_value=mock_library_client + ) + + # Create configuration with library client enabled + mcp_server = ModelContextProtocolServer( + name="test-server", url="http://localhost:8080" + ) + config = Configuration( + name="test", + service=ServiceConfiguration(), + llama_stack=LLamaStackConfiguration( + use_as_library_client=True, + library_client_config_path="/path/to/config.yaml", + ), + user_data_collection=UserDataCollection(feedback_disabled=True), + mcp_servers=[mcp_server], + ) + + # Call the async function + await register_mcp_servers_async(mock_logger, config) + + # Verify initialization was called + mock_async_client.initialize.assert_called_once() + # Verify tools.list was called + mock_async_client.tools.list.assert_called_once() + # Verify toolgroups.register was called for the new server + mock_async_client.toolgroups.register.assert_called_once_with( + toolgroup_id="test-server", + provider_id="model-context-protocol", + mcp_endpoint={"uri": "http://localhost:8080"}, + )