Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,15 +183,15 @@ async def query_endpoint_handler(
try:
# try to get Llama Stack client
client = AsyncLlamaStackClientHolder().get_client()
model_id, provider_id = select_model_and_provider_id(
llama_stack_model_id, model_id, provider_id = select_model_and_provider_id(
await client.models.list(),
*evaluate_model_hints(
user_conversation=user_conversation, query_request=query_request
),
)
response, conversation_id = await retrieve_response(
client,
model_id,
llama_stack_model_id,
query_request,
token,
mcp_headers=mcp_headers,
Expand Down Expand Up @@ -239,7 +239,7 @@ async def query_endpoint_handler(

def select_model_and_provider_id(
models: ModelListResponse, model_id: str | None, provider_id: str | None
) -> tuple[str, str]:
) -> tuple[str, str, str]:
"""Select the model ID and provider ID based on the request or available models."""
# If model_id and provider_id are provided in the request, use them

Expand Down Expand Up @@ -268,7 +268,7 @@ def select_model_and_provider_id(
model_id = model.identifier
provider_id = model.provider_id
logger.info("Selected model: %s", model)
return model_id, provider_id
return model_id, model_id, provider_id
except (StopIteration, AttributeError) as e:
message = "No LLM model found in available models"
logger.error(message)
Expand Down Expand Up @@ -297,7 +297,7 @@ def select_model_and_provider_id(
},
)

return llama_stack_model_id, provider_id
return llama_stack_model_id, model_id, provider_id


def _is_inout_shield(shield: Shield) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,15 +421,15 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
try:
# try to get Llama Stack client
client = AsyncLlamaStackClientHolder().get_client()
model_id, provider_id = select_model_and_provider_id(
llama_stack_model_id, model_id, provider_id = select_model_and_provider_id(
await client.models.list(),
*evaluate_model_hints(
user_conversation=user_conversation, query_request=query_request
),
)
response, conversation_id = await retrieve_response(
client,
model_id,
llama_stack_model_id,
query_request,
token,
mcp_headers=mcp_headers,
Expand Down
21 changes: 12 additions & 9 deletions tests/unit/app/endpoints/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ async def _test_query_endpoint_handler(mocker, store_transcript_to_file=False):
)
mocker.patch(
"app.endpoints.query.select_model_and_provider_id",
return_value=("fake_model_id", "fake_provider_id"),
return_value=("fake_model_id", "fake_model_id", "fake_provider_id"),
)
mocker.patch(
"app.endpoints.query.is_transcripts_enabled",
Expand Down Expand Up @@ -214,11 +214,12 @@ def test_select_model_and_provider_id_from_request(mocker):
)

# Assert the model and provider from request take precedence from the configuration one
model_id, provider_id = select_model_and_provider_id(
llama_stack_model_id, model_id, provider_id = select_model_and_provider_id(
model_list, query_request.model, query_request.provider
)

assert model_id == "provider2/model2"
assert llama_stack_model_id == "provider2/model2"
assert model_id == "model2"
assert provider_id == "provider2"


Expand Down Expand Up @@ -249,12 +250,13 @@ def test_select_model_and_provider_id_from_configuration(mocker):
query="What is OpenStack?",
)

model_id, provider_id = select_model_and_provider_id(
llama_stack_model_id, model_id, provider_id = select_model_and_provider_id(
model_list, query_request.model, query_request.provider
)

# Assert that the default model and provider from the configuration are returned
assert model_id == "default_provider/default_model"
assert llama_stack_model_id == "default_provider/default_model"
assert model_id == "default_model"
assert provider_id == "default_provider"


Expand All @@ -274,12 +276,13 @@ def test_select_model_and_provider_id_first_from_list(mocker):

query_request = QueryRequest(query="What is OpenStack?")

model_id, provider_id = select_model_and_provider_id(
llama_stack_model_id, model_id, provider_id = select_model_and_provider_id(
model_list, query_request.model, query_request.provider
)

# Assert return the first available LLM model when no model/provider is
# specified in the request or in the configuration
assert llama_stack_model_id == "first_model"
assert model_id == "first_model"
assert provider_id == "provider1"
Comment on lines +279 to 287
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Add coverage for fallback with provider-prefixed identifiers

This test uses identifiers without provider prefixes, so it won’t catch the regression where the fallback returns a provider-prefixed model as the plain model. Add a companion test where the first available LLM has identifier "provider1/model1" and assert that the plain model is "model1".

Proposed additional test:

def test_select_model_and_provider_id_first_from_list_with_prefixed_identifier(mocker):
    """Ensure fallback splits provider/model correctly for persistence."""
    model_list = [
        mocker.Mock(identifier="provider1/model1", model_type="llm", provider_id="provider1"),
        mocker.Mock(identifier="provider2/model2", model_type="llm", provider_id="provider2"),
    ]
    query_request = QueryRequest(query="What is OpenStack?")
    llama_stack_model_id, model_id, provider_id = select_model_and_provider_id(
        model_list, query_request.model, query_request.provider
    )
    assert llama_stack_model_id == "provider1/model1"
    assert model_id == "model1"        # plain model id
    assert provider_id == "provider1"

🏁 Script executed:

#!/bin/bash
# Locate the implementation of select_model_and_provider_id
rg "def select_model_and_provider_id" -n -A 10

Length of output: 920


🏁 Script executed:

#!/bin/bash
# Display the full implementation of select_model_and_provider_id for review
sed -n '240,340p' src/app/endpoints/query.py

Length of output: 3986


Ensure fallback handles provider-prefixed model identifiers

The current fallback branch in select_model_and_provider_id unconditionally assigns both returned IDs to model.identifier, so if that identifier already contains a provider prefix (e.g. "provider1/model1"), the plain model_id also comes back as "provider1/model1". We should:

  • Add a unit test to catch this case.
  • Update the fallback implementation to split out the plain model ID.

Locations to update:

  • tests/unit/app/endpoints/test_query.py
  • src/app/endpoints/query.py (fallback branch around line 250)

Proposed test addition:

def test_select_model_and_provider_id_first_with_prefixed_identifier(mocker):
    """Fallback should split provider/model correctly when identifier is prefixed."""
    model_list = [
        mocker.Mock(identifier="provider1/model1", model_type="llm", provider_id="provider1"),
        mocker.Mock(identifier="provider2/model2", model_type="llm", provider_id="provider2"),
    ]
    query_request = QueryRequest(query="foo")
    llama_stack_model_id, model_id, provider_id = select_model_and_provider_id(
        model_list, query_request.model, query_request.provider
    )
    assert llama_stack_model_id == "provider1/model1"
    assert model_id == "model1"        # plain model id
    assert provider_id == "provider1"

Proposed implementation diff in src/app/endpoints/query.py:

         model = next(m for m in models if m.model_type == "llm")
-        model_id = model.identifier
-        provider_id = model.provider_id
-        return model_id, model_id, provider_id
+        full_id = model.identifier
+        provider_id = model.provider_id
+        # split off plain model ID if a prefix is present
+        plain_id = full_id.split("/", 1)[1] if "/" in full_id else full_id
+        return full_id, plain_id, provider_id
🤖 Prompt for AI Agents
In tests/unit/app/endpoints/test_query.py around lines 279-287, add a unit test
that supplies a model_list whose first model.identifier is provider-prefixed
(e.g. "provider1/model1") and asserts that select_model_and_provider_id returns
llama_stack_model_id as the full "provider1/model1", model_id as the plain
"model1", and provider_id as "provider1"; and in src/app/endpoints/query.py
around line ~250 update the fallback branch so that when using model.identifier
you detect a provider prefix (split on the first '/' if present), return the
full identifier as llama_stack_model_id, the right-hand segment as model_id, and
the provider part as provider_id (fall back to existing model.provider_id when
no prefix), keeping other behavior unchanged.


Expand Down Expand Up @@ -1135,7 +1138,7 @@ async def test_auth_tuple_unpacking_in_query_endpoint_handler(mocker):

mocker.patch(
"app.endpoints.query.select_model_and_provider_id",
return_value=("test_model", "test_provider"),
return_value=("test_model", "test_model", "test_provider"),
)
mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False)
# Mock database operations
Expand Down Expand Up @@ -1174,7 +1177,7 @@ async def test_query_endpoint_handler_no_tools_true(mocker):
)
mocker.patch(
"app.endpoints.query.select_model_and_provider_id",
return_value=("fake_model_id", "fake_provider_id"),
return_value=("fake_model_id", "fake_model_id", "fake_provider_id"),
)
mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False)
# Mock database operations
Expand Down Expand Up @@ -1213,7 +1216,7 @@ async def test_query_endpoint_handler_no_tools_false(mocker):
)
mocker.patch(
"app.endpoints.query.select_model_and_provider_id",
return_value=("fake_model_id", "fake_provider_id"),
return_value=("fake_model_id", "fake_model_id", "fake_provider_id"),
)
mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False)
# Mock database operations
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/app/endpoints/test_streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False)
)
mocker.patch(
"app.endpoints.streaming_query.select_model_and_provider_id",
return_value=("fake_model_id", "fake_provider_id"),
return_value=("fake_model_id", "fake_model_id", "fake_provider_id"),
)
mocker.patch(
"app.endpoints.streaming_query.is_transcripts_enabled",
Expand Down Expand Up @@ -1279,7 +1279,7 @@ async def test_auth_tuple_unpacking_in_streaming_query_endpoint_handler(mocker):

mocker.patch(
"app.endpoints.streaming_query.select_model_and_provider_id",
return_value=("test_model", "test_provider"),
return_value=("test_model", "test_model", "test_provider"),
)
mocker.patch(
"app.endpoints.streaming_query.is_transcripts_enabled", return_value=False
Expand Down Expand Up @@ -1325,7 +1325,7 @@ async def test_streaming_query_endpoint_handler_no_tools_true(mocker):
)
mocker.patch(
"app.endpoints.streaming_query.select_model_and_provider_id",
return_value=("fake_model_id", "fake_provider_id"),
return_value=("fake_model_id", "fake_model_id", "fake_provider_id"),
)
mocker.patch(
"app.endpoints.streaming_query.is_transcripts_enabled", return_value=False
Expand Down Expand Up @@ -1372,7 +1372,7 @@ async def test_streaming_query_endpoint_handler_no_tools_false(mocker):
)
mocker.patch(
"app.endpoints.streaming_query.select_model_and_provider_id",
return_value=("fake_model_id", "fake_provider_id"),
return_value=("fake_model_id", "fake_model_id", "fake_provider_id"),
)
mocker.patch(
"app.endpoints.streaming_query.is_transcripts_enabled", return_value=False
Expand Down