Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions src/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,9 @@ async def register_mcp_servers_async(
)
else:
# Service client - use sync interface
register_mcp_servers(logger, configuration)
client = get_llama_stack_client(configuration.llama_stack)


def register_mcp_servers(logger: Logger, configuration: Configuration) -> None:
"""Register Model Context Protocol (MCP) servers with the LlamaStack client (sync)."""
# Service client - use sync interface
client = get_llama_stack_client(configuration.llama_stack)

_register_mcp_toolgroups_sync(client, configuration.mcp_servers, logger)
_register_mcp_toolgroups_sync(client, configuration.mcp_servers, logger)


async def _register_mcp_toolgroups_async(
Expand Down
37 changes: 20 additions & 17 deletions tests/unit/utils/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from utils.common import (
retrieve_user_id,
register_mcp_servers,
register_mcp_servers_async,
)
from models.config import (
Expand All @@ -25,7 +24,8 @@ def test_retrieve_user_id():
assert user_id == "user_id_placeholder"


def test_register_mcp_servers_empty_list(mocker):
@pytest.mark.asyncio
async def test_register_mcp_servers_empty_list(mocker):
"""Test register_mcp_servers with empty MCP servers list."""
# Mock the logger
mock_logger = Mock(spec=Logger)
Expand All @@ -40,22 +40,22 @@ def test_register_mcp_servers_empty_list(mocker):
name="test",
service=ServiceConfiguration(),
llama_stack=LLamaStackConfiguration(
use_as_library_client=True, library_client_config_path="foo"
use_as_library_client=False, url="http://localhost:8321"
),
user_data_collection=UserDataCollection(feedback_disabled=True),
mcp_servers=[],
)

# Call the function
register_mcp_servers(mock_logger, config)
await register_mcp_servers_async(mock_logger, config)

# Verify client.tools.list was called
mock_client.tools.list.assert_called_once()
# Verify client.toolgroups.register was not called since no MCP servers
assert not mock_client.toolgroups.register.called


def test_register_mcp_servers_single_server_not_registered(mocker):
@pytest.mark.asyncio
async def test_register_mcp_servers_single_server_not_registered(mocker):
"""Test register_mcp_servers with single MCP server that is not yet registered."""
# Mock the logger
mock_logger = Mock(spec=Logger)
Expand All @@ -76,14 +76,14 @@ def test_register_mcp_servers_single_server_not_registered(mocker):
name="test",
service=ServiceConfiguration(),
llama_stack=LLamaStackConfiguration(
use_as_library_client=True, library_client_config_path="foo"
use_as_library_client=False, url="http://localhost:8321"
),
user_data_collection=UserDataCollection(feedback_disabled=True),
mcp_servers=[mcp_server],
)

# Call the function
register_mcp_servers(mock_logger, config)
await register_mcp_servers_async(mock_logger, config)

# Verify client.tools.list was called
mock_client.tools.list.assert_called_once()
Expand All @@ -97,7 +97,8 @@ def test_register_mcp_servers_single_server_not_registered(mocker):
mock_logger.debug.assert_called()


def test_register_mcp_servers_single_server_already_registered(mocker):
@pytest.mark.asyncio
async def test_register_mcp_servers_single_server_already_registered(mocker):
"""Test register_mcp_servers with single MCP server that is already registered."""
# Mock the logger
mock_logger = Mock(spec=Logger)
Expand All @@ -117,22 +118,23 @@ def test_register_mcp_servers_single_server_already_registered(mocker):
name="test",
service=ServiceConfiguration(),
llama_stack=LLamaStackConfiguration(
use_as_library_client=True, library_client_config_path="foo"
use_as_library_client=False, url="http://localhost:8321"
),
user_data_collection=UserDataCollection(feedback_disabled=True),
mcp_servers=[mcp_server],
)

# Call the function
register_mcp_servers(mock_logger, config)
await register_mcp_servers_async(mock_logger, config)

# Verify client.tools.list was called
mock_client.tools.list.assert_called_once()
# Verify client.toolgroups.register was NOT called since server already registered
assert not mock_client.toolgroups.register.called


def test_register_mcp_servers_multiple_servers_mixed_registration(mocker):
@pytest.mark.asyncio
async def test_register_mcp_servers_multiple_servers_mixed_registration(mocker):
"""Test register_mcp_servers with multiple MCP servers - some registered, some not."""
# Mock the logger
mock_logger = Mock(spec=Logger)
Expand Down Expand Up @@ -161,14 +163,14 @@ def test_register_mcp_servers_multiple_servers_mixed_registration(mocker):
name="test",
service=ServiceConfiguration(),
llama_stack=LLamaStackConfiguration(
use_as_library_client=True, library_client_config_path="foo"
use_as_library_client=False, url="http://localhost:8321"
),
user_data_collection=UserDataCollection(feedback_disabled=True),
mcp_servers=mcp_servers,
)

# Call the function
register_mcp_servers(mock_logger, config)
await register_mcp_servers_async(mock_logger, config)

# Verify client.tools.list was called
mock_client.tools.list.assert_called_once()
Expand All @@ -191,7 +193,8 @@ def test_register_mcp_servers_multiple_servers_mixed_registration(mocker):
mock_client.toolgroups.register.assert_has_calls(expected_calls, any_order=True)


def test_register_mcp_servers_with_custom_provider(mocker):
@pytest.mark.asyncio
async def test_register_mcp_servers_with_custom_provider(mocker):
"""Test register_mcp_servers with MCP server using custom provider."""
# Mock the logger
mock_logger = Mock(spec=Logger)
Expand All @@ -212,14 +215,14 @@ def test_register_mcp_servers_with_custom_provider(mocker):
name="test",
service=ServiceConfiguration(),
llama_stack=LLamaStackConfiguration(
use_as_library_client=True, library_client_config_path="foo"
use_as_library_client=False, url="http://localhost:8321"
),
user_data_collection=UserDataCollection(feedback_disabled=True),
mcp_servers=[mcp_server],
)

# Call the function
register_mcp_servers(mock_logger, config)
await register_mcp_servers_async(mock_logger, config)

# Verify client.toolgroups.register was called with custom provider
mock_client.toolgroups.register.assert_called_once_with(
Expand Down