Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from auth import get_auth_dependency
from utils.common import retrieve_user_id
from utils.endpoints import check_configuration_loaded, get_system_prompt
from utils.mcp_headers import mcp_headers_dependency
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
from utils.suid import get_suid
from utils.types import GraniteToolParser

Expand Down Expand Up @@ -231,6 +231,7 @@ def retrieve_response(
# preserve compatibility when mcp_headers is not provided
if mcp_headers is None:
mcp_headers = {}
mcp_headers = handle_mcp_headers_with_toolgroups(mcp_headers, configuration)
if not mcp_headers and token:
for mcp_server in configuration.mcp_servers:
mcp_headers[mcp_server.url] = {
Expand Down
5 changes: 4 additions & 1 deletion src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from models.requests import QueryRequest
from utils.endpoints import check_configuration_loaded, get_system_prompt
from utils.common import retrieve_user_id
from utils.mcp_headers import mcp_headers_dependency
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
from utils.suid import get_suid
from utils.types import GraniteToolParser

Expand Down Expand Up @@ -290,6 +290,9 @@ async def retrieve_response(
# preserve compatibility when mcp_headers is not provided
if mcp_headers is None:
mcp_headers = {}

mcp_headers = handle_mcp_headers_with_toolgroups(mcp_headers, configuration)

if not mcp_headers and token:
for mcp_server in configuration.mcp_servers:
mcp_headers[mcp_server.url] = {
Expand Down
42 changes: 42 additions & 0 deletions src/utils/mcp_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@

import json
import logging
from urllib.parse import urlparse

from fastapi import Request

from configuration import AppConfig


logger = logging.getLogger("app.endpoints.dependencies")


Expand Down Expand Up @@ -46,3 +51,40 @@ def extract_mcp_headers(request: Request) -> dict[str, dict[str, str]]:
)
mcp_headers = {}
return mcp_headers


def handle_mcp_headers_with_toolgroups(
mcp_headers: dict[str, dict[str, str]], config: AppConfig
) -> dict[str, dict[str, str]]:
"""Process MCP headers by converting toolgroup names to URLs.

This function takes MCP headers where keys can be either valid URLs or
toolgroup names. For valid URLs (HTTP/HTTPS), it keeps them as-is. For
toolgroup names, it looks up the corresponding MCP server URL in the
configuration and replaces the key with the URL. Unknown toolgroup names
are filtered out.

Args:
mcp_headers: Dictionary with keys as URLs or toolgroup names
config: Application configuration containing MCP server definitions

Returns:
Dictionary with URLs as keys and their corresponding headers as values
"""
converted_mcp_headers = {}

for key, item in mcp_headers.items():
key_url_parsed = urlparse(key)
if key_url_parsed.scheme in ("http", "https") and key_url_parsed.netloc:
# a valid url is supplied, deliver it as is
converted_mcp_headers[key] = item
else:
# assume the key is a toolgroup name
# look for toolgroups name in mcp_servers configuration
# if the mcp server is not found, the mcp header gets ignored
for mcp_server in config.mcp_servers:
if mcp_server.name == key and mcp_server.url:
converted_mcp_headers[mcp_server.url] = item
break

return converted_mcp_headers
21 changes: 18 additions & 3 deletions tests/unit/app/endpoints/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,8 +694,14 @@ def test_retrieve_response_with_mcp_servers_and_mcp_headers(mocker):
model_id = "fake_model_id"
access_token = ""
mcp_headers = {
"http://localhost:3000": {"Authorization": "Bearer test_token_123"},
"https://git.example.com/mcp": {"Authorization": "Bearer test_token_123"},
"filesystem-server": {"Authorization": "Bearer test_token_123"},
"git-server": {"Authorization": "Bearer test_token_456"},
"http://another-server-mcp-server:3000": {
"Authorization": "Bearer test_token_789"
},
"unknown-mcp-server": {
"Authorization": "Bearer test_token_for_unknown-mcp-server"
},
}

response, conversation_id = retrieve_response(
Expand All @@ -718,11 +724,20 @@ def test_retrieve_response_with_mcp_servers_and_mcp_headers(mocker):
None, # conversation_id
)

expected_mcp_headers = {
"http://localhost:3000": {"Authorization": "Bearer test_token_123"},
"https://git.example.com/mcp": {"Authorization": "Bearer test_token_456"},
"http://another-server-mcp-server:3000": {
"Authorization": "Bearer test_token_789"
},
# we do not put "unknown-mcp-server" url as it's unknown to lightspeed-stack
}

# Check that the agent's extra_headers property was set correctly
expected_extra_headers = {
"X-LlamaStack-Provider-Data": json.dumps(
{
"mcp_headers": mcp_headers,
"mcp_headers": expected_mcp_headers,
}
)
}
Expand Down
20 changes: 17 additions & 3 deletions tests/unit/app/endpoints/test_streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,8 +762,14 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers(mocker):
model_id = "fake_model_id"
access_token = ""
mcp_headers = {
"http://localhost:3000": {"Authorization": "Bearer test_token_123"},
"https://git.example.com/mcp": {"Authorization": "Bearer test_token_456"},
"filesystem-server": {"Authorization": "Bearer test_token_123"},
"git-server": {"Authorization": "Bearer test_token_456"},
"http://another-server-mcp-server:3000": {
"Authorization": "Bearer test_token_789"
},
"unknown-mcp-server": {
"Authorization": "Bearer test_token_for_unknown-mcp-server"
},
}

response, conversation_id = await retrieve_response(
Expand All @@ -786,9 +792,17 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers(mocker):
None, # conversation_id
)

expected_mcp_headers = {
"http://localhost:3000": {"Authorization": "Bearer test_token_123"},
"https://git.example.com/mcp": {"Authorization": "Bearer test_token_456"},
"http://another-server-mcp-server:3000": {
"Authorization": "Bearer test_token_789"
},
# we do not put "unknown-mcp-server" url as it's unknown to lightspeed-stack
}
# Check that the agent's extra_headers property was set correctly
expected_extra_headers = {
"X-LlamaStack-Provider-Data": json.dumps({"mcp_headers": mcp_headers})
"X-LlamaStack-Provider-Data": json.dumps({"mcp_headers": expected_mcp_headers})
}
assert mock_agent.extra_headers == expected_extra_headers

Expand Down