diff --git a/guardrails/cli/configure.py b/guardrails/cli/configure.py index b56fcf43f..1b0d669db 100644 --- a/guardrails/cli/configure.py +++ b/guardrails/cli/configure.py @@ -59,6 +59,18 @@ def configure( help="Opt out of anonymous metrics collection.", prompt="Enable anonymous metrics reporting?", ), + token: Optional[str] = typer.Option( + None, + "--token", + help="API Key for Guardrails. If not provided, you will be prompted for it.", + ), + remote_inferencing: Optional[bool] = typer.Option( + DEFAULT_USE_REMOTE_INFERENCING, + "--enable-remote-inferencing/--disable-remote-inferencing", + help="Opt in to remote inferencing. " + "If not provided, you will be prompted for it.", + prompt="Do you wish to use remote inferencing?", + ), clear_token: Optional[bool] = typer.Option( False, "--clear-token", @@ -68,7 +80,7 @@ def configure( existing_token = _get_default_token() last4 = existing_token[-4:] if existing_token else "" - if not clear_token: + if not clear_token and token is None: console.print("\nEnter API Key below", style="bold", end=" ") if last4: @@ -86,21 +98,10 @@ def configure( token = typer.prompt("\nAPI Key", existing_token, show_default=False) else: - token = DEFAULT_TOKEN - - # Ask about remote inferencing - use_remote_inferencing = ( - typer.prompt( - "Do you wish to use remote inferencing? (Y/N)", - type=str, - default="Y", - show_default=False, - ).lower() - == "y" - ) + token = token or DEFAULT_TOKEN try: - save_configuration_file(token, enable_metrics, use_remote_inferencing) + save_configuration_file(token, enable_metrics, remote_inferencing) logger.info("Configuration saved.") if not token: diff --git a/guardrails/cli/hub/install.py b/guardrails/cli/hub/install.py index ac6da49a2..d0dbd7509 100644 --- a/guardrails/cli/hub/install.py +++ b/guardrails/cli/hub/install.py @@ -194,6 +194,11 @@ def install( help="URI to the package to install.\ Example: hub://guardrails/regex_match." ), + local_models: bool = typer.Option( + None, + "--install-local-models/--no-install-local-models", + help="Install local models", + ), quiet: bool = typer.Option( False, "--quiet", @@ -237,18 +242,22 @@ def do_nothing_context(*args, **kwargs): with loader(dl_deps_msg, spinner="bouncingBar"): install_hub_module(module_manifest, site_packages, quiet=quiet) - try: - if module_manifest.tags and module_manifest.tags.has_guardrails_endpoint: - install_local_models = typer.confirm( - "This validator has a Guardrails AI inference endpoint available. " - "Would you still like to install the local models for local inference?", - ) - else: - install_local_models = typer.confirm( - "Would you like to install the local models?", default=True - ) - except AttributeError: - install_local_models = False + if local_models is True or local_models is False: + install_local_models = local_models + else: + try: + if module_manifest.tags and module_manifest.tags.has_guardrails_endpoint: + install_local_models = typer.confirm( + "This validator has a Guardrails AI inference endpoint available. " + "Would you still like to install the" + " local models for local inference?", + ) + else: + install_local_models = typer.confirm( + "Would you like to install the local models?", default=True + ) + except AttributeError: + install_local_models = False # Post-install if install_local_models: diff --git a/tests/unit_tests/cli/hub/test_install.py b/tests/unit_tests/cli/hub/test_install.py index b5232dd90..50a2b0330 100644 --- a/tests/unit_tests/cli/hub/test_install.py +++ b/tests/unit_tests/cli/hub/test_install.py @@ -126,6 +126,7 @@ def test_happy_path(self, mocker, monkeypatch): msg="Skipping post install, models will not be downloaded for local inference.", # noqa ), # noqa ] + assert mock_logger_log.call_count == 3 mock_logger_log.assert_has_calls(log_calls) diff --git a/tests/unit_tests/cli/test_configure.py b/tests/unit_tests/cli/test_configure.py index 11f189d5f..fc043462f 100644 --- a/tests/unit_tests/cli/test_configure.py +++ b/tests/unit_tests/cli/test_configure.py @@ -56,7 +56,7 @@ def test_configure(mocker, runner, expected_token, enable_metrics, clear_token): assert mock_logger_info.call_count == 2 mock_logger_info.assert_has_calls(expected_calls) mock_save_configuration_file.assert_called_once_with( - expected_token, enable_metrics, False + expected_token, enable_metrics, True )