Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import version
from log import get_logger
from configuration import configuration
from utils.common import register_mcp_servers
from utils.common import register_mcp_servers_async

logger = get_logger(__name__)

Expand All @@ -32,8 +32,7 @@
@app.on_event("startup")
async def startup_event() -> None:
"""Perform logger setup on service startup."""
logger.info("Registering MCP servers")
await register_mcp_servers_async(logger, configuration.configuration)
get_logger("app.endpoints.handlers")
logger.info("Starting up: registering MCP servers")
register_mcp_servers(logger, configuration.configuration)
logger.info("Including routers")
routers.include_routers(app)
logger.info("App startup complete")
89 changes: 78 additions & 11 deletions src/utils/common.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
"""Common utilities for the project."""

from typing import Any
from typing import Any, List, cast
from logging import Logger

from llama_stack_client import LlamaStackClient

from llama_stack.distribution.library_client import (
LlamaStackAsLibraryClient,
AsyncLlamaStackAsLibraryClient,
)

from client import get_llama_stack_client
from models.config import Configuration
from models.config import Configuration, ModelContextProtocolServer


# TODO(lucasagomes): implement this function to retrieve user ID from auth
Expand All @@ -22,20 +28,81 @@ def retrieve_user_id(auth: Any) -> str: # pylint: disable=unused-argument
return "user_id_placeholder"


async def register_mcp_servers_async(
logger: Logger, configuration: Configuration
) -> None:
"""Register Model Context Protocol (MCP) servers with the LlamaStack client (async)."""
if configuration.llama_stack.use_as_library_client:
# Library client - use async interface
# config.py validation ensures library_client_config_path is not None
# when use_as_library_client is True
config_path = cast(str, configuration.llama_stack.library_client_config_path)
client = LlamaStackAsLibraryClient(config_path)
await client.async_client.initialize()

await _register_mcp_toolgroups_async(
client.async_client, configuration.mcp_servers, logger
)
else:
# Service client - use sync interface
register_mcp_servers(logger, configuration)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO it'd be better to get the client here too; to be consistent with the async approach.

async def register_mcp_servers_async(
    logger: Logger, configuration: Configuration
) -> None:
    """Register Model Context Protocol (MCP) servers with the LlamaStack client (async)."""
    if configuration.llama_stack.use_as_library_client:
        # Library client - use async interface
        # config.py validation ensures library_client_config_path is not None
        # when use_as_library_client is True
        config_path = cast(str, configuration.llama_stack.library_client_config_path)
        client = LlamaStackAsLibraryClient(config_path)
        await client.async_client.initialize()

        await _register_mcp_toolgroups_async(
            client.async_client, configuration.mcp_servers, logger
        )
    else:
        # Service client - use sync interface
        client = get_llama_stack_client(configuration.llama_stack)
    
        _register_mcp_toolgroups_sync(client, configuration.mcp_servers, logger)



def register_mcp_servers(logger: Logger, configuration: Configuration) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delete this if doing the above proposal.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added it to make the tests a bit clearer (async when not using as a library).

"""Register Model Context Protocol (MCP) servers with the LlamaStack client."""
# Get list of registered tools and extract their toolgroup IDs
"""Register Model Context Protocol (MCP) servers with the LlamaStack client (sync)."""
# Service client - use sync interface
client = get_llama_stack_client(configuration.llama_stack)

_register_mcp_toolgroups_sync(client, configuration.mcp_servers, logger)


async def _register_mcp_toolgroups_async(
client: AsyncLlamaStackAsLibraryClient,
mcp_servers: List[ModelContextProtocolServer],
logger: Logger,
) -> None:
"""Async logic for registering MCP toolgroups."""
# Get registered tools
registered_tools = await client.tools.list()
registered_toolgroups = [tool.toolgroup_id for tool in registered_tools]
logger.debug("Registered toolgroups: %s", set(registered_toolgroups))

# Register toolgroups for MCP servers if not already registered
for mcp in mcp_servers:
if mcp.name not in registered_toolgroups:
logger.debug("Registering MCP server: %s, %s", mcp.name, mcp.url)

registration_params = {
"toolgroup_id": mcp.name,
"provider_id": mcp.provider_id,
"mcp_endpoint": {"uri": mcp.url},
}

await client.toolgroups.register(**registration_params)
logger.debug("MCP server %s registered successfully", mcp.name)


def _register_mcp_toolgroups_sync(
client: LlamaStackClient,
mcp_servers: List[ModelContextProtocolServer],
logger: Logger,
) -> None:
"""Sync logic for registering MCP toolgroups."""
# Get registered tools
registered_tools = client.tools.list()
registered_toolgroups = [tool.toolgroup_id for tool in registered_tools]
logger.debug("Registered toolgroups: %s", set(registered_toolgroups))

# Register toolgroups for MCP servers if not already registered
for mcp in configuration.mcp_servers:
if mcp.name not in registered_toolgroups: # required
for mcp in mcp_servers:
if mcp.name not in registered_toolgroups:
logger.debug("Registering MCP server: %s, %s", mcp.name, mcp.url)
client.toolgroups.register(
toolgroup_id=mcp.name,
provider_id=mcp.provider_id,
mcp_endpoint={"uri": mcp.url},
)

registration_params = {
"toolgroup_id": mcp.name,
"provider_id": mcp.provider_id,
"mcp_endpoint": {"uri": mcp.url},
}

client.toolgroups.register(**registration_params)
logger.debug("MCP server %s registered successfully", mcp.name)
61 changes: 59 additions & 2 deletions tests/unit/utils/test_common.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
"""Test module for utils/common.py."""

from unittest.mock import Mock
import pytest
from unittest.mock import Mock, AsyncMock
from logging import Logger

from utils.common import retrieve_user_id, register_mcp_servers
from utils.common import (
retrieve_user_id,
register_mcp_servers,
register_mcp_servers_async,
)
from models.config import (
Configuration,
ServiceConfiguration,
Expand Down Expand Up @@ -222,3 +227,55 @@ def test_register_mcp_servers_with_custom_provider(mocker):
provider_id="my-custom-provider",
mcp_endpoint={"uri": "https://custom.example.com/mcp"},
)


@pytest.mark.asyncio
async def test_register_mcp_servers_async_with_library_client(mocker):
"""Test register_mcp_servers_async with library client configuration."""
# Mock the logger
mock_logger = Mock(spec=Logger)

# Mock the LlamaStackAsLibraryClient
mock_library_client = Mock()
mock_async_client = AsyncMock()
mock_async_client.initialize = AsyncMock()
mock_library_client.async_client = mock_async_client

# Mock tools.list to return empty list
mock_tool = Mock()
mock_tool.toolgroup_id = "existing-tool"
mock_async_client.tools.list = AsyncMock(return_value=[mock_tool])
mock_async_client.toolgroups.register = AsyncMock()

mocker.patch(
"utils.common.LlamaStackAsLibraryClient", return_value=mock_library_client
)

# Create configuration with library client enabled
mcp_server = ModelContextProtocolServer(
name="test-server", url="http://localhost:8080"
)
config = Configuration(
name="test",
service=ServiceConfiguration(),
llama_stack=LLamaStackConfiguration(
use_as_library_client=True,
library_client_config_path="/path/to/config.yaml",
),
user_data_collection=UserDataCollection(feedback_disabled=True),
mcp_servers=[mcp_server],
)

# Call the async function
await register_mcp_servers_async(mock_logger, config)

# Verify initialization was called
mock_async_client.initialize.assert_called_once()
# Verify tools.list was called
mock_async_client.tools.list.assert_called_once()
# Verify toolgroups.register was called for the new server
mock_async_client.toolgroups.register.assert_called_once_with(
toolgroup_id="test-server",
provider_id="model-context-protocol",
mcp_endpoint={"uri": "http://localhost:8080"},
)