Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@
"model_type": "llm"
}
]
},
"503": {
"description": "Connection to Llama Stack is broken"
}
}
}
Expand Down
42 changes: 35 additions & 7 deletions src/app/endpoints/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import logging
from typing import Any

from fastapi import APIRouter, Request
from llama_stack_client import APIConnectionError
from fastapi import APIRouter, HTTPException, Request, status

from client import get_llama_stack_client
from configuration import configuration
from models.responses import ModelsResponse
from utils.endpoints import check_configuration_loaded

logger = logging.getLogger(__name__)
router = APIRouter(tags=["models"])
Expand Down Expand Up @@ -36,16 +38,42 @@
},
]
},
503: {"description": "Connection to Llama Stack is broken"},
}


@router.get("/models", responses=models_responses)
def models_endpoint_handler(_request: Request) -> ModelsResponse:
"""Handle requests to the /models endpoint."""
llama_stack_config = configuration.llama_stack_configuration
logger.info("LLama stack config: %s", llama_stack_config)
check_configuration_loaded(configuration)

client = get_llama_stack_client(llama_stack_config)
models = client.models.list()
m = [dict(m) for m in models]
return ModelsResponse(models=m)
llama_stack_configuration = configuration.llama_stack_configuration
logger.info("LLama stack config: %s", llama_stack_configuration)

try:
# try to get Llama Stack client
client = get_llama_stack_client(llama_stack_configuration)
# retrieve models
models = client.models.list()
m = [dict(m) for m in models]
return ModelsResponse(models=m)
# connection to Llama Stack server
except APIConnectionError as e:
logger.error("Unable to connect to Llama Stack: %s", e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"response": "Unable to connect to Llama Stack",
"cause": str(e),
},
) from e
# any other exception that can occur during model listing
except Exception as e:
logger.error("Unable to retrieve list of models: %s", e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"response": "Unable to retrieve list of models",
"cause": str(e),
},
) from e
132 changes: 132 additions & 0 deletions tests/unit/app/endpoints/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import pytest

from unittest.mock import Mock
from fastapi import HTTPException, status

from app.endpoints.models import models_endpoint_handler
from configuration import AppConfig


def test_models_endpoint_handler_configuration_not_loaded(mocker):
"""Test the models endpoint handler if configuration is not loaded."""
# simulate state when no configuration is loaded
mocker.patch(
"app.endpoints.models.configuration",
return_value=mocker.Mock(),
)
mocker.patch("app.endpoints.models.configuration", None)

request = None
with pytest.raises(HTTPException) as e:
models_endpoint_handler(request)
assert e.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert e.detail["response"] == "Configuration is not loaded"


def test_models_endpoint_handler_improper_llama_stack_configuration(mocker):
"""Test the models endpoint handler if Llama Stack configuration is not proper."""
# configuration for tests
config_dict = {
"name": "test",
"service": {
"host": "localhost",
"port": 8080,
"auth_enabled": False,
"workers": 1,
"color_log": True,
"access_log": True,
},
"llama_stack": {
"api_key": "test-key",
"url": "http://test.com:1234",
"use_as_library_client": False,
},
"user_data_collection": {
"transcripts_disabled": True,
},
"mcp_servers": [],
}
cfg = AppConfig()
cfg.init_from_dict(config_dict)

mocker.patch(
"app.endpoints.models.configuration",
return_value=None,
)

request = None
with pytest.raises(HTTPException) as e:
models_endpoint_handler(request)
assert e.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert e.detail["response"] == "LLama stack is not configured"


def test_models_endpoint_handler_configuration_loaded(mocker):
"""Test the models endpoint handler if configuration is loaded."""
# configuration for tests
config_dict = {
"name": "foo",
"service": {
"host": "localhost",
"port": 8080,
"auth_enabled": False,
"workers": 1,
"color_log": True,
"access_log": True,
},
"llama_stack": {
"api_key": "xyzzy",
"url": "http://x.y.com:1234",
"use_as_library_client": False,
},
"user_data_collection": {
"feedback_disabled": True,
},
}
cfg = AppConfig()
cfg.init_from_dict(config_dict)

with pytest.raises(HTTPException) as e:
request = None
models_endpoint_handler(request)
assert e.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert e.detail["response"] == "Unable to connect to Llama Stack"


def test_models_endpoint_handler_unable_to_retrieve_models_list(mocker):
"""Test the models endpoint handler if configuration is loaded."""
# configuration for tests
config_dict = {
"name": "foo",
"service": {
"host": "localhost",
"port": 8080,
"auth_enabled": False,
"workers": 1,
"color_log": True,
"access_log": True,
},
"llama_stack": {
"api_key": "xyzzy",
"url": "http://x.y.com:1234",
"use_as_library_client": False,
},
"user_data_collection": {
"feedback_disabled": True,
},
}
cfg = AppConfig()
cfg.init_from_dict(config_dict)

# Mock the LlamaStack client
mock_client = Mock()
mock_client.models.list.return_value = []

# Mock the LlamaStack client (shouldn't be called directly)
mocker.patch(
"app.endpoints.models.get_llama_stack_client", return_value=mock_client
)

request = None
response = models_endpoint_handler(request)
assert response is not None