Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions guardrails/cli/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,18 @@ def configure(
help="Opt out of anonymous metrics collection.",
prompt="Enable anonymous metrics reporting?",
),
token: Optional[str] = typer.Option(
None,
"--token",
help="API Key for Guardrails. If not provided, you will be prompted for it.",
),
remote_inferencing: Optional[bool] = typer.Option(
DEFAULT_USE_REMOTE_INFERENCING,
"--enable-remote-inferencing/--disable-remote-inferencing",
help="Opt in to remote inferencing. "
"If not provided, you will be prompted for it.",
prompt="Do you wish to use remote inferencing?",
),
clear_token: Optional[bool] = typer.Option(
False,
"--clear-token",
Expand All @@ -68,7 +80,7 @@ def configure(
existing_token = _get_default_token()
last4 = existing_token[-4:] if existing_token else ""

if not clear_token:
if not clear_token and token is None:
console.print("\nEnter API Key below", style="bold", end=" ")

if last4:
Expand All @@ -86,21 +98,10 @@ def configure(
token = typer.prompt("\nAPI Key", existing_token, show_default=False)

else:
token = DEFAULT_TOKEN

# Ask about remote inferencing
use_remote_inferencing = (
typer.prompt(
"Do you wish to use remote inferencing? (Y/N)",
type=str,
default="Y",
show_default=False,
).lower()
== "y"
)
token = token or DEFAULT_TOKEN

try:
save_configuration_file(token, enable_metrics, use_remote_inferencing)
save_configuration_file(token, enable_metrics, remote_inferencing)
logger.info("Configuration saved.")

if not token:
Expand Down
33 changes: 21 additions & 12 deletions guardrails/cli/hub/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ def install(
help="URI to the package to install.\
Example: hub://guardrails/regex_match."
),
local_models: bool = typer.Option(
None,
"--install-local-models/--no-install-local-models",
help="Install local models",
),
quiet: bool = typer.Option(
False,
"--quiet",
Expand Down Expand Up @@ -237,18 +242,22 @@ def do_nothing_context(*args, **kwargs):
with loader(dl_deps_msg, spinner="bouncingBar"):
install_hub_module(module_manifest, site_packages, quiet=quiet)

try:
if module_manifest.tags and module_manifest.tags.has_guardrails_endpoint:
install_local_models = typer.confirm(
"This validator has a Guardrails AI inference endpoint available. "
"Would you still like to install the local models for local inference?",
)
else:
install_local_models = typer.confirm(
"Would you like to install the local models?", default=True
)
except AttributeError:
install_local_models = False
if local_models is True or local_models is False:
install_local_models = local_models
else:
try:
if module_manifest.tags and module_manifest.tags.has_guardrails_endpoint:
install_local_models = typer.confirm(
"This validator has a Guardrails AI inference endpoint available. "
"Would you still like to install the"
" local models for local inference?",
)
else:
install_local_models = typer.confirm(
"Would you like to install the local models?", default=True
)
except AttributeError:
install_local_models = False

# Post-install
if install_local_models:
Expand Down
1 change: 1 addition & 0 deletions tests/unit_tests/cli/hub/test_install.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def test_happy_path(self, mocker, monkeypatch):
msg="Skipping post install, models will not be downloaded for local inference.", # noqa
), # noqa
]

assert mock_logger_log.call_count == 3
mock_logger_log.assert_has_calls(log_calls)

Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/cli/test_configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_configure(mocker, runner, expected_token, enable_metrics, clear_token):
assert mock_logger_info.call_count == 2
mock_logger_info.assert_has_calls(expected_calls)
mock_save_configuration_file.assert_called_once_with(
expected_token, enable_metrics, False
expected_token, enable_metrics, True
)


Expand Down