Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/app/endpoints/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from fastapi import APIRouter, Request
from llama_stack_client import LlamaStackClient

from client import get_llama_stack_client
from configuration import configuration
from models.responses import ModelsResponse

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -40,7 +42,10 @@

@router.get("/models", responses=models_responses)
def models_endpoint_handler(request: Request) -> ModelsResponse:
client = LlamaStackClient(base_url="http://localhost:8321")
llama_stack_config = configuration.llama_stack_configuration
logger.info("LLama stack config: %s", llama_stack_config)

client = get_llama_stack_client(llama_stack_config)
models = client.models.list()
m = [dict(m) for m in models]
return ModelsResponse(models=m)
22 changes: 2 additions & 20 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
import logging
from typing import Any

from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client import LlamaStackClient
from llama_stack_client.types import UserMessage

from fastapi import APIRouter, Request

from client import get_llama_stack_client
from configuration import configuration
from models.config import LLamaStackConfiguration
from models.responses import QueryResponse

logger = logging.getLogger("app.endpoints.handlers")
Expand All @@ -27,7 +26,7 @@


@router.post("/query", responses=query_response)
def info_endpoint_handler(request: Request, query: str) -> QueryResponse:
def query_endpoint_handler(request: Request, query: str) -> QueryResponse:
llama_stack_config = configuration.llama_stack_configuration
logger.info("LLama stack config: %s", llama_stack_config)

Expand Down Expand Up @@ -70,20 +69,3 @@ def retrieve_response(client: LlamaStackClient, model_id: str, prompt: str) -> s
)

return str(response.output_message.content)


def get_llama_stack_client(
llama_stack_config: LLamaStackConfiguration,
) -> LlamaStackClient:
if llama_stack_config.use_as_library_client is True:
logger.info("Using Llama stack as library client")
client = LlamaStackAsLibraryClient(
llama_stack_config.library_client_config_path
)
client.initialize()
return client
else:
logger.info("Using Llama stack running as a service")
return LlamaStackClient(
base_url=llama_stack_config.url, api_key=llama_stack_config.api_key
)
26 changes: 26 additions & 0 deletions src/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""LLama stack client retrieval."""

import logging

from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
from llama_stack_client import LlamaStackClient
from models.config import LLamaStackConfiguration

logger = logging.getLogger(__name__)


def get_llama_stack_client(
llama_stack_config: LLamaStackConfiguration,
) -> LlamaStackClient:
if llama_stack_config.use_as_library_client is True:
logger.info("Using Llama stack as library client")
client = LlamaStackAsLibraryClient(
llama_stack_config.library_client_config_path
)
client.initialize()
return client
else:
logger.info("Using Llama stack running as a service")
return LlamaStackClient(
base_url=llama_stack_config.url, api_key=llama_stack_config.api_key
)
31 changes: 31 additions & 0 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Unit tests for functions defined in src/client.py."""

import os
from unittest.mock import patch

from client import get_llama_stack_client
from models.config import LLamaStackConfiguration


@patch.dict(os.environ, {"INFERENCE_MODEL": "llama3.2:3b-instruct-fp16"})
def test_get_llama_stack_library_client():
cfg = LLamaStackConfiguration(
url=None,
api_key=None,
use_as_library_client=True,
library_client_config_path="ollama",
)
client = get_llama_stack_client(cfg)
assert client is not None


@patch.dict(os.environ, {"INFERENCE_MODEL": "llama3.2:3b-instruct-fp16"})
def test_get_llama_stack_remote_client():
cfg = LLamaStackConfiguration(
url="http://localhost:8321",
api_key=None,
use_as_library_client=False,
library_client_config_path="ollama",
)
client = get_llama_stack_client(cfg)
assert client is not None