diff --git a/lightspeed-stack.yaml b/lightspeed-stack.yaml index 96425376..073446e8 100644 --- a/lightspeed-stack.yaml +++ b/lightspeed-stack.yaml @@ -3,3 +3,4 @@ llama_stack: use_as_library_client: false url: http://localhost:8321 api_key: xyzzy + chat_completion_mode: true diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index fc1f5264..758992ca 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -4,6 +4,7 @@ from typing import Any from llama_stack.distribution.library_client import LlamaStackAsLibraryClient +from llama_stack_client.lib.agents.agent import Agent from llama_stack_client import LlamaStackClient from fastapi import APIRouter, Request @@ -12,7 +13,7 @@ from models.config import LLamaStackConfiguration from models.responses import QueryResponse -logger = logging.getLogger(__name__) +logger = logging.getLogger("app.endpoints.handlers") router = APIRouter(tags=["models"]) @@ -28,10 +29,8 @@ def info_endpoint_handler(request: Request, query: str) -> QueryResponse: llama_stack_config = configuration.llama_stack_configuration logger.info("LLama stack config: %s", llama_stack_config) - client = getLLamaStackClient(llama_stack_config) - client = LlamaStackClient( - base_url=llama_stack_config.url, api_key=llama_stack_config.api_key - ) + + client = get_llama_stack_client(llama_stack_config) # retrieve list of available models models = client.models.list() @@ -42,24 +41,55 @@ def info_endpoint_handler(request: Request, query: str) -> QueryResponse: logger.info("Model: %s", model_id) - response = client.inference.chat_completion( - model_id=model_id, - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": query}, - ], + response = retrieve_response( + client, model_id, llama_stack_config.chat_completion_mode, query ) - return QueryResponse(query=query, response=str(response.completion_message.content)) + + return QueryResponse(query=query, response=response) + + +def retrieve_response( + client: LlamaStackClient, model_id: str, chat_completion_mode: bool, query: str +) -> str: + if chat_completion_mode: + logger.info("Chat completion mode enabled") + response = client.inference.chat_completion( + model_id=model_id, + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": query}, + ], + ) + return str(response.completion_message.content) + else: + logger.info("Chat completion mode disabled") + agent = Agent( + client, + model=model_id, + instructions="You are a helpful assistant", + tools=[], + ) + + prompt = "How do you do great work?" + + response = agent.create_turn( + messages=[{"role": "user", "content": prompt}], + session_id=agent.create_session("rag_session"), + stream=False, + ) + return str(response.output_message.content) -def getLLamaStackClient( +def get_llama_stack_client( llama_stack_config: LLamaStackConfiguration, ) -> LlamaStackClient: if llama_stack_config.use_as_library_client is True: + logger.info("Using Llama stack as library client") client = LlamaStackAsLibraryClient("ollama") client.initialize() return client else: + logger.info("Using Llama stack running as a service") return LlamaStackClient( base_url=llama_stack_config.url, api_key=llama_stack_config.api_key ) diff --git a/src/app/endpoints/root.py b/src/app/endpoints/root.py index 8c1b29b9..5e1c8cee 100644 --- a/src/app/endpoints/root.py +++ b/src/app/endpoints/root.py @@ -5,9 +5,11 @@ from fastapi import APIRouter, Request from fastapi.responses import HTMLResponse -logger = logging.getLogger(__name__) +logger = logging.getLogger("app.endpoints.handlers") + router = APIRouter(tags=["root"]) + index_page = """ @@ -25,4 +27,5 @@ @router.get("/", response_class=HTMLResponse) def root_endpoint_handler(request: Request) -> HTMLResponse: + logger.info("Serving index page") return HTMLResponse(index_page) diff --git a/src/models/config.py b/src/models/config.py index 6ffb8f51..d3322ceb 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -10,6 +10,7 @@ class LLamaStackConfiguration(BaseModel): url: Optional[str] = None api_key: Optional[str] = None use_as_library_client: Optional[bool] = None + chat_completion_mode: bool = False @model_validator(mode="after") def check_llama_stack_model(self) -> Self: