From bd664f3f0ad1ecf368aaa29ff5911f1af456e708 Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Fri, 26 Jul 2024 12:46:53 -0500 Subject: [PATCH 1/2] only skip local models if explicitly told to do so, or if remote inference is enabled and available --- guardrails/cli/hub/install.py | 21 ++-- tests/unit_tests/cli/hub/test_install.py | 126 ++++++++++++++++++++++- 2 files changed, 136 insertions(+), 11 deletions(-) diff --git a/guardrails/cli/hub/install.py b/guardrails/cli/hub/install.py index ea2b49e27..c1825ddb6 100644 --- a/guardrails/cli/hub/install.py +++ b/guardrails/cli/hub/install.py @@ -3,7 +3,7 @@ import sys from contextlib import contextmanager from string import Template -from typing import List, Literal +from typing import List, Literal, Optional import typer @@ -196,8 +196,8 @@ def install( help="URI to the package to install.\ Example: hub://guardrails/regex_match." ), - local_models: bool = typer.Option( - False, + local_models: Optional[bool] = typer.Option( + None, "--install-local-models/--no-install-local-models", help="Install local models", ), @@ -248,14 +248,19 @@ def do_nothing_context(*args, **kwargs): install_hub_module(module_manifest, site_packages, quiet=quiet) install_local_models = local_models + use_remote_endpoint = False + module_has_endpoint = ( + module_manifest.tags and module_manifest.tags.has_guardrails_endpoint + ) try: if has_rc_file: # if we do want to remote then we don't want to install local models - install_local_models = not Credentials.from_rc_file( - logger - ).use_remote_inferencing - elif module_manifest.tags and module_manifest.tags.has_guardrails_endpoint: + use_remote_endpoint = ( + not Credentials.from_rc_file(logger).use_remote_inferencing + and module_has_endpoint + ) + elif install_local_models is None and module_has_endpoint: install_local_models = typer.confirm( "This validator has a Guardrails AI inference endpoint available. " "Would you still like to install the" @@ -265,7 +270,7 @@ def do_nothing_context(*args, **kwargs): pass # Post-install - if install_local_models: + if not use_remote_endpoint and install_local_models is not False: logger.log( level=LEVELS.get("SPAM"), # type: ignore msg="Installing models locally!", diff --git a/tests/unit_tests/cli/hub/test_install.py b/tests/unit_tests/cli/hub/test_install.py index 29f516a87..86667dba7 100644 --- a/tests/unit_tests/cli/hub/test_install.py +++ b/tests/unit_tests/cli/hub/test_install.py @@ -22,7 +22,7 @@ def test_exits_early_if_uri_is_not_valid(self, mocker): mock_logger_error.assert_called_once_with("Invalid URI!") sys_exit_spy.assert_called_once_with(1) - def test_install_local_models(self, mocker, monkeypatch): + def test_install_local_models__false(self, mocker, monkeypatch): mock_logger_log = mocker.patch("guardrails.cli.hub.install.logger.log") mock_get_validator_manifest = mocker.patch( @@ -60,7 +60,7 @@ def test_install_local_models(self, mocker, monkeypatch): from guardrails.cli.hub.install import install - install("hub://guardrails/test-validator", quiet=False) + install("hub://guardrails/test-validator", quiet=False, local_models=False) log_calls = [ call(level=5, msg="Installing hub://guardrails/test-validator..."), @@ -83,6 +83,126 @@ def test_install_local_models(self, mocker, monkeypatch): mock_add_to_hub_init.assert_called_once_with(manifest, site_packages) + def test_install_local_models__true(self, mocker, monkeypatch): + mock_logger_log = mocker.patch("guardrails.cli.hub.install.logger.log") + + mock_get_validator_manifest = mocker.patch( + "guardrails.cli.hub.install.get_validator_manifest" + ) + manifest = ModuleManifest.from_dict( + { + "id": "id", + "name": "name", + "author": {"name": "me", "email": "me@me.me"}, + "maintainers": [], + "repository": {"url": "some-repo"}, + "namespace": "guardrails", + "package_name": "test-validator", + "module_name": "test_validator", + "exports": ["TestValidator"], + "tags": {"has_guardrails_endpoint": False}, + } + ) + mock_get_validator_manifest.return_value = manifest + + mock_get_site_packages_location = mocker.patch( + "guardrails.cli.hub.install.get_site_packages_location" + ) + site_packages = "./.venv/lib/python3.X/site-packages" + mock_get_site_packages_location.return_value = site_packages + + mocker.patch("guardrails.cli.hub.install.install_hub_module") + + mock_add_to_hub_init = mocker.patch( + "guardrails.cli.hub.install.add_to_hub_inits" + ) + + monkeypatch.setattr("typer.confirm", lambda prompt, default=True: True) + + from guardrails.cli.hub.install import install + + install("hub://guardrails/test-validator", quiet=False, local_models=True) + + log_calls = [ + call(level=5, msg="Installing hub://guardrails/test-validator..."), + call( + level=5, + msg="Installing models locally!", + ), + call( + level=5, + msg="✅Successfully installed hub://guardrails/test-validator!\n\nImport validator:\nfrom guardrails.hub import TestValidator\n\nGet more info:\nhttps://hub.guardrailsai.com/validator/id\n", # noqa + ), # noqa + ] + assert mock_logger_log.call_count == 3 + mock_logger_log.assert_has_calls(log_calls) + + mock_get_validator_manifest.assert_called_once_with("guardrails/test-validator") + + assert mock_get_site_packages_location.call_count == 1 + + mock_add_to_hub_init.assert_called_once_with(manifest, site_packages) + + def test_install_local_models__none(self, mocker, monkeypatch): + mock_logger_log = mocker.patch("guardrails.cli.hub.install.logger.log") + + mock_get_validator_manifest = mocker.patch( + "guardrails.cli.hub.install.get_validator_manifest" + ) + manifest = ModuleManifest.from_dict( + { + "id": "id", + "name": "name", + "author": {"name": "me", "email": "me@me.me"}, + "maintainers": [], + "repository": {"url": "some-repo"}, + "namespace": "guardrails", + "package_name": "test-validator", + "module_name": "test_validator", + "exports": ["TestValidator"], + "tags": {"has_guardrails_endpoint": False}, + } + ) + mock_get_validator_manifest.return_value = manifest + + mock_get_site_packages_location = mocker.patch( + "guardrails.cli.hub.install.get_site_packages_location" + ) + site_packages = "./.venv/lib/python3.X/site-packages" + mock_get_site_packages_location.return_value = site_packages + + mocker.patch("guardrails.cli.hub.install.install_hub_module") + + mock_add_to_hub_init = mocker.patch( + "guardrails.cli.hub.install.add_to_hub_inits" + ) + + monkeypatch.setattr("typer.confirm", lambda prompt, default=True: True) + + from guardrails.cli.hub.install import install + + install("hub://guardrails/test-validator", quiet=False) + + log_calls = [ + call(level=5, msg="Installing hub://guardrails/test-validator..."), + call( + level=5, + msg="Installing models locally!", + ), + call( + level=5, + msg="✅Successfully installed hub://guardrails/test-validator!\n\nImport validator:\nfrom guardrails.hub import TestValidator\n\nGet more info:\nhttps://hub.guardrailsai.com/validator/id\n", # noqa + ), # noqa + ] + assert mock_logger_log.call_count == 3 + mock_logger_log.assert_has_calls(log_calls) + + mock_get_validator_manifest.assert_called_once_with("guardrails/test-validator") + + assert mock_get_site_packages_location.call_count == 1 + + mock_add_to_hub_init.assert_called_once_with(manifest, site_packages) + def test_happy_path(self, mocker, monkeypatch): mock_logger_log = mocker.patch("guardrails.cli.hub.install.logger.log") @@ -123,7 +243,7 @@ def test_happy_path(self, mocker, monkeypatch): call(level=5, msg="Installing hub://guardrails/test-validator..."), call( level=5, - msg="Skipping post install, models will not be downloaded for local inference.", # noqa + msg="Installing models locally!", # noqa ), # noqa ] From 3871cc21acd59b09933a1b00733d5dddf169916e Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Fri, 26 Jul 2024 13:30:59 -0500 Subject: [PATCH 2/2] default install flag to true if is none after other logic is applied --- guardrails/cli/hub/install.py | 5 ++++- tests/unit_tests/cli/hub/test_install.py | 24 +++++++++++++++--------- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/guardrails/cli/hub/install.py b/guardrails/cli/hub/install.py index c1825ddb6..08205c53b 100644 --- a/guardrails/cli/hub/install.py +++ b/guardrails/cli/hub/install.py @@ -270,7 +270,10 @@ def do_nothing_context(*args, **kwargs): pass # Post-install - if not use_remote_endpoint and install_local_models is not False: + install_local_models = ( + install_local_models if install_local_models is not None else True + ) + if not use_remote_endpoint and install_local_models is True: logger.log( level=LEVELS.get("SPAM"), # type: ignore msg="Installing models locally!", diff --git a/tests/unit_tests/cli/hub/test_install.py b/tests/unit_tests/cli/hub/test_install.py index 86667dba7..77c2c795d 100644 --- a/tests/unit_tests/cli/hub/test_install.py +++ b/tests/unit_tests/cli/hub/test_install.py @@ -12,7 +12,7 @@ class TestInstall: def test_exits_early_if_uri_is_not_valid(self, mocker): mock_logger_error = mocker.patch("guardrails.cli.hub.install.logger.error") - from guardrails.cli.hub.install import install, sys + from guardrails.cli.hub.install import sys sys_exit_spy = mocker.spy(sys, "exit") @@ -58,9 +58,12 @@ def test_install_local_models__false(self, mocker, monkeypatch): monkeypatch.setattr("typer.confirm", lambda prompt, default=True: True) - from guardrails.cli.hub.install import install + runner = CliRunner() - install("hub://guardrails/test-validator", quiet=False, local_models=False) + runner.invoke( + hub_command, + ["install", "hub://guardrails/test-validator", "--no-install-local-models"], + ) log_calls = [ call(level=5, msg="Installing hub://guardrails/test-validator..."), @@ -119,9 +122,12 @@ def test_install_local_models__true(self, mocker, monkeypatch): monkeypatch.setattr("typer.confirm", lambda prompt, default=True: True) - from guardrails.cli.hub.install import install + runner = CliRunner() - install("hub://guardrails/test-validator", quiet=False, local_models=True) + runner.invoke( + hub_command, + ["install", "hub://guardrails/test-validator", "--install-local-models"], + ) log_calls = [ call(level=5, msg="Installing hub://guardrails/test-validator..."), @@ -179,9 +185,9 @@ def test_install_local_models__none(self, mocker, monkeypatch): monkeypatch.setattr("typer.confirm", lambda prompt, default=True: True) - from guardrails.cli.hub.install import install + runner = CliRunner() - install("hub://guardrails/test-validator", quiet=False) + runner.invoke(hub_command, ["install", "hub://guardrails/test-validator"]) log_calls = [ call(level=5, msg="Installing hub://guardrails/test-validator..."), @@ -235,9 +241,9 @@ def test_happy_path(self, mocker, monkeypatch): mocker.patch("guardrails.cli.hub.install.run_post_install") mocker.patch("guardrails.cli.hub.install.add_to_hub_inits") - monkeypatch.setattr("typer.confirm", lambda _: False) + runner = CliRunner() - install("hub://guardrails/test-validator", quiet=False) + runner.invoke(hub_command, ["install", "hub://guardrails/test-validator"]) log_calls = [ call(level=5, msg="Installing hub://guardrails/test-validator..."),