From b6bf0aeced0756d83814704a2458eb33b642b465 Mon Sep 17 00:00:00 2001 From: David Tam Date: Tue, 9 Jul 2024 12:54:31 -0700 Subject: [PATCH 1/2] headless cli updates --- guardrails/cli/configure.py | 29 ++++++++++++++-------------- guardrails/cli/hub/install.py | 36 ++++++++++++++++++++++------------- 2 files changed, 38 insertions(+), 27 deletions(-) diff --git a/guardrails/cli/configure.py b/guardrails/cli/configure.py index b56fcf43f..1b0d669db 100644 --- a/guardrails/cli/configure.py +++ b/guardrails/cli/configure.py @@ -59,6 +59,18 @@ def configure( help="Opt out of anonymous metrics collection.", prompt="Enable anonymous metrics reporting?", ), + token: Optional[str] = typer.Option( + None, + "--token", + help="API Key for Guardrails. If not provided, you will be prompted for it.", + ), + remote_inferencing: Optional[bool] = typer.Option( + DEFAULT_USE_REMOTE_INFERENCING, + "--enable-remote-inferencing/--disable-remote-inferencing", + help="Opt in to remote inferencing. " + "If not provided, you will be prompted for it.", + prompt="Do you wish to use remote inferencing?", + ), clear_token: Optional[bool] = typer.Option( False, "--clear-token", @@ -68,7 +80,7 @@ def configure( existing_token = _get_default_token() last4 = existing_token[-4:] if existing_token else "" - if not clear_token: + if not clear_token and token is None: console.print("\nEnter API Key below", style="bold", end=" ") if last4: @@ -86,21 +98,10 @@ def configure( token = typer.prompt("\nAPI Key", existing_token, show_default=False) else: - token = DEFAULT_TOKEN - - # Ask about remote inferencing - use_remote_inferencing = ( - typer.prompt( - "Do you wish to use remote inferencing? (Y/N)", - type=str, - default="Y", - show_default=False, - ).lower() - == "y" - ) + token = token or DEFAULT_TOKEN try: - save_configuration_file(token, enable_metrics, use_remote_inferencing) + save_configuration_file(token, enable_metrics, remote_inferencing) logger.info("Configuration saved.") if not token: diff --git a/guardrails/cli/hub/install.py b/guardrails/cli/hub/install.py index ac6da49a2..7d6d543c7 100644 --- a/guardrails/cli/hub/install.py +++ b/guardrails/cli/hub/install.py @@ -194,6 +194,11 @@ def install( help="URI to the package to install.\ Example: hub://guardrails/regex_match." ), + local_models: bool = typer.Option( + None, + "--install-local-models/--no-install-local-models", + help="Install local models", + ), quiet: bool = typer.Option( False, "--quiet", @@ -203,6 +208,8 @@ def install( verbose_printer = console.print quiet_printer = console.print if not quiet else lambda x: None """Install a validator from the Hub.""" + print("==== hub installing") + print("local models", local_models) if not package_uri.startswith("hub://"): logger.error("Invalid URI!") sys.exit(1) @@ -236,19 +243,22 @@ def do_nothing_context(*args, **kwargs): dl_deps_msg = "Downloading dependencies" with loader(dl_deps_msg, spinner="bouncingBar"): install_hub_module(module_manifest, site_packages, quiet=quiet) - - try: - if module_manifest.tags and module_manifest.tags.has_guardrails_endpoint: - install_local_models = typer.confirm( - "This validator has a Guardrails AI inference endpoint available. " - "Would you still like to install the local models for local inference?", - ) - else: - install_local_models = typer.confirm( - "Would you like to install the local models?", default=True - ) - except AttributeError: - install_local_models = False + if local_models is not None: + install_local_models = local_models + else: + try: + if module_manifest.tags and module_manifest.tags.has_guardrails_endpoint: + install_local_models = typer.confirm( + "This validator has a Guardrails AI inference endpoint available. " + "Would you still like to install the" + " local models for local inference?", + ) + else: + install_local_models = typer.confirm( + "Would you like to install the local models?", default=True + ) + except AttributeError: + install_local_models = False # Post-install if install_local_models: From 1252a091a34b2652ea90273b6e36b5e042206e33 Mon Sep 17 00:00:00 2001 From: David Tam Date: Tue, 9 Jul 2024 13:47:33 -0700 Subject: [PATCH 2/2] tests and cleanup --- guardrails/cli/hub/install.py | 5 ++--- tests/unit_tests/cli/hub/test_install.py | 1 + tests/unit_tests/cli/test_configure.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/guardrails/cli/hub/install.py b/guardrails/cli/hub/install.py index 7d6d543c7..d0dbd7509 100644 --- a/guardrails/cli/hub/install.py +++ b/guardrails/cli/hub/install.py @@ -208,8 +208,6 @@ def install( verbose_printer = console.print quiet_printer = console.print if not quiet else lambda x: None """Install a validator from the Hub.""" - print("==== hub installing") - print("local models", local_models) if not package_uri.startswith("hub://"): logger.error("Invalid URI!") sys.exit(1) @@ -243,7 +241,8 @@ def do_nothing_context(*args, **kwargs): dl_deps_msg = "Downloading dependencies" with loader(dl_deps_msg, spinner="bouncingBar"): install_hub_module(module_manifest, site_packages, quiet=quiet) - if local_models is not None: + + if local_models is True or local_models is False: install_local_models = local_models else: try: diff --git a/tests/unit_tests/cli/hub/test_install.py b/tests/unit_tests/cli/hub/test_install.py index b5232dd90..50a2b0330 100644 --- a/tests/unit_tests/cli/hub/test_install.py +++ b/tests/unit_tests/cli/hub/test_install.py @@ -126,6 +126,7 @@ def test_happy_path(self, mocker, monkeypatch): msg="Skipping post install, models will not be downloaded for local inference.", # noqa ), # noqa ] + assert mock_logger_log.call_count == 3 mock_logger_log.assert_has_calls(log_calls) diff --git a/tests/unit_tests/cli/test_configure.py b/tests/unit_tests/cli/test_configure.py index 11f189d5f..fc043462f 100644 --- a/tests/unit_tests/cli/test_configure.py +++ b/tests/unit_tests/cli/test_configure.py @@ -56,7 +56,7 @@ def test_configure(mocker, runner, expected_token, enable_metrics, clear_token): assert mock_logger_info.call_count == 2 mock_logger_info.assert_has_calls(expected_calls) mock_save_configuration_file.assert_called_once_with( - expected_token, enable_metrics, False + expected_token, enable_metrics, True )