diff --git a/lightspeed-stack.yaml b/lightspeed-stack.yaml index 073446e8..ddbb8d1c 100644 --- a/lightspeed-stack.yaml +++ b/lightspeed-stack.yaml @@ -1,6 +1,10 @@ name: foo bar baz llama_stack: + # Uses a remote llama-stack service + # The instance would have already been started with a llama-stack-run.yaml file use_as_library_client: false + # Alternative for "as library use" + # use_as_library_client: true + # library_client_config_path: url: http://localhost:8321 api_key: xyzzy - chat_completion_mode: true diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 1b8819a1..65689162 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -6,6 +6,7 @@ from llama_stack.distribution.library_client import LlamaStackAsLibraryClient from llama_stack_client.lib.agents.agent import Agent from llama_stack_client import LlamaStackClient +from llama_stack_client.types import UserMessage from fastapi import APIRouter, Request @@ -41,43 +42,34 @@ def info_endpoint_handler(request: Request, query: str) -> QueryResponse: logger.info("Model: %s", model_id) - response = retrieve_response( - client, model_id, llama_stack_config.chat_completion_mode, query - ) + response = retrieve_response(client, model_id, query) return QueryResponse(query=query, response=response) -def retrieve_response( - client: LlamaStackClient, model_id: str, chat_completion_mode: bool, query: str -) -> str: - if chat_completion_mode: - logger.info("Chat completion mode enabled") - response = client.inference.chat_completion( - model_id=model_id, - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": query}, - ], - ) - return str(response.completion_message.content) - else: - logger.info("Chat completion mode disabled") - agent = Agent( - client, - model=model_id, - instructions="You are a helpful assistant", - tools=[], - ) +def retrieve_response(client: LlamaStackClient, model_id: str, prompt: str) -> str: - prompt = "How do you do great work?" + available_shields = [shield.identifier for shield in client.shields.list()] + if not available_shields: + print(colored("No available shields. Disabling safety.", "yellow")) + else: + print(f"Available shields found: {available_shields}") + + agent = Agent( + client, + model=model_id, + instructions="You are a helpful assistant", + input_shields=available_shields if available_shields else [], + tools=[], + ) + session_id = agent.create_session("chat_session") + response = agent.create_turn( + messages=[UserMessage(role="user", content=prompt)], + session_id=session_id, + stream=False, + ) - response = agent.create_turn( - messages=[{"role": "user", "content": prompt}], - session_id=agent.create_session("rag_session"), - stream=False, - ) - return str(response.output_message.content) + return str(response.output_message.content) def get_llama_stack_client( diff --git a/src/models/config.py b/src/models/config.py index 5601465f..e82ebd81 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -11,7 +11,6 @@ class LLamaStackConfiguration(BaseModel): api_key: Optional[str] = None use_as_library_client: Optional[bool] = None library_client_config_path: Optional[str] = None - chat_completion_mode: bool = False @model_validator(mode="after") def check_llama_stack_model(self) -> Self: