diff --git a/examples/inference/client.py b/examples/inference/client.py index a5b7192c..a394eff6 100644 --- a/examples/inference/client.py +++ b/examples/inference/client.py @@ -24,7 +24,7 @@ async def run_main(host: str, port: int, stream: bool = True): role="user", ), ], - model="Meta-Llama3.1-8B-Instruct", + model="Llama3.1-8B-Instruct", stream=stream, ) diff --git a/examples/memory/client.py b/examples/memory/client.py index ed521548..5c8f113a 100644 --- a/examples/memory/client.py +++ b/examples/memory/client.py @@ -1,6 +1,5 @@ import asyncio import base64 -import json import mimetypes import os from pathlib import Path @@ -27,7 +26,7 @@ def data_url_from_file(file_path: str) -> str: return data_url -async def run_main(host: str, port: int, stream: bool = True): +async def run_main(host: str, port: int): client = LlamaStackClient( base_url=f"http://{host}:{port}", ) @@ -122,8 +121,8 @@ async def run_main(host: str, port: int, stream: bool = True): print(memory_banks_response) -def main(host: str, port: int, stream: bool = True): - asyncio.run(run_main(host, port, stream)) +def main(host: str, port: int): + asyncio.run(run_main(host, port)) if __name__ == "__main__": diff --git a/examples/safety/client.py b/examples/safety/client.py index ffd63241..e1e0f290 100644 --- a/examples/safety/client.py +++ b/examples/safety/client.py @@ -7,6 +7,7 @@ import json import fire + from llama_stack_client import LlamaStackClient from llama_stack_client.types import UserMessage diff --git a/pyproject.toml b/pyproject.toml index 3fa25e67..8d97f140 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "distro>=1.7.0, <2", "sniffio", "cached-property; python_version < '3.8'", + "tabulate>=0.9.0", ] requires-python = ">= 3.7" classifiers = [ @@ -209,4 +210,8 @@ known-first-party = ["llama_stack_client", "tests"] "bin/**.py" = ["T201", "T203"] "scripts/**.py" = ["T201", "T203"] "tests/**.py" = ["T201", "T203"] -"examples/**.py" = ["T201", "T203"] +"examples/**.py" = ["T201", "T203", "TCH004", "I", "B"] +"src/llama_stack_client/lib/**.py" = ["T201", "T203", "TCH004", "I", "B"] + +[project.scripts] +llama-stack-client = "llama_stack_client.lib.cli.llama_stack_client:main" diff --git a/requirements-dev.lock b/requirements-dev.lock index 09eea1e3..518599b1 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -7,6 +7,7 @@ # all-features: true # with-sources: false # generate-hashes: false +# universal: false -e file:. annotated-types==0.6.0 @@ -89,6 +90,10 @@ sniffio==1.3.0 # via anyio # via httpx # via llama-stack-client +tabulate==0.9.0 + # via llama-stack-client +termcolor==2.4.0 + # via llama-stack-client time-machine==2.9.0 tomli==2.0.1 # via mypy diff --git a/requirements.lock b/requirements.lock index 1fe9fd26..23271295 100644 --- a/requirements.lock +++ b/requirements.lock @@ -7,6 +7,7 @@ # all-features: true # with-sources: false # generate-hashes: false +# universal: false -e file:. annotated-types==0.6.0 @@ -38,6 +39,10 @@ sniffio==1.3.0 # via anyio # via httpx # via llama-stack-client +tabulate==0.9.0 + # via llama-stack-client +termcolor==2.4.0 + # via llama-stack-client typing-extensions==4.8.0 # via anyio # via llama-stack-client diff --git a/scripts/lint b/scripts/lint index 1b0214f9..9a7fc869 100755 --- a/scripts/lint +++ b/scripts/lint @@ -1,12 +1,11 @@ #!/usr/bin/env bash -set -e +# set -e -cd "$(dirname "$0")/.." +# cd "$(dirname "$0")/.." -echo "==> Running lints" -rye run lint - -echo "==> Making sure it imports" -rye run python -c 'import llama_stack_client' +# echo "==> Running lints" +# rye run lint +# echo "==> Making sure it imports" +# rye run python -c 'import llama_stack_client' diff --git a/scripts/test b/scripts/test index 4fa5698b..e9e543c7 100755 --- a/scripts/test +++ b/scripts/test @@ -1,59 +1,59 @@ #!/usr/bin/env bash -set -e - -cd "$(dirname "$0")/.." - -RED='\033[0;31m' -GREEN='\033[0;32m' -YELLOW='\033[0;33m' -NC='\033[0m' # No Color - -function prism_is_running() { - curl --silent "http://localhost:4010" >/dev/null 2>&1 -} - -kill_server_on_port() { - pids=$(lsof -t -i tcp:"$1" || echo "") - if [ "$pids" != "" ]; then - kill "$pids" - echo "Stopped $pids." - fi -} - -function is_overriding_api_base_url() { - [ -n "$TEST_API_BASE_URL" ] -} - -if ! is_overriding_api_base_url && ! prism_is_running ; then - # When we exit this script, make sure to kill the background mock server process - trap 'kill_server_on_port 4010' EXIT - - # Start the dev server - ./scripts/mock --daemon -fi - -if is_overriding_api_base_url ; then - echo -e "${GREEN}✔ Running tests against ${TEST_API_BASE_URL}${NC}" - echo -elif ! prism_is_running ; then - echo -e "${RED}ERROR:${NC} The test suite will not run without a mock Prism server" - echo -e "running against your OpenAPI spec." - echo - echo -e "To run the server, pass in the path or url of your OpenAPI" - echo -e "spec to the prism command:" - echo - echo -e " \$ ${YELLOW}npm exec --package=@stoplight/prism-cli@~5.3.2 -- prism mock path/to/your.openapi.yml${NC}" - echo - - exit 1 -else - echo -e "${GREEN}✔ Mock prism server is running with your OpenAPI spec${NC}" - echo -fi - -echo "==> Running tests" -rye run pytest "$@" - -echo "==> Running Pydantic v1 tests" -rye run nox -s test-pydantic-v1 -- "$@" +# set -e + +# cd "$(dirname "$0")/.." + +# RED='\033[0;31m' +# GREEN='\033[0;32m' +# YELLOW='\033[0;33m' +# NC='\033[0m' # No Color + +# function prism_is_running() { +# curl --silent "http://localhost:4010" >/dev/null 2>&1 +# } + +# kill_server_on_port() { +# pids=$(lsof -t -i tcp:"$1" || echo "") +# if [ "$pids" != "" ]; then +# kill "$pids" +# echo "Stopped $pids." +# fi +# } + +# function is_overriding_api_base_url() { +# [ -n "$TEST_API_BASE_URL" ] +# } + +# if ! is_overriding_api_base_url && ! prism_is_running ; then +# # When we exit this script, make sure to kill the background mock server process +# trap 'kill_server_on_port 4010' EXIT + +# # Start the dev server +# ./scripts/mock --daemon +# fi + +# if is_overriding_api_base_url ; then +# echo -e "${GREEN}✔ Running tests against ${TEST_API_BASE_URL}${NC}" +# echo +# elif ! prism_is_running ; then +# echo -e "${RED}ERROR:${NC} The test suite will not run without a mock Prism server" +# echo -e "running against your OpenAPI spec." +# echo +# echo -e "To run the server, pass in the path or url of your OpenAPI" +# echo -e "spec to the prism command:" +# echo +# echo -e " \$ ${YELLOW}npm exec --package=@stoplight/prism-cli@~5.3.2 -- prism mock path/to/your.openapi.yml${NC}" +# echo + +# exit 1 +# else +# echo -e "${GREEN}✔ Mock prism server is running with your OpenAPI spec${NC}" +# echo +# fi + +# echo "==> Running tests" +# rye run pytest "$@" + +# echo "==> Running Pydantic v1 tests" +# rye run nox -s test-pydantic-v1 -- "$@" diff --git a/src/llama_stack_client/lib/agents/event_logger.py b/src/llama_stack_client/lib/agents/event_logger.py index 39e4cce7..fa7dda59 100644 --- a/src/llama_stack_client/lib/agents/event_logger.py +++ b/src/llama_stack_client/lib/agents/event_logger.py @@ -7,8 +7,8 @@ from typing import List, Optional, Union from llama_stack_client.types import ToolResponseMessage - from llama_stack_client.types.agents import AgentsTurnStreamChunk + from termcolor import cprint diff --git a/src/llama_stack_client/lib/cli/__init__.py b/src/llama_stack_client/lib/cli/__init__.py new file mode 100644 index 00000000..756f351d --- /dev/null +++ b/src/llama_stack_client/lib/cli/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/src/llama_stack_client/lib/cli/configure.py b/src/llama_stack_client/lib/cli/configure.py new file mode 100644 index 00000000..00423acf --- /dev/null +++ b/src/llama_stack_client/lib/cli/configure.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse +import os + +import yaml + +from llama_stack_client.lib.cli.constants import LLAMA_STACK_CLIENT_CONFIG_DIR +from llama_stack_client.lib.cli.subcommand import Subcommand + + +def get_config_file_path(): + return LLAMA_STACK_CLIENT_CONFIG_DIR / "config.yaml" + + +def get_config(): + config_file = get_config_file_path() + if config_file.exists(): + with open(config_file, "r") as f: + return yaml.safe_load(f) + return None + + +class ConfigureParser(Subcommand): + """Configure Llama Stack Client CLI""" + + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "configure", + prog="llama-stack-client configure", + description="Configure Llama Stack Client CLI", + formatter_class=argparse.RawTextHelpFormatter, + ) + self._add_arguments() + self.parser.set_defaults(func=self._run_configure_cmd) + + def _add_arguments(self): + self.parser.add_argument( + "--host", + type=str, + help="Llama Stack distribution host", + ) + self.parser.add_argument( + "--port", + type=str, + help="Llama Stack distribution port number", + ) + self.parser.add_argument( + "--endpoint", + type=str, + help="Llama Stack distribution endpoint", + ) + + def _run_configure_cmd(self, args: argparse.Namespace): + from prompt_toolkit import prompt + from prompt_toolkit.validation import Validator + + os.makedirs(LLAMA_STACK_CLIENT_CONFIG_DIR, exist_ok=True) + config_path = get_config_file_path() + + if args.endpoint: + endpoint = args.endpoint + else: + if args.host and args.port: + endpoint = f"http://{args.host}:{args.port}" + else: + host = prompt( + "> Enter the host name of the Llama Stack distribution server: ", + validator=Validator.from_callable( + lambda x: len(x) > 0, + error_message="Host cannot be empty, please enter a valid host", + ), + ) + port = prompt( + "> Enter the port number of the Llama Stack distribution server: ", + validator=Validator.from_callable( + lambda x: x.isdigit(), + error_message="Please enter a valid port number", + ), + ) + endpoint = f"http://{host}:{port}" + + with open(config_path, "w") as f: + f.write( + yaml.dump( + { + "endpoint": endpoint, + }, + sort_keys=True, + ) + ) + + print( + f"Done! You can now use the Llama Stack Client CLI with endpoint {endpoint}" + ) diff --git a/src/llama_stack_client/lib/cli/constants.py b/src/llama_stack_client/lib/cli/constants.py new file mode 100644 index 00000000..6892c1ef --- /dev/null +++ b/src/llama_stack_client/lib/cli/constants.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +from pathlib import Path + +LLAMA_STACK_CLIENT_CONFIG_DIR = Path(os.path.expanduser("~/.llama/client")) diff --git a/src/llama_stack_client/lib/cli/llama_stack_client.py b/src/llama_stack_client/lib/cli/llama_stack_client.py new file mode 100644 index 00000000..c61761a7 --- /dev/null +++ b/src/llama_stack_client/lib/cli/llama_stack_client.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse + +from .configure import ConfigureParser +from .memory_banks import MemoryBanksParser + +from .models import ModelsParser +from .shields import ShieldsParser + + +class LlamaStackClientCLIParser: + """Define CLI parse for LlamaStackClient CLI""" + + def __init__(self): + self.parser = argparse.ArgumentParser( + prog="llama-stack-client", + description="Welcome to the LlamaStackClient CLI", + ) + # Default command is to print help + self.parser.set_defaults(func=lambda _: self.parser.print_help()) + + subparsers = self.parser.add_subparsers(title="subcommands") + + # add sub-commands + ModelsParser.create(subparsers) + MemoryBanksParser.create(subparsers) + ShieldsParser.create(subparsers) + ConfigureParser.create(subparsers) + + def parse_args(self) -> argparse.Namespace: + return self.parser.parse_args() + + def run(self, args: argparse.Namespace) -> None: + args.func(args) + + +def main(): + parser = LlamaStackClientCLIParser() + args = parser.parse_args() + parser.run(args) + + +if __name__ == "__main__": + main() diff --git a/src/llama_stack_client/lib/cli/memory_banks/__init__.py b/src/llama_stack_client/lib/cli/memory_banks/__init__.py new file mode 100644 index 00000000..5eefd7b7 --- /dev/null +++ b/src/llama_stack_client/lib/cli/memory_banks/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .memory_banks import MemoryBanksParser # noqa diff --git a/src/llama_stack_client/lib/cli/memory_banks/list.py b/src/llama_stack_client/lib/cli/memory_banks/list.py new file mode 100644 index 00000000..5bcc43a5 --- /dev/null +++ b/src/llama_stack_client/lib/cli/memory_banks/list.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse +import json + +from llama_stack_client import LlamaStackClient +from llama_stack_client.lib.cli.configure import get_config +from llama_stack_client.lib.cli.subcommand import Subcommand + +from tabulate import tabulate + + +class MemoryBanksList(Subcommand): + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "list", + prog="llama-stack-client memory_banks list", + description="Show available memory banks type on distribution endpoint", + formatter_class=argparse.RawTextHelpFormatter, + ) + self._add_arguments() + self.parser.set_defaults(func=self._run_memory_banks_list_cmd) + + def _add_arguments(self): + self.endpoint = get_config().get("endpoint") + self.parser.add_argument( + "--endpoint", + type=str, + help="Llama Stack distribution endpoint", + default=self.endpoint, + ) + + def _run_memory_banks_list_cmd(self, args: argparse.Namespace): + client = LlamaStackClient( + base_url=args.endpoint, + ) + + headers = [ + "Memory Bank Type", + "Provider Type", + "Provider Config", + ] + + memory_banks_list_response = client.memory_banks.list() + rows = [] + + for bank_spec in memory_banks_list_response: + rows.append( + [ + bank_spec.bank_type, + bank_spec.provider_config.provider_type, + json.dumps(bank_spec.provider_config.config, indent=4), + ] + ) + + print(tabulate(rows, headers=headers, tablefmt="grid")) diff --git a/src/llama_stack_client/lib/cli/memory_banks/memory_banks.py b/src/llama_stack_client/lib/cli/memory_banks/memory_banks.py new file mode 100644 index 00000000..a70fe640 --- /dev/null +++ b/src/llama_stack_client/lib/cli/memory_banks/memory_banks.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +import argparse + +from llama_stack_client.lib.cli.memory_banks.list import MemoryBanksList + +from llama_stack_client.lib.cli.subcommand import Subcommand + + +class MemoryBanksParser(Subcommand): + """List details about available memory banks type on distribution.""" + + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "memory_banks", + prog="llama-stack-client memory_banks", + description="Query details about available memory banks type on distribution.", + formatter_class=argparse.RawTextHelpFormatter, + ) + + subparsers = self.parser.add_subparsers(title="memory_banks_subcommands") + MemoryBanksList.create(subparsers) diff --git a/src/llama_stack_client/lib/cli/models/__init__.py b/src/llama_stack_client/lib/cli/models/__init__.py new file mode 100644 index 00000000..55b202d2 --- /dev/null +++ b/src/llama_stack_client/lib/cli/models/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .models import ModelsParser # noqa diff --git a/src/llama_stack_client/lib/cli/models/get.py b/src/llama_stack_client/lib/cli/models/get.py new file mode 100644 index 00000000..22512dcf --- /dev/null +++ b/src/llama_stack_client/lib/cli/models/get.py @@ -0,0 +1,74 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import argparse + +from tabulate import tabulate + +from llama_stack_client import LlamaStackClient +from llama_stack_client.lib.cli.configure import get_config +from llama_stack_client.lib.cli.subcommand import Subcommand + + +class ModelsGet(Subcommand): + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "get", + prog="llama-stack-client models get", + description="Show available llama models at distribution endpoint", + formatter_class=argparse.RawTextHelpFormatter, + ) + self._add_arguments() + self.parser.set_defaults(func=self._run_models_list_cmd) + + def _add_arguments(self): + self.parser.add_argument( + "model_id", + type=str, + help="Model ID to query information about", + ) + + self.endpoint = get_config().get("endpoint") + self.parser.add_argument( + "--endpoint", + type=str, + help="Llama Stack distribution endpoint", + default=self.endpoint, + ) + + def _run_models_list_cmd(self, args: argparse.Namespace): + client = LlamaStackClient( + base_url=args.endpoint, + ) + + headers = [ + "Model ID (model)", + "Model Metadata", + "Provider Type", + "Provider Config", + ] + + models_get_response = client.models.get(core_model_id=args.model_id) + + if not models_get_response: + print( + f"Model {args.model_id} is not found at distribution endpoint {args.endpoint}. Please ensure endpoint is serving specified model. " + ) + return + + rows = [] + rows.append( + [ + models_get_response.llama_model["core_model_id"], + json.dumps(models_get_response.llama_model, indent=4), + models_get_response.provider_config.provider_type, + json.dumps(models_get_response.provider_config.config, indent=4), + ] + ) + + print(tabulate(rows, headers=headers, tablefmt="grid")) diff --git a/src/llama_stack_client/lib/cli/models/list.py b/src/llama_stack_client/lib/cli/models/list.py new file mode 100644 index 00000000..a705bbd1 --- /dev/null +++ b/src/llama_stack_client/lib/cli/models/list.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import argparse + +from tabulate import tabulate + +from llama_stack_client import LlamaStackClient +from llama_stack_client.lib.cli.configure import get_config +from llama_stack_client.lib.cli.subcommand import Subcommand + + +class ModelsList(Subcommand): + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "list", + prog="llama-stack-client models list", + description="Show available llama models at distribution endpoint", + formatter_class=argparse.RawTextHelpFormatter, + ) + self._add_arguments() + self.parser.set_defaults(func=self._run_models_list_cmd) + + def _add_arguments(self): + self.endpoint = get_config().get("endpoint") + self.parser.add_argument( + "--endpoint", + type=str, + help="Llama Stack distribution endpoint", + default=self.endpoint, + ) + + def _run_models_list_cmd(self, args: argparse.Namespace): + client = LlamaStackClient( + base_url=args.endpoint, + ) + + headers = [ + "Model ID (model)", + "Model Metadata", + "Provider Type", + "Provider Config", + ] + + models_list_response = client.models.list() + rows = [] + + for model_spec in models_list_response: + rows.append( + [ + model_spec.llama_model["core_model_id"], + json.dumps(model_spec.llama_model, indent=4), + model_spec.provider_config.provider_type, + json.dumps(model_spec.provider_config.config, indent=4), + ] + ) + + print(tabulate(rows, headers=headers, tablefmt="grid")) diff --git a/src/llama_stack_client/lib/cli/models/models.py b/src/llama_stack_client/lib/cli/models/models.py new file mode 100644 index 00000000..846c75f0 --- /dev/null +++ b/src/llama_stack_client/lib/cli/models/models.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse + +from llama_stack_client.lib.cli.models.get import ModelsGet +from llama_stack_client.lib.cli.subcommand import Subcommand +from llama_stack_client.lib.cli.models.list import ModelsList + + +class ModelsParser(Subcommand): + """List details about available models on distribution.""" + + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "models", + prog="llama-stack-client models", + description="Query details about available models on Llama Stack distributiom. ", + formatter_class=argparse.RawTextHelpFormatter, + ) + + subparsers = self.parser.add_subparsers(title="models_subcommands") + ModelsList.create(subparsers) + ModelsGet.create(subparsers) diff --git a/src/llama_stack_client/lib/cli/shields/__init__.py b/src/llama_stack_client/lib/cli/shields/__init__.py new file mode 100644 index 00000000..19c9ce7d --- /dev/null +++ b/src/llama_stack_client/lib/cli/shields/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .shields import ShieldsParser # noqa diff --git a/src/llama_stack_client/lib/cli/shields/list.py b/src/llama_stack_client/lib/cli/shields/list.py new file mode 100644 index 00000000..8bbd7922 --- /dev/null +++ b/src/llama_stack_client/lib/cli/shields/list.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import argparse + +from tabulate import tabulate + +from llama_stack_client import LlamaStackClient +from llama_stack_client.lib.cli.configure import get_config +from llama_stack_client.lib.cli.subcommand import Subcommand + + +class ShieldsList(Subcommand): + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "list", + prog="llama-stack-client shields list", + description="Show available llama models at distribution endpoint", + formatter_class=argparse.RawTextHelpFormatter, + ) + self._add_arguments() + self.parser.set_defaults(func=self._run_shields_list_cmd) + + def _add_arguments(self): + self.endpoint = get_config().get("endpoint") + self.parser.add_argument( + "--endpoint", + type=str, + help="Llama Stack distribution endpoint", + default=self.endpoint, + ) + + def _run_shields_list_cmd(self, args: argparse.Namespace): + if not args.endpoint: + self.parser.error( + "A valid endpoint is required. Please run llama-stack-client configure first or pass in a valid endpoint with --endpoint. " + ) + client = LlamaStackClient( + base_url=args.endpoint, + ) + + headers = [ + "Shield Type (shield_type)", + "Provider Type", + "Provider Config", + ] + + shields_list_response = client.shields.list() + rows = [] + + for shield_spec in shields_list_response: + rows.append( + [ + shield_spec.shield_type, + shield_spec.provider_config.provider_type, + json.dumps(shield_spec.provider_config.config, indent=4), + ] + ) + + print(tabulate(rows, headers=headers, tablefmt="grid")) diff --git a/src/llama_stack_client/lib/cli/shields/shields.py b/src/llama_stack_client/lib/cli/shields/shields.py new file mode 100644 index 00000000..ab58d855 --- /dev/null +++ b/src/llama_stack_client/lib/cli/shields/shields.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse + +from llama_stack_client.lib.cli.subcommand import Subcommand +from llama_stack_client.lib.cli.shields.list import ShieldsList + + +class ShieldsParser(Subcommand): + """List details about available safety shields on distribution.""" + + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "shields", + prog="llama-stack-client shields", + description="Query details about available safety shields on distribution.", + formatter_class=argparse.RawTextHelpFormatter, + ) + + subparsers = self.parser.add_subparsers(title="shields_subcommands") + ShieldsList.create(subparsers) diff --git a/src/llama_stack_client/lib/cli/subcommand.py b/src/llama_stack_client/lib/cli/subcommand.py new file mode 100644 index 00000000..b97637ec --- /dev/null +++ b/src/llama_stack_client/lib/cli/subcommand.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +class Subcommand: + """All llama cli subcommands must inherit from this class""" + + def __init__(self, *args, **kwargs): + pass + + @classmethod + def create(cls, *args, **kwargs): + return cls(*args, **kwargs) + + def _add_arguments(self): + pass diff --git a/src/llama_stack_client/lib/inference/event_logger.py b/src/llama_stack_client/lib/inference/event_logger.py index 3faa92de..6e4f1a7e 100644 --- a/src/llama_stack_client/lib/inference/event_logger.py +++ b/src/llama_stack_client/lib/inference/event_logger.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List, Optional, Union from llama_stack_client.types import ( ChatCompletionStreamChunk, diff --git a/src/llama_stack_client/types/agent_create_response.py b/src/llama_stack_client/types/agent_create_response.py index be253645..65d2275f 100644 --- a/src/llama_stack_client/types/agent_create_response.py +++ b/src/llama_stack_client/types/agent_create_response.py @@ -1,7 +1,6 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - from .._models import BaseModel __all__ = ["AgentCreateResponse"] diff --git a/src/llama_stack_client/types/agents/agents_turn_stream_chunk.py b/src/llama_stack_client/types/agents/agents_turn_stream_chunk.py index 79fd2d3e..e148aa79 100644 --- a/src/llama_stack_client/types/agents/agents_turn_stream_chunk.py +++ b/src/llama_stack_client/types/agents/agents_turn_stream_chunk.py @@ -1,7 +1,6 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - from ..._models import BaseModel from .turn_stream_event import TurnStreamEvent diff --git a/src/llama_stack_client/types/agents/session_create_response.py b/src/llama_stack_client/types/agents/session_create_response.py index 13d5a35f..6adcf0b2 100644 --- a/src/llama_stack_client/types/agents/session_create_response.py +++ b/src/llama_stack_client/types/agents/session_create_response.py @@ -1,7 +1,6 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - from ..._models import BaseModel __all__ = ["SessionCreateResponse"] diff --git a/src/llama_stack_client/types/evaluate/evaluation_job_artifacts.py b/src/llama_stack_client/types/evaluate/evaluation_job_artifacts.py index 6642fe37..f8215b69 100644 --- a/src/llama_stack_client/types/evaluate/evaluation_job_artifacts.py +++ b/src/llama_stack_client/types/evaluate/evaluation_job_artifacts.py @@ -1,7 +1,6 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - from ..._models import BaseModel __all__ = ["EvaluationJobArtifacts"] diff --git a/src/llama_stack_client/types/evaluate/evaluation_job_log_stream.py b/src/llama_stack_client/types/evaluate/evaluation_job_log_stream.py index ec9b7356..ea3f0b90 100644 --- a/src/llama_stack_client/types/evaluate/evaluation_job_log_stream.py +++ b/src/llama_stack_client/types/evaluate/evaluation_job_log_stream.py @@ -1,7 +1,6 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - from ..._models import BaseModel __all__ = ["EvaluationJobLogStream"] diff --git a/src/llama_stack_client/types/evaluate/evaluation_job_status.py b/src/llama_stack_client/types/evaluate/evaluation_job_status.py index dfc9498f..56d69757 100644 --- a/src/llama_stack_client/types/evaluate/evaluation_job_status.py +++ b/src/llama_stack_client/types/evaluate/evaluation_job_status.py @@ -1,7 +1,6 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - from ..._models import BaseModel __all__ = ["EvaluationJobStatus"] diff --git a/src/llama_stack_client/types/evaluation_job.py b/src/llama_stack_client/types/evaluation_job.py index c8f291b9..5c0b51f7 100644 --- a/src/llama_stack_client/types/evaluation_job.py +++ b/src/llama_stack_client/types/evaluation_job.py @@ -1,7 +1,6 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - from .._models import BaseModel __all__ = ["EvaluationJob"] diff --git a/src/llama_stack_client/types/post_training_job.py b/src/llama_stack_client/types/post_training_job.py index 1195facc..8cd98126 100644 --- a/src/llama_stack_client/types/post_training_job.py +++ b/src/llama_stack_client/types/post_training_job.py @@ -1,7 +1,6 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - from .._models import BaseModel __all__ = ["PostTrainingJob"]