From 2e057e25b8b16ba564646ad3ca28834b1bc417c8 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 13 Nov 2024 10:03:56 -0800 Subject: [PATCH 1/4] migrate to click --- .../lib/cli/datasets/__init__.py | 4 +- .../lib/cli/datasets/datasets.py | 24 ++---- .../lib/cli/datasets/list.py | 43 +++-------- .../lib/cli/eval_tasks/__init__.py | 2 +- .../lib/cli/eval_tasks/eval_tasks.py | 24 ++---- .../lib/cli/eval_tasks/list.py | 40 +++------- .../lib/cli/llama_stack_client.py | 76 ++++++++----------- .../lib/cli/memory_banks/__init__.py | 2 +- .../lib/cli/memory_banks/list.py | 54 ++++--------- .../lib/cli/memory_banks/memory_banks.py | 25 ++---- .../lib/cli/models/__init__.py | 2 +- src/llama_stack_client/lib/cli/models/get.py | 66 +++++----------- src/llama_stack_client/lib/cli/models/list.py | 44 +++-------- .../lib/cli/models/models.py | 28 +++---- src/llama_stack_client/lib/cli/providers.py | 59 -------------- .../lib/cli/providers/__init__.py | 1 + .../lib/cli/providers/list.py | 20 +++++ .../lib/cli/providers/providers.py | 13 ++++ .../lib/cli/scoring_functions/__init__.py | 4 +- .../lib/cli/scoring_functions/list.py | 52 ++++--------- .../scoring_functions/scoring_functions.py | 25 +++--- .../lib/cli/shields/__init__.py | 2 +- .../lib/cli/shields/list.py | 48 +++--------- .../lib/cli/shields/shields.py | 24 ++---- 24 files changed, 209 insertions(+), 473 deletions(-) delete mode 100644 src/llama_stack_client/lib/cli/providers.py create mode 100644 src/llama_stack_client/lib/cli/providers/__init__.py create mode 100644 src/llama_stack_client/lib/cli/providers/list.py create mode 100644 src/llama_stack_client/lib/cli/providers/providers.py diff --git a/src/llama_stack_client/lib/cli/datasets/__init__.py b/src/llama_stack_client/lib/cli/datasets/__init__.py index 049d72ee..ec7b144f 100644 --- a/src/llama_stack_client/lib/cli/datasets/__init__.py +++ b/src/llama_stack_client/lib/cli/datasets/__init__.py @@ -4,6 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .datasets import DatasetsParser +from .datasets import datasets -__all__ = ["DatasetsParser"] +__all__ = ["datasets"] diff --git a/src/llama_stack_client/lib/cli/datasets/datasets.py b/src/llama_stack_client/lib/cli/datasets/datasets.py index e53743c7..c37a7aa5 100644 --- a/src/llama_stack_client/lib/cli/datasets/datasets.py +++ b/src/llama_stack_client/lib/cli/datasets/datasets.py @@ -4,24 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import argparse +import click -from llama_stack_client.lib.cli.subcommand import Subcommand -from .list import DatasetsList +from .list import list_datasets -class DatasetsParser(Subcommand): - """Parser for datasets commands""" +@click.group() +def datasets(): + """Query details about available datasets on Llama Stack distribution.""" + pass - @classmethod - def create(cls, subparsers: argparse._SubParsersAction): - parser = subparsers.add_parser( - "datasets", - help="Manage datasets", - formatter_class=argparse.RawTextHelpFormatter, - ) - parser.set_defaults(func=lambda _: parser.print_help()) - # Create subcommands - datasets_subparsers = parser.add_subparsers(title="subcommands") - DatasetsList(datasets_subparsers) +# Register subcommands +datasets.add_command(list_datasets) diff --git a/src/llama_stack_client/lib/cli/datasets/list.py b/src/llama_stack_client/lib/cli/datasets/list.py index 923e5548..731dc4fd 100644 --- a/src/llama_stack_client/lib/cli/datasets/list.py +++ b/src/llama_stack_client/lib/cli/datasets/list.py @@ -4,42 +4,19 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import argparse +import click -from llama_stack_client import LlamaStackClient from llama_stack_client.lib.cli.common.utils import print_table_from_response -from llama_stack_client.lib.cli.configure import get_config -from llama_stack_client.lib.cli.subcommand import Subcommand -class DatasetsList(Subcommand): - def __init__(self, subparsers: argparse._SubParsersAction): - super().__init__() - self.parser = subparsers.add_parser( - "list", - prog="llama-stack-client datasets list", - description="Show available datasets on distribution endpoint", - formatter_class=argparse.RawTextHelpFormatter, - ) - self._add_arguments() - self.parser.set_defaults(func=self._run_datasets_list_cmd) +@click.command("list") +@click.pass_context +def list_datasets(ctx): + """Show available datasets on distribution endpoint""" + client = ctx.obj["client"] - def _add_arguments(self): - self.parser.add_argument( - "--endpoint", - type=str, - help="Llama Stack distribution endpoint", - ) + headers = ["identifier", "provider_id", "metadata", "type"] - def _run_datasets_list_cmd(self, args: argparse.Namespace): - args.endpoint = get_config().get("endpoint") or args.endpoint - - client = LlamaStackClient( - base_url=args.endpoint, - ) - - headers = ["identifier", "provider_id", "metadata", "type"] - - datasets_list_response = client.datasets.list() - if datasets_list_response: - print_table_from_response(datasets_list_response, headers) + datasets_list_response = client.datasets.list() + if datasets_list_response: + print_table_from_response(datasets_list_response, headers) diff --git a/src/llama_stack_client/lib/cli/eval_tasks/__init__.py b/src/llama_stack_client/lib/cli/eval_tasks/__init__.py index ae642407..010ffb76 100644 --- a/src/llama_stack_client/lib/cli/eval_tasks/__init__.py +++ b/src/llama_stack_client/lib/cli/eval_tasks/__init__.py @@ -4,4 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .eval_tasks import EvalTasksParser # noqa +from .eval_tasks import eval_tasks diff --git a/src/llama_stack_client/lib/cli/eval_tasks/eval_tasks.py b/src/llama_stack_client/lib/cli/eval_tasks/eval_tasks.py index f50da552..1983d782 100644 --- a/src/llama_stack_client/lib/cli/eval_tasks/eval_tasks.py +++ b/src/llama_stack_client/lib/cli/eval_tasks/eval_tasks.py @@ -5,24 +5,16 @@ # the root directory of this source tree. -import argparse +import click -from llama_stack_client.lib.cli.eval_tasks.list import EvalTasksList +from .list import list_eval_tasks -from llama_stack_client.lib.cli.subcommand import Subcommand +@click.group() +def eval_tasks(): + """Query details about available eval tasks type on distribution.""" + pass -class EvalTasksParser(Subcommand): - """List details about available eval banks type on distribution.""" - def __init__(self, subparsers: argparse._SubParsersAction): - super().__init__() - self.parser = subparsers.add_parser( - "eval_tasks", - prog="llama-stack-client eval_tasks", - description="Query details about available eval tasks type on distribution.", - formatter_class=argparse.RawTextHelpFormatter, - ) - - subparsers = self.parser.add_subparsers(title="eval_tasks_subcommands") - EvalTasksList.create(subparsers) +# Register subcommands +eval_tasks.add_command(list_eval_tasks) diff --git a/src/llama_stack_client/lib/cli/eval_tasks/list.py b/src/llama_stack_client/lib/cli/eval_tasks/list.py index 2b60539a..88585a4b 100644 --- a/src/llama_stack_client/lib/cli/eval_tasks/list.py +++ b/src/llama_stack_client/lib/cli/eval_tasks/list.py @@ -4,40 +4,20 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import argparse +import click -from llama_stack_client import LlamaStackClient from llama_stack_client.lib.cli.common.utils import print_table_from_response -from llama_stack_client.lib.cli.configure import get_config -from llama_stack_client.lib.cli.subcommand import Subcommand -class EvalTasksList(Subcommand): - def __init__(self, subparsers: argparse._SubParsersAction): - super().__init__() - self.parser = subparsers.add_parser( - "list", - prog="llama-stack-client eval_tasks list", - description="Show available evaluation tasks on distribution endpoint", - formatter_class=argparse.RawTextHelpFormatter, - ) - self._add_arguments() - self.parser.set_defaults(func=self._run_eval_tasks_list_cmd) +@click.command("list") +@click.pass_context +def list_eval_tasks(ctx): + """Show available eval tasks on distribution endpoint""" - def _add_arguments(self): - self.parser.add_argument( - "--endpoint", - type=str, - help="Llama Stack distribution endpoint", - ) + client = ctx.obj["client"] - def _run_eval_tasks_list_cmd(self, args: argparse.Namespace): - args.endpoint = get_config().get("endpoint") or args.endpoint + headers = ["identifier", "provider_id", "description", "type"] - client = LlamaStackClient( - base_url=args.endpoint, - ) - - eval_tasks_list_response = client.eval_tasks.list() - if eval_tasks_list_response: - print_table_from_response(eval_tasks_list_response) + eval_tasks_list_response = client.eval_tasks.list() + if eval_tasks_list_response: + print_table_from_response(eval_tasks_list_response, headers) diff --git a/src/llama_stack_client/lib/cli/llama_stack_client.py b/src/llama_stack_client/lib/cli/llama_stack_client.py index 3631aa25..a56be509 100644 --- a/src/llama_stack_client/lib/cli/llama_stack_client.py +++ b/src/llama_stack_client/lib/cli/llama_stack_client.py @@ -4,62 +4,46 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import argparse +import click +from llama_stack_client import LlamaStackClient -from llama_stack_client.lib.cli.constants import get_config_file_path +# from .configure import configure +from .datasets import datasets +from .eval_tasks import eval_tasks +from .memory_banks import memory_banks +from .models import models -from .configure import ConfigureParser -from .datasets import DatasetsParser -from .eval_tasks import EvalTasksParser -from .memory_banks import MemoryBanksParser +from .providers import providers +from .scoring_functions import scoring_functions -from .models import ModelsParser -from .providers import ProvidersParser -from .scoring_functions import ScoringFunctionsParser -from .shields import ShieldsParser +from .shields import shields -class LlamaStackClientCLIParser: - """Define CLI parse for LlamaStackClient CLI""" +@click.group() +@click.option("--endpoint", type=str, help="Llama Stack distribution endpoint", default="http://localhost:5000") +@click.pass_context +def cli(ctx, endpoint: str): + """Welcome to the LlamaStackClient CLI""" + ctx.ensure_object(dict) + client = LlamaStackClient( + base_url=endpoint, + ) + ctx.obj = {"client": client} - def __init__(self): - self.parser = argparse.ArgumentParser( - prog="llama-stack-client", - description="Welcome to the LlamaStackClient CLI", - ) - # Default command is to print help - self.parser.set_defaults(func=lambda _: self.parser.print_help()) - subparsers = self.parser.add_subparsers(title="subcommands") - - # add sub-commands - ModelsParser.create(subparsers) - MemoryBanksParser.create(subparsers) - ShieldsParser.create(subparsers) - EvalTasksParser.create(subparsers) - ConfigureParser.create(subparsers) - ProvidersParser.create(subparsers) - DatasetsParser.create(subparsers) - ScoringFunctionsParser.create(subparsers) - - def parse_args(self) -> argparse.Namespace: - return self.parser.parse_args() - - def command_requires_config(self, args: argparse.Namespace) -> bool: - return not (hasattr(args.func, "__self__") and isinstance(args.func.__self__, ConfigureParser)) - - def run(self, args: argparse.Namespace) -> None: - if self.command_requires_config(args) and not get_config_file_path().exists(): - print("Config file not found. Please run 'llama-stack-client configure' to create one.") - return - - args.func(args) +# Register all subcommands +cli.add_command(models, "models") +cli.add_command(memory_banks, "memory_banks") +cli.add_command(shields, "shields") +cli.add_command(eval_tasks, "eval_tasks") +# cli.add_command(configure) +cli.add_command(providers, "providers") +cli.add_command(datasets, "datasets") +cli.add_command(scoring_functions, "scoring_functions") def main(): - parser = LlamaStackClientCLIParser() - args = parser.parse_args() - parser.run(args) + cli() if __name__ == "__main__": diff --git a/src/llama_stack_client/lib/cli/memory_banks/__init__.py b/src/llama_stack_client/lib/cli/memory_banks/__init__.py index 5eefd7b7..eb25ec3d 100644 --- a/src/llama_stack_client/lib/cli/memory_banks/__init__.py +++ b/src/llama_stack_client/lib/cli/memory_banks/__init__.py @@ -4,4 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .memory_banks import MemoryBanksParser # noqa +from .memory_banks import memory_banks diff --git a/src/llama_stack_client/lib/cli/memory_banks/list.py b/src/llama_stack_client/lib/cli/memory_banks/list.py index f13c9828..a44d7de0 100644 --- a/src/llama_stack_client/lib/cli/memory_banks/list.py +++ b/src/llama_stack_client/lib/cli/memory_banks/list.py @@ -4,49 +4,25 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import argparse +import click -from llama_stack_client import LlamaStackClient from llama_stack_client.lib.cli.common.utils import print_table_from_response -from llama_stack_client.lib.cli.configure import get_config -from llama_stack_client.lib.cli.subcommand import Subcommand -class MemoryBanksList(Subcommand): - def __init__(self, subparsers: argparse._SubParsersAction): - super().__init__() - self.parser = subparsers.add_parser( - "list", - prog="llama-stack-client memory_banks list", - description="Show available memory banks type on distribution endpoint", - formatter_class=argparse.RawTextHelpFormatter, - ) - self._add_arguments() - self.parser.set_defaults(func=self._run_memory_banks_list_cmd) +@click.command("list") +@click.pass_context +def list_memory_banks(ctx): + """Show available memory banks on distribution endpoint""" - def _add_arguments(self): - self.parser.add_argument( - "--endpoint", - type=str, - help="Llama Stack distribution endpoint", - ) + client = ctx.obj["client"] - def _run_memory_banks_list_cmd(self, args: argparse.Namespace): - args.endpoint = get_config().get("endpoint") or args.endpoint + headers = [ + "identifier", + "provider_id", + "description", + "type", + ] - client = LlamaStackClient( - base_url=args.endpoint, - ) - - headers = [ - "identifier", - "provider_id", - "type", - "embedding_model", - "chunk_size_in_tokens", - "overlap_size_in_tokens", - ] - - memory_banks_list_response = client.memory_banks.list() - if memory_banks_list_response: - print_table_from_response(memory_banks_list_response, headers) + memory_banks_list_response = client.memory_banks.list() + if memory_banks_list_response: + print_table_from_response(memory_banks_list_response, headers) diff --git a/src/llama_stack_client/lib/cli/memory_banks/memory_banks.py b/src/llama_stack_client/lib/cli/memory_banks/memory_banks.py index a70fe640..088d647e 100644 --- a/src/llama_stack_client/lib/cli/memory_banks/memory_banks.py +++ b/src/llama_stack_client/lib/cli/memory_banks/memory_banks.py @@ -4,25 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import click -import argparse +from .list import list_memory_banks -from llama_stack_client.lib.cli.memory_banks.list import MemoryBanksList -from llama_stack_client.lib.cli.subcommand import Subcommand +@click.group() +def memory_banks(): + """Query details about available memory banks type on distribution.""" + pass -class MemoryBanksParser(Subcommand): - """List details about available memory banks type on distribution.""" - - def __init__(self, subparsers: argparse._SubParsersAction): - super().__init__() - self.parser = subparsers.add_parser( - "memory_banks", - prog="llama-stack-client memory_banks", - description="Query details about available memory banks type on distribution.", - formatter_class=argparse.RawTextHelpFormatter, - ) - - subparsers = self.parser.add_subparsers(title="memory_banks_subcommands") - MemoryBanksList.create(subparsers) +# Register subcommands +memory_banks.add_command(list_memory_banks) diff --git a/src/llama_stack_client/lib/cli/models/__init__.py b/src/llama_stack_client/lib/cli/models/__init__.py index 55b202d2..c23692d7 100644 --- a/src/llama_stack_client/lib/cli/models/__init__.py +++ b/src/llama_stack_client/lib/cli/models/__init__.py @@ -4,4 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .models import ModelsParser # noqa +from .models import models diff --git a/src/llama_stack_client/lib/cli/models/get.py b/src/llama_stack_client/lib/cli/models/get.py index 89e4d3da..c242fb22 100644 --- a/src/llama_stack_client/lib/cli/models/get.py +++ b/src/llama_stack_client/lib/cli/models/get.py @@ -4,60 +4,28 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import argparse - +import click from tabulate import tabulate -from llama_stack_client import LlamaStackClient -from llama_stack_client.lib.cli.configure import get_config -from llama_stack_client.lib.cli.subcommand import Subcommand - - -class ModelsGet(Subcommand): - def __init__(self, subparsers: argparse._SubParsersAction): - super().__init__() - self.parser = subparsers.add_parser( - "get", - prog="llama-stack-client models get", - description="Show available llama models at distribution endpoint", - formatter_class=argparse.RawTextHelpFormatter, - ) - self._add_arguments() - self.parser.set_defaults(func=self._run_models_list_cmd) - - def _add_arguments(self): - self.parser.add_argument( - "model_id", - type=str, - help="Model ID to query information about", - ) - self.parser.add_argument( - "--endpoint", - type=str, - help="Llama Stack distribution endpoint", - ) +@click.command(name="get") +@click.argument("model_id") +@click.pass_context +def get_model(ctx, model_id: str): + """Show available llama models at distribution endpoint""" + client = ctx.obj["client"] - def _run_models_list_cmd(self, args: argparse.Namespace): - config = get_config() - if config: - args.endpoint = config.get("endpoint") + models_get_response = client.models.retrieve(identifier=model_id) - client = LlamaStackClient( - base_url=args.endpoint, + if not models_get_response: + click.echo( + f"Model {model_id} is not found at distribution endpoint. " + "Please ensure endpoint is serving specified model." ) + return - models_get_response = client.models.retrieve(identifier=args.model_id) - - if not models_get_response: - print( - f"Model {args.model_id} is not found at distribution endpoint {args.endpoint}. Please ensure endpoint is serving specified model. " - ) - return - - headers = sorted(models_get_response.__dict__.keys()) - - rows = [] - rows.append([models_get_response.__dict__[headers[i]] for i in range(len(headers))]) + headers = sorted(models_get_response.__dict__.keys()) + rows = [] + rows.append([models_get_response.__dict__[headers[i]] for i in range(len(headers))]) - print(tabulate(rows, headers=headers, tablefmt="grid")) + click.echo(tabulate(rows, headers=headers, tablefmt="grid")) diff --git a/src/llama_stack_client/lib/cli/models/list.py b/src/llama_stack_client/lib/cli/models/list.py index 42d665e2..7d335224 100644 --- a/src/llama_stack_client/lib/cli/models/list.py +++ b/src/llama_stack_client/lib/cli/models/list.py @@ -4,43 +4,17 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import argparse +import click -from llama_stack_client import LlamaStackClient from llama_stack_client.lib.cli.common.utils import print_table_from_response -from llama_stack_client.lib.cli.configure import get_config -from llama_stack_client.lib.cli.subcommand import Subcommand -class ModelsList(Subcommand): - def __init__(self, subparsers: argparse._SubParsersAction): - super().__init__() - self.parser = subparsers.add_parser( - "list", - prog="llama-stack-client models list", - description="Show available llama models at distribution endpoint", - formatter_class=argparse.RawTextHelpFormatter, - ) - self._add_arguments() - self.parser.set_defaults(func=self._run_models_list_cmd) +@click.command(name="list", help="Show available llama models at distribution endpoint") +@click.pass_context +def list_models(ctx): + client = ctx.obj["client"] - def _add_arguments(self): - self.parser.add_argument( - "--endpoint", - type=str, - help="Llama Stack distribution endpoint", - ) - - def _run_models_list_cmd(self, args: argparse.Namespace): - config = get_config() - if config: - args.endpoint = config.get("endpoint") - - client = LlamaStackClient( - base_url=args.endpoint, - ) - - headers = ["identifier", "llama_model", "provider_id", "metadata"] - response = client.models.list() - if response: - print_table_from_response(response, headers) + headers = ["identifier", "provider_id", "provider_resource_id", "metadata"] + response = client.models.list() + if response: + print_table_from_response(response, headers) diff --git a/src/llama_stack_client/lib/cli/models/models.py b/src/llama_stack_client/lib/cli/models/models.py index 3679e529..1655429a 100644 --- a/src/llama_stack_client/lib/cli/models/models.py +++ b/src/llama_stack_client/lib/cli/models/models.py @@ -4,25 +4,17 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import argparse +import click +from llama_stack_client.lib.cli.models.get import get_model +from llama_stack_client.lib.cli.models.list import list_models -from llama_stack_client.lib.cli.models.get import ModelsGet -from llama_stack_client.lib.cli.models.list import ModelsList -from llama_stack_client.lib.cli.subcommand import Subcommand +@click.group() +def models(): + """Query details about available models on Llama Stack distribution.""" + pass -class ModelsParser(Subcommand): - """List details about available models on distribution.""" - def __init__(self, subparsers: argparse._SubParsersAction): - super().__init__() - self.parser = subparsers.add_parser( - "models", - prog="llama-stack-client models", - description="Query details about available models on Llama Stack distributiom. ", - formatter_class=argparse.RawTextHelpFormatter, - ) - - subparsers = self.parser.add_subparsers(title="models_subcommands") - ModelsList.create(subparsers) - ModelsGet.create(subparsers) +# Register subcommands +models.add_command(list_models) +models.add_command(get_model) diff --git a/src/llama_stack_client/lib/cli/providers.py b/src/llama_stack_client/lib/cli/providers.py deleted file mode 100644 index 3be0ba6e..00000000 --- a/src/llama_stack_client/lib/cli/providers.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import argparse - -from tabulate import tabulate - -from llama_stack_client import LlamaStackClient -from llama_stack_client.lib.cli.configure import get_config -from llama_stack_client.lib.cli.subcommand import Subcommand - - -class ProvidersParser(Subcommand): - """Configure Llama Stack Client CLI""" - - def __init__(self, subparsers: argparse._SubParsersAction): - super().__init__() - self.parser = subparsers.add_parser( - "providers", - prog="llama-stack-client providers", - description="List available providers Llama Stack Client CLI", - formatter_class=argparse.RawTextHelpFormatter, - ) - self._add_arguments() - self.parser.set_defaults(func=self._run_providers_cmd) - - def _add_arguments(self): - self.parser.add_argument( - "--endpoint", - type=str, - help="Llama Stack distribution endpoint", - ) - - def _run_providers_cmd(self, args: argparse.Namespace): - config = get_config() - if config: - args.endpoint = config.get("endpoint") - - client = LlamaStackClient( - base_url=args.endpoint, - ) - - headers = [ - "API", - "Provider ID", - "Provider Type", - ] - - providers_response = client.providers.list() - rows = [] - - for k, v in providers_response.items(): - for provider_info in v: - rows.append([k, provider_info.provider_id, provider_info.provider_type]) - - print(tabulate(rows, headers=headers, tablefmt="grid")) diff --git a/src/llama_stack_client/lib/cli/providers/__init__.py b/src/llama_stack_client/lib/cli/providers/__init__.py new file mode 100644 index 00000000..82538fbb --- /dev/null +++ b/src/llama_stack_client/lib/cli/providers/__init__.py @@ -0,0 +1 @@ +from .providers import providers diff --git a/src/llama_stack_client/lib/cli/providers/list.py b/src/llama_stack_client/lib/cli/providers/list.py new file mode 100644 index 00000000..2babd42d --- /dev/null +++ b/src/llama_stack_client/lib/cli/providers/list.py @@ -0,0 +1,20 @@ +import click +from tabulate import tabulate + + +@click.command("list") +@click.pass_context +def list_providers(ctx): + """Show available providers on distribution endpoint""" + client = ctx.obj["client"] + + headers = ["API", "Provider ID", "Provider Type"] + + providers_response = client.providers.list() + rows = [] + + for k, v in providers_response.items(): + for provider_info in v: + rows.append([k, provider_info.provider_id, provider_info.provider_type]) + + click.echo(tabulate(rows, headers=headers, tablefmt="grid")) diff --git a/src/llama_stack_client/lib/cli/providers/providers.py b/src/llama_stack_client/lib/cli/providers/providers.py new file mode 100644 index 00000000..2486de29 --- /dev/null +++ b/src/llama_stack_client/lib/cli/providers/providers.py @@ -0,0 +1,13 @@ +import click + +from .list import list_providers + + +@click.group() +def providers(): + """Query details about available providers on Llama Stack distribution.""" + pass + + +# Register subcommands +providers.add_command(list_providers) diff --git a/src/llama_stack_client/lib/cli/scoring_functions/__init__.py b/src/llama_stack_client/lib/cli/scoring_functions/__init__.py index 98dea0bf..c3faacbd 100644 --- a/src/llama_stack_client/lib/cli/scoring_functions/__init__.py +++ b/src/llama_stack_client/lib/cli/scoring_functions/__init__.py @@ -4,6 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .scoring_functions import ScoringFunctionsParser +from .list import list_scoring_functions -__all__ = ["ScoringFunctionsParser"] +__all__ = ["list_scoring_functions"] diff --git a/src/llama_stack_client/lib/cli/scoring_functions/list.py b/src/llama_stack_client/lib/cli/scoring_functions/list.py index 2de0acfb..73abd0bc 100644 --- a/src/llama_stack_client/lib/cli/scoring_functions/list.py +++ b/src/llama_stack_client/lib/cli/scoring_functions/list.py @@ -4,47 +4,25 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import argparse +import click -from llama_stack_client import LlamaStackClient from llama_stack_client.lib.cli.common.utils import print_table_from_response -from llama_stack_client.lib.cli.configure import get_config -from llama_stack_client.lib.cli.subcommand import Subcommand -class ScoringFunctionsList(Subcommand): - def __init__(self, subparsers: argparse._SubParsersAction): - super().__init__() - self.parser = subparsers.add_parser( - "list", - prog="llama-stack-client scoring_functions list", - description="Show available scoring functions on distribution endpoint", - formatter_class=argparse.RawTextHelpFormatter, - ) - self._add_arguments() - self.parser.set_defaults(func=self._run_scoring_functions_list_cmd) +@click.command("list") +@click.pass_context +def list_scoring_functions(ctx): + """Show available scoring functions on distribution endpoint""" - def _add_arguments(self): - self.parser.add_argument( - "--endpoint", - type=str, - help="Llama Stack distribution endpoint", - ) + client = ctx.obj["client"] - def _run_scoring_functions_list_cmd(self, args: argparse.Namespace): - args.endpoint = get_config().get("endpoint") or args.endpoint + headers = [ + "identifier", + "provider_id", + "description", + "type", + ] - client = LlamaStackClient( - base_url=args.endpoint, - ) - - headers = [ - "identifier", - "provider_id", - "description", - "type", - ] - - scoring_functions_list_response = client.scoring_functions.list() - if scoring_functions_list_response: - print_table_from_response(scoring_functions_list_response, headers) + scoring_functions_list_response = client.scoring_functions.list() + if scoring_functions_list_response: + print_table_from_response(scoring_functions_list_response, headers) diff --git a/src/llama_stack_client/lib/cli/scoring_functions/scoring_functions.py b/src/llama_stack_client/lib/cli/scoring_functions/scoring_functions.py index 80593420..56cf26db 100644 --- a/src/llama_stack_client/lib/cli/scoring_functions/scoring_functions.py +++ b/src/llama_stack_client/lib/cli/scoring_functions/scoring_functions.py @@ -4,24 +4,17 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import argparse +import click -from llama_stack_client.lib.cli.subcommand import Subcommand -from .list import ScoringFunctionsList +from .list import list_scoring_functions -class ScoringFunctionsParser(Subcommand): - """Parser for scoring functions commands""" +@click.group() +@click.pass_context +def scoring_functions(ctx): + """Manage scoring functions""" + pass - @classmethod - def create(cls, subparsers: argparse._SubParsersAction): - parser = subparsers.add_parser( - "scoring_functions", - help="Manage scoring functions", - formatter_class=argparse.RawTextHelpFormatter, - ) - parser.set_defaults(func=lambda _: parser.print_help()) - # Create subcommands - scoring_functions_subparsers = parser.add_subparsers(title="subcommands") - ScoringFunctionsList(scoring_functions_subparsers) +# Register subcommands +scoring_functions.add_command(list_scoring_functions) diff --git a/src/llama_stack_client/lib/cli/shields/__init__.py b/src/llama_stack_client/lib/cli/shields/__init__.py index 19c9ce7d..5966f69f 100644 --- a/src/llama_stack_client/lib/cli/shields/__init__.py +++ b/src/llama_stack_client/lib/cli/shields/__init__.py @@ -4,4 +4,4 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .shields import ShieldsParser # noqa +from .shields import shields diff --git a/src/llama_stack_client/lib/cli/shields/list.py b/src/llama_stack_client/lib/cli/shields/list.py index 092dd633..ce112204 100644 --- a/src/llama_stack_client/lib/cli/shields/list.py +++ b/src/llama_stack_client/lib/cli/shields/list.py @@ -4,48 +4,20 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import argparse +import click -from llama_stack_client import LlamaStackClient from llama_stack_client.lib.cli.common.utils import print_table_from_response -from llama_stack_client.lib.cli.configure import get_config -from llama_stack_client.lib.cli.subcommand import Subcommand -class ShieldsList(Subcommand): - def __init__(self, subparsers: argparse._SubParsersAction): - super().__init__() - self.parser = subparsers.add_parser( - "list", - prog="llama-stack-client shields list", - description="Show available llama models at distribution endpoint", - formatter_class=argparse.RawTextHelpFormatter, - ) - self._add_arguments() - self.parser.set_defaults(func=self._run_shields_list_cmd) +@click.command("list") +@click.pass_context +def list_shields(ctx): + """Show available safety shields on distribution endpoint""" - def _add_arguments(self): - self.parser.add_argument( - "--endpoint", - type=str, - help="Llama Stack distribution endpoint", - default="", - ) + client = ctx.obj["client"] - def _run_shields_list_cmd(self, args: argparse.Namespace): - config = get_config() - if config: - args.endpoint = config.get("endpoint") + headers = ["identifier", "provider_id", "description", "type"] - if not args.endpoint: - self.parser.error( - "A valid endpoint is required. Please run llama-stack-client configure first or pass in a valid endpoint with --endpoint. " - ) - client = LlamaStackClient( - base_url=args.endpoint, - ) - - shields_list_response = client.shields.list() - - if shields_list_response: - print_table_from_response(shields_list_response) + shields_list_response = client.shields.list() + if shields_list_response: + print_table_from_response(shields_list_response, headers) diff --git a/src/llama_stack_client/lib/cli/shields/shields.py b/src/llama_stack_client/lib/cli/shields/shields.py index 00a556a0..a4a0373b 100644 --- a/src/llama_stack_client/lib/cli/shields/shields.py +++ b/src/llama_stack_client/lib/cli/shields/shields.py @@ -4,24 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import argparse +import click -from llama_stack_client.lib.cli.shields.list import ShieldsList +from .list import list_shields -from llama_stack_client.lib.cli.subcommand import Subcommand +@click.group() +def shields(): + """Query details about available safety shields on distribution.""" + pass -class ShieldsParser(Subcommand): - """List details about available safety shields on distribution.""" - def __init__(self, subparsers: argparse._SubParsersAction): - super().__init__() - self.parser = subparsers.add_parser( - "shields", - prog="llama-stack-client shields", - description="Query details about available safety shields on distribution.", - formatter_class=argparse.RawTextHelpFormatter, - ) - - subparsers = self.parser.add_subparsers(title="shields_subcommands") - ShieldsList.create(subparsers) +# Register subcommands +shields.add_command(list_shields) From a651f22cf6b80407e91258571c9b0f53b7c989a2 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 13 Nov 2024 10:07:20 -0800 Subject: [PATCH 2/4] remove configure --- src/llama_stack_client/lib/cli/configure.py | 94 ------------------- src/llama_stack_client/lib/cli/constants.py | 14 --- .../lib/cli/llama_stack_client.py | 1 - 3 files changed, 109 deletions(-) delete mode 100644 src/llama_stack_client/lib/cli/configure.py delete mode 100644 src/llama_stack_client/lib/cli/constants.py diff --git a/src/llama_stack_client/lib/cli/configure.py b/src/llama_stack_client/lib/cli/configure.py deleted file mode 100644 index 2b312b62..00000000 --- a/src/llama_stack_client/lib/cli/configure.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import argparse -import os - -import yaml - -from llama_stack_client.lib.cli.constants import get_config_file_path, LLAMA_STACK_CLIENT_CONFIG_DIR -from llama_stack_client.lib.cli.subcommand import Subcommand - - -def get_config(): - config_file = get_config_file_path() - if config_file.exists(): - with open(config_file, "r") as f: - return yaml.safe_load(f) - return None - - -class ConfigureParser(Subcommand): - """Configure Llama Stack Client CLI""" - - def __init__(self, subparsers: argparse._SubParsersAction): - super().__init__() - self.parser = subparsers.add_parser( - "configure", - prog="llama-stack-client configure", - description="Configure Llama Stack Client CLI", - formatter_class=argparse.RawTextHelpFormatter, - ) - self._add_arguments() - self.parser.set_defaults(func=self._run_configure_cmd) - - def _add_arguments(self): - self.parser.add_argument( - "--host", - type=str, - help="Llama Stack distribution host", - ) - self.parser.add_argument( - "--port", - type=str, - help="Llama Stack distribution port number", - ) - self.parser.add_argument( - "--endpoint", - type=str, - help="Llama Stack distribution endpoint", - ) - - def _run_configure_cmd(self, args: argparse.Namespace): - from prompt_toolkit import prompt - from prompt_toolkit.validation import Validator - - os.makedirs(LLAMA_STACK_CLIENT_CONFIG_DIR, exist_ok=True) - config_path = get_config_file_path() - - if args.endpoint: - endpoint = args.endpoint - else: - if args.host and args.port: - endpoint = f"http://{args.host}:{args.port}" - else: - host = prompt( - "> Enter the host name of the Llama Stack distribution server: ", - validator=Validator.from_callable( - lambda x: len(x) > 0, - error_message="Host cannot be empty, please enter a valid host", - ), - ) - port = prompt( - "> Enter the port number of the Llama Stack distribution server: ", - validator=Validator.from_callable( - lambda x: x.isdigit(), - error_message="Please enter a valid port number", - ), - ) - endpoint = f"http://{host}:{port}" - - with open(config_path, "w") as f: - f.write( - yaml.dump( - { - "endpoint": endpoint, - }, - sort_keys=True, - ) - ) - - print(f"Done! You can now use the Llama Stack Client CLI with endpoint {endpoint}") diff --git a/src/llama_stack_client/lib/cli/constants.py b/src/llama_stack_client/lib/cli/constants.py deleted file mode 100644 index 22595747..00000000 --- a/src/llama_stack_client/lib/cli/constants.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import os -from pathlib import Path - -LLAMA_STACK_CLIENT_CONFIG_DIR = Path(os.path.expanduser("~/.llama/client")) - - -def get_config_file_path(): - return LLAMA_STACK_CLIENT_CONFIG_DIR / "config.yaml" diff --git a/src/llama_stack_client/lib/cli/llama_stack_client.py b/src/llama_stack_client/lib/cli/llama_stack_client.py index a56be509..4d436598 100644 --- a/src/llama_stack_client/lib/cli/llama_stack_client.py +++ b/src/llama_stack_client/lib/cli/llama_stack_client.py @@ -36,7 +36,6 @@ def cli(ctx, endpoint: str): cli.add_command(memory_banks, "memory_banks") cli.add_command(shields, "shields") cli.add_command(eval_tasks, "eval_tasks") -# cli.add_command(configure) cli.add_command(providers, "providers") cli.add_command(datasets, "datasets") cli.add_command(scoring_functions, "scoring_functions") From 26a6236c2b6aafa53567a08544f0bc9f1c69861b Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 13 Nov 2024 10:46:00 -0800 Subject: [PATCH 3/4] add back configure --- src/llama_stack_client/lib/cli/configure.py | 64 +++++++++++++++++++ src/llama_stack_client/lib/cli/constants.py | 14 ++++ .../lib/cli/llama_stack_client.py | 35 ++++++++-- .../lib/cli/scoring_functions/__init__.py | 4 +- .../scoring_functions/scoring_functions.py | 3 +- 5 files changed, 110 insertions(+), 10 deletions(-) create mode 100644 src/llama_stack_client/lib/cli/configure.py create mode 100644 src/llama_stack_client/lib/cli/constants.py diff --git a/src/llama_stack_client/lib/cli/configure.py b/src/llama_stack_client/lib/cli/configure.py new file mode 100644 index 00000000..84cc78a0 --- /dev/null +++ b/src/llama_stack_client/lib/cli/configure.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +import click +import yaml +from llama_stack_client.lib.cli.constants import get_config_file_path, LLAMA_STACK_CLIENT_CONFIG_DIR +from prompt_toolkit import prompt +from prompt_toolkit.validation import Validator + + +def get_config(): + config_file = get_config_file_path() + if config_file.exists(): + with open(config_file, "r") as f: + return yaml.safe_load(f) + return None + + +@click.command() +@click.option("--host", type=str, help="Llama Stack distribution host") +@click.option("--port", type=str, help="Llama Stack distribution port number") +@click.option("--endpoint", type=str, help="Llama Stack distribution endpoint") +def configure(host: str | None, port: str | None, endpoint: str | None): + """Configure Llama Stack Client CLI""" + os.makedirs(LLAMA_STACK_CLIENT_CONFIG_DIR, exist_ok=True) + config_path = get_config_file_path() + + if endpoint: + final_endpoint = endpoint + else: + if host and port: + final_endpoint = f"http://{host}:{port}" + else: + host = prompt( + "> Enter the host name of the Llama Stack distribution server: ", + validator=Validator.from_callable( + lambda x: len(x) > 0, + error_message="Host cannot be empty, please enter a valid host", + ), + ) + port = prompt( + "> Enter the port number of the Llama Stack distribution server: ", + validator=Validator.from_callable( + lambda x: x.isdigit(), + error_message="Please enter a valid port number", + ), + ) + final_endpoint = f"http://{host}:{port}" + + with open(config_path, "w") as f: + f.write( + yaml.dump( + { + "endpoint": final_endpoint, + }, + sort_keys=True, + ) + ) + + print(f"Done! You can now use the Llama Stack Client CLI with endpoint {final_endpoint}") diff --git a/src/llama_stack_client/lib/cli/constants.py b/src/llama_stack_client/lib/cli/constants.py new file mode 100644 index 00000000..22595747 --- /dev/null +++ b/src/llama_stack_client/lib/cli/constants.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +from pathlib import Path + +LLAMA_STACK_CLIENT_CONFIG_DIR = Path(os.path.expanduser("~/.llama/client")) + + +def get_config_file_path(): + return LLAMA_STACK_CLIENT_CONFIG_DIR / "config.yaml" diff --git a/src/llama_stack_client/lib/cli/llama_stack_client.py b/src/llama_stack_client/lib/cli/llama_stack_client.py index 4d436598..49bdef8c 100644 --- a/src/llama_stack_client/lib/cli/llama_stack_client.py +++ b/src/llama_stack_client/lib/cli/llama_stack_client.py @@ -5,9 +5,11 @@ # the root directory of this source tree. import click +import yaml from llama_stack_client import LlamaStackClient -# from .configure import configure +from .constants import get_config_file_path +from .configure import configure from .datasets import datasets from .eval_tasks import eval_tasks from .memory_banks import memory_banks @@ -20,14 +22,34 @@ @click.group() -@click.option("--endpoint", type=str, help="Llama Stack distribution endpoint", default="http://localhost:5000") +@click.option("--endpoint", type=str, help="Llama Stack distribution endpoint", default="") +@click.option("--config", type=str, help="Path to config file", default=None) @click.pass_context -def cli(ctx, endpoint: str): +def cli(ctx, endpoint: str, config: str | None): """Welcome to the LlamaStackClient CLI""" ctx.ensure_object(dict) - client = LlamaStackClient( - base_url=endpoint, - ) + + # If no config provided, check default location + if config is None: + if endpoint != "": + raise ValueError("Cannot use both config and endpoint") + default_config = get_config_file_path() + if default_config.exists(): + config = str(default_config) + + if config: + try: + with open(config, "r") as f: + config_dict = yaml.safe_load(f) + endpoint = config_dict.get("endpoint", endpoint) + except Exception as e: + click.echo(f"Error loading config from {config}: {str(e)}", err=True) + click.echo("Falling back to HTTP client with endpoint", err=True) + + if endpoint == "": + endpoint = "http://localhost:5000" + + client = LlamaStackClient(base_url=endpoint) ctx.obj = {"client": client} @@ -38,6 +60,7 @@ def cli(ctx, endpoint: str): cli.add_command(eval_tasks, "eval_tasks") cli.add_command(providers, "providers") cli.add_command(datasets, "datasets") +cli.add_command(configure, "configure") cli.add_command(scoring_functions, "scoring_functions") diff --git a/src/llama_stack_client/lib/cli/scoring_functions/__init__.py b/src/llama_stack_client/lib/cli/scoring_functions/__init__.py index c3faacbd..9699df68 100644 --- a/src/llama_stack_client/lib/cli/scoring_functions/__init__.py +++ b/src/llama_stack_client/lib/cli/scoring_functions/__init__.py @@ -4,6 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .list import list_scoring_functions +from .scoring_functions import scoring_functions -__all__ = ["list_scoring_functions"] +__all__ = ["scoring_functions"] diff --git a/src/llama_stack_client/lib/cli/scoring_functions/scoring_functions.py b/src/llama_stack_client/lib/cli/scoring_functions/scoring_functions.py index 56cf26db..4dba3fb5 100644 --- a/src/llama_stack_client/lib/cli/scoring_functions/scoring_functions.py +++ b/src/llama_stack_client/lib/cli/scoring_functions/scoring_functions.py @@ -10,8 +10,7 @@ @click.group() -@click.pass_context -def scoring_functions(ctx): +def scoring_functions(): """Manage scoring functions""" pass From 5282b7b697e7525cd7f43c52211ac85f6bcb29b8 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 13 Nov 2024 10:46:52 -0800 Subject: [PATCH 4/4] doc update --- docs/cli_reference.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/cli_reference.md b/docs/cli_reference.md index 809a6035..19e6a49e 100644 --- a/docs/cli_reference.md +++ b/docs/cli_reference.md @@ -26,9 +26,9 @@ Done! You can now use the Llama Stack Client CLI with endpoint http://localhost: ``` -#### `llama-stack-client providers` +#### `llama-stack-client providers list` ```bash -$ llama-stack-client providers +$ llama-stack-client providers list ``` ``` +-----------+----------------+-----------------+