Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions guardrails/cli/hub/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
from contextlib import contextmanager
from string import Template
from typing import List, Literal
from typing import List, Literal, Optional

import typer

Expand Down Expand Up @@ -196,8 +196,8 @@ def install(
help="URI to the package to install.\
Example: hub://guardrails/regex_match."
),
local_models: bool = typer.Option(
False,
local_models: Optional[bool] = typer.Option(
None,
"--install-local-models/--no-install-local-models",
help="Install local models",
),
Expand Down Expand Up @@ -248,14 +248,19 @@ def do_nothing_context(*args, **kwargs):
install_hub_module(module_manifest, site_packages, quiet=quiet)

install_local_models = local_models
use_remote_endpoint = False
module_has_endpoint = (
module_manifest.tags and module_manifest.tags.has_guardrails_endpoint
)

try:
if has_rc_file:
# if we do want to remote then we don't want to install local models
install_local_models = not Credentials.from_rc_file(
logger
).use_remote_inferencing
elif module_manifest.tags and module_manifest.tags.has_guardrails_endpoint:
use_remote_endpoint = (
not Credentials.from_rc_file(logger).use_remote_inferencing
and module_has_endpoint
)
elif install_local_models is None and module_has_endpoint:
install_local_models = typer.confirm(
"This validator has a Guardrails AI inference endpoint available. "
"Would you still like to install the"
Expand All @@ -265,7 +270,10 @@ def do_nothing_context(*args, **kwargs):
pass

# Post-install
if install_local_models:
install_local_models = (
install_local_models if install_local_models is not None else True
)
if not use_remote_endpoint and install_local_models is True:
logger.log(
level=LEVELS.get("SPAM"), # type: ignore
msg="Installing models locally!",
Expand Down
140 changes: 133 additions & 7 deletions tests/unit_tests/cli/hub/test_install.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class TestInstall:
def test_exits_early_if_uri_is_not_valid(self, mocker):
mock_logger_error = mocker.patch("guardrails.cli.hub.install.logger.error")

from guardrails.cli.hub.install import install, sys
from guardrails.cli.hub.install import sys

sys_exit_spy = mocker.spy(sys, "exit")

Expand All @@ -22,7 +22,7 @@ def test_exits_early_if_uri_is_not_valid(self, mocker):
mock_logger_error.assert_called_once_with("Invalid URI!")
sys_exit_spy.assert_called_once_with(1)

def test_install_local_models(self, mocker, monkeypatch):
def test_install_local_models__false(self, mocker, monkeypatch):
mock_logger_log = mocker.patch("guardrails.cli.hub.install.logger.log")

mock_get_validator_manifest = mocker.patch(
Expand Down Expand Up @@ -58,9 +58,12 @@ def test_install_local_models(self, mocker, monkeypatch):

monkeypatch.setattr("typer.confirm", lambda prompt, default=True: True)

from guardrails.cli.hub.install import install
runner = CliRunner()

install("hub://guardrails/test-validator", quiet=False)
runner.invoke(
hub_command,
["install", "hub://guardrails/test-validator", "--no-install-local-models"],
)

log_calls = [
call(level=5, msg="Installing hub://guardrails/test-validator..."),
Expand All @@ -83,6 +86,129 @@ def test_install_local_models(self, mocker, monkeypatch):

mock_add_to_hub_init.assert_called_once_with(manifest, site_packages)

def test_install_local_models__true(self, mocker, monkeypatch):
mock_logger_log = mocker.patch("guardrails.cli.hub.install.logger.log")

mock_get_validator_manifest = mocker.patch(
"guardrails.cli.hub.install.get_validator_manifest"
)
manifest = ModuleManifest.from_dict(
{
"id": "id",
"name": "name",
"author": {"name": "me", "email": "me@me.me"},
"maintainers": [],
"repository": {"url": "some-repo"},
"namespace": "guardrails",
"package_name": "test-validator",
"module_name": "test_validator",
"exports": ["TestValidator"],
"tags": {"has_guardrails_endpoint": False},
}
)
mock_get_validator_manifest.return_value = manifest

mock_get_site_packages_location = mocker.patch(
"guardrails.cli.hub.install.get_site_packages_location"
)
site_packages = "./.venv/lib/python3.X/site-packages"
mock_get_site_packages_location.return_value = site_packages

mocker.patch("guardrails.cli.hub.install.install_hub_module")

mock_add_to_hub_init = mocker.patch(
"guardrails.cli.hub.install.add_to_hub_inits"
)

monkeypatch.setattr("typer.confirm", lambda prompt, default=True: True)

runner = CliRunner()

runner.invoke(
hub_command,
["install", "hub://guardrails/test-validator", "--install-local-models"],
)

log_calls = [
call(level=5, msg="Installing hub://guardrails/test-validator..."),
call(
level=5,
msg="Installing models locally!",
),
call(
level=5,
msg="✅Successfully installed hub://guardrails/test-validator!\n\nImport validator:\nfrom guardrails.hub import TestValidator\n\nGet more info:\nhttps://hub.guardrailsai.com/validator/id\n", # noqa
), # noqa
]
assert mock_logger_log.call_count == 3
mock_logger_log.assert_has_calls(log_calls)

mock_get_validator_manifest.assert_called_once_with("guardrails/test-validator")

assert mock_get_site_packages_location.call_count == 1

mock_add_to_hub_init.assert_called_once_with(manifest, site_packages)

def test_install_local_models__none(self, mocker, monkeypatch):
mock_logger_log = mocker.patch("guardrails.cli.hub.install.logger.log")

mock_get_validator_manifest = mocker.patch(
"guardrails.cli.hub.install.get_validator_manifest"
)
manifest = ModuleManifest.from_dict(
{
"id": "id",
"name": "name",
"author": {"name": "me", "email": "me@me.me"},
"maintainers": [],
"repository": {"url": "some-repo"},
"namespace": "guardrails",
"package_name": "test-validator",
"module_name": "test_validator",
"exports": ["TestValidator"],
"tags": {"has_guardrails_endpoint": False},
}
)
mock_get_validator_manifest.return_value = manifest

mock_get_site_packages_location = mocker.patch(
"guardrails.cli.hub.install.get_site_packages_location"
)
site_packages = "./.venv/lib/python3.X/site-packages"
mock_get_site_packages_location.return_value = site_packages

mocker.patch("guardrails.cli.hub.install.install_hub_module")

mock_add_to_hub_init = mocker.patch(
"guardrails.cli.hub.install.add_to_hub_inits"
)

monkeypatch.setattr("typer.confirm", lambda prompt, default=True: True)

runner = CliRunner()

runner.invoke(hub_command, ["install", "hub://guardrails/test-validator"])

log_calls = [
call(level=5, msg="Installing hub://guardrails/test-validator..."),
call(
level=5,
msg="Installing models locally!",
),
call(
level=5,
msg="✅Successfully installed hub://guardrails/test-validator!\n\nImport validator:\nfrom guardrails.hub import TestValidator\n\nGet more info:\nhttps://hub.guardrailsai.com/validator/id\n", # noqa
), # noqa
]
assert mock_logger_log.call_count == 3
mock_logger_log.assert_has_calls(log_calls)

mock_get_validator_manifest.assert_called_once_with("guardrails/test-validator")

assert mock_get_site_packages_location.call_count == 1

mock_add_to_hub_init.assert_called_once_with(manifest, site_packages)

def test_happy_path(self, mocker, monkeypatch):
mock_logger_log = mocker.patch("guardrails.cli.hub.install.logger.log")

Expand Down Expand Up @@ -115,15 +241,15 @@ def test_happy_path(self, mocker, monkeypatch):
mocker.patch("guardrails.cli.hub.install.run_post_install")
mocker.patch("guardrails.cli.hub.install.add_to_hub_inits")

monkeypatch.setattr("typer.confirm", lambda _: False)
runner = CliRunner()

install("hub://guardrails/test-validator", quiet=False)
runner.invoke(hub_command, ["install", "hub://guardrails/test-validator"])

log_calls = [
call(level=5, msg="Installing hub://guardrails/test-validator..."),
call(
level=5,
msg="Skipping post install, models will not be downloaded for local inference.", # noqa
msg="Installing models locally!", # noqa
), # noqa
]

Expand Down