diff --git a/guardrails/cli/hub/install.py b/guardrails/cli/hub/install.py index ea2b49e27..08205c53b 100644 --- a/guardrails/cli/hub/install.py +++ b/guardrails/cli/hub/install.py @@ -3,7 +3,7 @@ import sys from contextlib import contextmanager from string import Template -from typing import List, Literal +from typing import List, Literal, Optional import typer @@ -196,8 +196,8 @@ def install( help="URI to the package to install.\ Example: hub://guardrails/regex_match." ), - local_models: bool = typer.Option( - False, + local_models: Optional[bool] = typer.Option( + None, "--install-local-models/--no-install-local-models", help="Install local models", ), @@ -248,14 +248,19 @@ def do_nothing_context(*args, **kwargs): install_hub_module(module_manifest, site_packages, quiet=quiet) install_local_models = local_models + use_remote_endpoint = False + module_has_endpoint = ( + module_manifest.tags and module_manifest.tags.has_guardrails_endpoint + ) try: if has_rc_file: # if we do want to remote then we don't want to install local models - install_local_models = not Credentials.from_rc_file( - logger - ).use_remote_inferencing - elif module_manifest.tags and module_manifest.tags.has_guardrails_endpoint: + use_remote_endpoint = ( + not Credentials.from_rc_file(logger).use_remote_inferencing + and module_has_endpoint + ) + elif install_local_models is None and module_has_endpoint: install_local_models = typer.confirm( "This validator has a Guardrails AI inference endpoint available. " "Would you still like to install the" @@ -265,7 +270,10 @@ def do_nothing_context(*args, **kwargs): pass # Post-install - if install_local_models: + install_local_models = ( + install_local_models if install_local_models is not None else True + ) + if not use_remote_endpoint and install_local_models is True: logger.log( level=LEVELS.get("SPAM"), # type: ignore msg="Installing models locally!", diff --git a/tests/unit_tests/cli/hub/test_install.py b/tests/unit_tests/cli/hub/test_install.py index 29f516a87..77c2c795d 100644 --- a/tests/unit_tests/cli/hub/test_install.py +++ b/tests/unit_tests/cli/hub/test_install.py @@ -12,7 +12,7 @@ class TestInstall: def test_exits_early_if_uri_is_not_valid(self, mocker): mock_logger_error = mocker.patch("guardrails.cli.hub.install.logger.error") - from guardrails.cli.hub.install import install, sys + from guardrails.cli.hub.install import sys sys_exit_spy = mocker.spy(sys, "exit") @@ -22,7 +22,7 @@ def test_exits_early_if_uri_is_not_valid(self, mocker): mock_logger_error.assert_called_once_with("Invalid URI!") sys_exit_spy.assert_called_once_with(1) - def test_install_local_models(self, mocker, monkeypatch): + def test_install_local_models__false(self, mocker, monkeypatch): mock_logger_log = mocker.patch("guardrails.cli.hub.install.logger.log") mock_get_validator_manifest = mocker.patch( @@ -58,9 +58,12 @@ def test_install_local_models(self, mocker, monkeypatch): monkeypatch.setattr("typer.confirm", lambda prompt, default=True: True) - from guardrails.cli.hub.install import install + runner = CliRunner() - install("hub://guardrails/test-validator", quiet=False) + runner.invoke( + hub_command, + ["install", "hub://guardrails/test-validator", "--no-install-local-models"], + ) log_calls = [ call(level=5, msg="Installing hub://guardrails/test-validator..."), @@ -83,6 +86,129 @@ def test_install_local_models(self, mocker, monkeypatch): mock_add_to_hub_init.assert_called_once_with(manifest, site_packages) + def test_install_local_models__true(self, mocker, monkeypatch): + mock_logger_log = mocker.patch("guardrails.cli.hub.install.logger.log") + + mock_get_validator_manifest = mocker.patch( + "guardrails.cli.hub.install.get_validator_manifest" + ) + manifest = ModuleManifest.from_dict( + { + "id": "id", + "name": "name", + "author": {"name": "me", "email": "me@me.me"}, + "maintainers": [], + "repository": {"url": "some-repo"}, + "namespace": "guardrails", + "package_name": "test-validator", + "module_name": "test_validator", + "exports": ["TestValidator"], + "tags": {"has_guardrails_endpoint": False}, + } + ) + mock_get_validator_manifest.return_value = manifest + + mock_get_site_packages_location = mocker.patch( + "guardrails.cli.hub.install.get_site_packages_location" + ) + site_packages = "./.venv/lib/python3.X/site-packages" + mock_get_site_packages_location.return_value = site_packages + + mocker.patch("guardrails.cli.hub.install.install_hub_module") + + mock_add_to_hub_init = mocker.patch( + "guardrails.cli.hub.install.add_to_hub_inits" + ) + + monkeypatch.setattr("typer.confirm", lambda prompt, default=True: True) + + runner = CliRunner() + + runner.invoke( + hub_command, + ["install", "hub://guardrails/test-validator", "--install-local-models"], + ) + + log_calls = [ + call(level=5, msg="Installing hub://guardrails/test-validator..."), + call( + level=5, + msg="Installing models locally!", + ), + call( + level=5, + msg="✅Successfully installed hub://guardrails/test-validator!\n\nImport validator:\nfrom guardrails.hub import TestValidator\n\nGet more info:\nhttps://hub.guardrailsai.com/validator/id\n", # noqa + ), # noqa + ] + assert mock_logger_log.call_count == 3 + mock_logger_log.assert_has_calls(log_calls) + + mock_get_validator_manifest.assert_called_once_with("guardrails/test-validator") + + assert mock_get_site_packages_location.call_count == 1 + + mock_add_to_hub_init.assert_called_once_with(manifest, site_packages) + + def test_install_local_models__none(self, mocker, monkeypatch): + mock_logger_log = mocker.patch("guardrails.cli.hub.install.logger.log") + + mock_get_validator_manifest = mocker.patch( + "guardrails.cli.hub.install.get_validator_manifest" + ) + manifest = ModuleManifest.from_dict( + { + "id": "id", + "name": "name", + "author": {"name": "me", "email": "me@me.me"}, + "maintainers": [], + "repository": {"url": "some-repo"}, + "namespace": "guardrails", + "package_name": "test-validator", + "module_name": "test_validator", + "exports": ["TestValidator"], + "tags": {"has_guardrails_endpoint": False}, + } + ) + mock_get_validator_manifest.return_value = manifest + + mock_get_site_packages_location = mocker.patch( + "guardrails.cli.hub.install.get_site_packages_location" + ) + site_packages = "./.venv/lib/python3.X/site-packages" + mock_get_site_packages_location.return_value = site_packages + + mocker.patch("guardrails.cli.hub.install.install_hub_module") + + mock_add_to_hub_init = mocker.patch( + "guardrails.cli.hub.install.add_to_hub_inits" + ) + + monkeypatch.setattr("typer.confirm", lambda prompt, default=True: True) + + runner = CliRunner() + + runner.invoke(hub_command, ["install", "hub://guardrails/test-validator"]) + + log_calls = [ + call(level=5, msg="Installing hub://guardrails/test-validator..."), + call( + level=5, + msg="Installing models locally!", + ), + call( + level=5, + msg="✅Successfully installed hub://guardrails/test-validator!\n\nImport validator:\nfrom guardrails.hub import TestValidator\n\nGet more info:\nhttps://hub.guardrailsai.com/validator/id\n", # noqa + ), # noqa + ] + assert mock_logger_log.call_count == 3 + mock_logger_log.assert_has_calls(log_calls) + + mock_get_validator_manifest.assert_called_once_with("guardrails/test-validator") + + assert mock_get_site_packages_location.call_count == 1 + + mock_add_to_hub_init.assert_called_once_with(manifest, site_packages) + def test_happy_path(self, mocker, monkeypatch): mock_logger_log = mocker.patch("guardrails.cli.hub.install.logger.log") @@ -115,15 +241,15 @@ def test_happy_path(self, mocker, monkeypatch): mocker.patch("guardrails.cli.hub.install.run_post_install") mocker.patch("guardrails.cli.hub.install.add_to_hub_inits") - monkeypatch.setattr("typer.confirm", lambda _: False) + runner = CliRunner() - install("hub://guardrails/test-validator", quiet=False) + runner.invoke(hub_command, ["install", "hub://guardrails/test-validator"]) log_calls = [ call(level=5, msg="Installing hub://guardrails/test-validator..."), call( level=5, - msg="Skipping post install, models will not be downloaded for local inference.", # noqa + msg="Installing models locally!", # noqa ), # noqa ]