diff --git a/src/app/main.py b/src/app/main.py index 18feb0d83..c315136fb 100644 --- a/src/app/main.py +++ b/src/app/main.py @@ -6,7 +6,7 @@ import version from log import get_logger from configuration import configuration -from utils.common import register_mcp_servers +from utils.common import register_mcp_servers_async logger = get_logger(__name__) @@ -32,8 +32,7 @@ @app.on_event("startup") async def startup_event() -> None: """Perform logger setup on service startup.""" + logger.info("Registering MCP servers") + await register_mcp_servers_async(logger, configuration.configuration) get_logger("app.endpoints.handlers") - logger.info("Starting up: registering MCP servers") - register_mcp_servers(logger, configuration.configuration) - logger.info("Including routers") - routers.include_routers(app) + logger.info("App startup complete") diff --git a/src/utils/common.py b/src/utils/common.py index 7ed415031..911be00af 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -1,11 +1,17 @@ """Common utilities for the project.""" -from typing import Any +from typing import Any, List, cast from logging import Logger +from llama_stack_client import LlamaStackClient + +from llama_stack.distribution.library_client import ( + LlamaStackAsLibraryClient, + AsyncLlamaStackAsLibraryClient, +) from client import get_llama_stack_client -from models.config import Configuration +from models.config import Configuration, ModelContextProtocolServer # TODO(lucasagomes): implement this function to retrieve user ID from auth @@ -22,20 +28,81 @@ def retrieve_user_id(auth: Any) -> str: # pylint: disable=unused-argument return "user_id_placeholder" +async def register_mcp_servers_async( + logger: Logger, configuration: Configuration +) -> None: + """Register Model Context Protocol (MCP) servers with the LlamaStack client (async).""" + if configuration.llama_stack.use_as_library_client: + # Library client - use async interface + # config.py validation ensures library_client_config_path is not None + # when use_as_library_client is True + config_path = cast(str, configuration.llama_stack.library_client_config_path) + client = LlamaStackAsLibraryClient(config_path) + await client.async_client.initialize() + + await _register_mcp_toolgroups_async( + client.async_client, configuration.mcp_servers, logger + ) + else: + # Service client - use sync interface + register_mcp_servers(logger, configuration) + + def register_mcp_servers(logger: Logger, configuration: Configuration) -> None: - """Register Model Context Protocol (MCP) servers with the LlamaStack client.""" - # Get list of registered tools and extract their toolgroup IDs + """Register Model Context Protocol (MCP) servers with the LlamaStack client (sync).""" + # Service client - use sync interface client = get_llama_stack_client(configuration.llama_stack) + + _register_mcp_toolgroups_sync(client, configuration.mcp_servers, logger) + + +async def _register_mcp_toolgroups_async( + client: AsyncLlamaStackAsLibraryClient, + mcp_servers: List[ModelContextProtocolServer], + logger: Logger, +) -> None: + """Async logic for registering MCP toolgroups.""" + # Get registered tools + registered_tools = await client.tools.list() + registered_toolgroups = [tool.toolgroup_id for tool in registered_tools] + logger.debug("Registered toolgroups: %s", set(registered_toolgroups)) + + # Register toolgroups for MCP servers if not already registered + for mcp in mcp_servers: + if mcp.name not in registered_toolgroups: + logger.debug("Registering MCP server: %s, %s", mcp.name, mcp.url) + + registration_params = { + "toolgroup_id": mcp.name, + "provider_id": mcp.provider_id, + "mcp_endpoint": {"uri": mcp.url}, + } + + await client.toolgroups.register(**registration_params) + logger.debug("MCP server %s registered successfully", mcp.name) + + +def _register_mcp_toolgroups_sync( + client: LlamaStackClient, + mcp_servers: List[ModelContextProtocolServer], + logger: Logger, +) -> None: + """Sync logic for registering MCP toolgroups.""" + # Get registered tools registered_tools = client.tools.list() registered_toolgroups = [tool.toolgroup_id for tool in registered_tools] logger.debug("Registered toolgroups: %s", set(registered_toolgroups)) + # Register toolgroups for MCP servers if not already registered - for mcp in configuration.mcp_servers: - if mcp.name not in registered_toolgroups: # required + for mcp in mcp_servers: + if mcp.name not in registered_toolgroups: logger.debug("Registering MCP server: %s, %s", mcp.name, mcp.url) - client.toolgroups.register( - toolgroup_id=mcp.name, - provider_id=mcp.provider_id, - mcp_endpoint={"uri": mcp.url}, - ) + + registration_params = { + "toolgroup_id": mcp.name, + "provider_id": mcp.provider_id, + "mcp_endpoint": {"uri": mcp.url}, + } + + client.toolgroups.register(**registration_params) logger.debug("MCP server %s registered successfully", mcp.name) diff --git a/tests/unit/utils/test_common.py b/tests/unit/utils/test_common.py index d727e5ee6..64ff7c93f 100644 --- a/tests/unit/utils/test_common.py +++ b/tests/unit/utils/test_common.py @@ -1,9 +1,14 @@ """Test module for utils/common.py.""" -from unittest.mock import Mock +import pytest +from unittest.mock import Mock, AsyncMock from logging import Logger -from utils.common import retrieve_user_id, register_mcp_servers +from utils.common import ( + retrieve_user_id, + register_mcp_servers, + register_mcp_servers_async, +) from models.config import ( Configuration, ServiceConfiguration, @@ -222,3 +227,55 @@ def test_register_mcp_servers_with_custom_provider(mocker): provider_id="my-custom-provider", mcp_endpoint={"uri": "https://custom.example.com/mcp"}, ) + + +@pytest.mark.asyncio +async def test_register_mcp_servers_async_with_library_client(mocker): + """Test register_mcp_servers_async with library client configuration.""" + # Mock the logger + mock_logger = Mock(spec=Logger) + + # Mock the LlamaStackAsLibraryClient + mock_library_client = Mock() + mock_async_client = AsyncMock() + mock_async_client.initialize = AsyncMock() + mock_library_client.async_client = mock_async_client + + # Mock tools.list to return empty list + mock_tool = Mock() + mock_tool.toolgroup_id = "existing-tool" + mock_async_client.tools.list = AsyncMock(return_value=[mock_tool]) + mock_async_client.toolgroups.register = AsyncMock() + + mocker.patch( + "utils.common.LlamaStackAsLibraryClient", return_value=mock_library_client + ) + + # Create configuration with library client enabled + mcp_server = ModelContextProtocolServer( + name="test-server", url="http://localhost:8080" + ) + config = Configuration( + name="test", + service=ServiceConfiguration(), + llama_stack=LLamaStackConfiguration( + use_as_library_client=True, + library_client_config_path="/path/to/config.yaml", + ), + user_data_collection=UserDataCollection(feedback_disabled=True), + mcp_servers=[mcp_server], + ) + + # Call the async function + await register_mcp_servers_async(mock_logger, config) + + # Verify initialization was called + mock_async_client.initialize.assert_called_once() + # Verify tools.list was called + mock_async_client.tools.list.assert_called_once() + # Verify toolgroups.register was called for the new server + mock_async_client.toolgroups.register.assert_called_once_with( + toolgroup_id="test-server", + provider_id="model-context-protocol", + mcp_endpoint={"uri": "http://localhost:8080"}, + )