diff --git a/guardrails/cli/hub/install.py b/guardrails/cli/hub/install.py index 3ad9a07d1..b105517a2 100644 --- a/guardrails/cli/hub/install.py +++ b/guardrails/cli/hub/install.py @@ -25,6 +25,9 @@ def install( "--quiet", help="Run the command in quiet mode to reduce output verbosity.", ), + upgrade: bool = typer.Option( + False, "--upgrade", help="Upgrade the package to the latest version." + ), ): try: trace_if_enabled("hub/install") @@ -41,6 +44,7 @@ def confirm(): package_uri, install_local_models=local_models, quiet=quiet, + upgrade=upgrade, install_local_models_confirm=confirm, ) except Exception as e: diff --git a/guardrails/hub/install.py b/guardrails/hub/install.py index 6ce9e16cf..005a5306d 100644 --- a/guardrails/hub/install.py +++ b/guardrails/hub/install.py @@ -35,6 +35,7 @@ def install( package_uri: str, install_local_models=None, quiet: bool = True, + upgrade: bool = False, install_local_models_confirm: Callable = default_local_models_confirm, ) -> ValidatorModuleType: """Install a validator package from a hub URI. @@ -84,7 +85,11 @@ def install( dl_deps_msg = "Downloading dependencies" with loader(dl_deps_msg, spinner="bouncingBar"): ValidatorPackageService.install_hub_module( - module_manifest, site_packages, quiet=quiet, logger=cli_logger + module_manifest, + site_packages, + quiet=quiet, + upgrade=upgrade, + logger=cli_logger, ) use_remote_endpoint = False diff --git a/guardrails/hub/validator_package_service.py b/guardrails/hub/validator_package_service.py index 526b9cb5b..e90825d27 100644 --- a/guardrails/hub/validator_package_service.py +++ b/guardrails/hub/validator_package_service.py @@ -260,6 +260,7 @@ def install_hub_module( module_manifest: Manifest, site_packages: str, quiet: bool = False, + upgrade: bool = False, logger=guardrails_logger, ): install_url = ValidatorPackageService.get_install_url(module_manifest) @@ -268,6 +269,10 @@ def install_hub_module( ) pip_flags = [f"--target={install_directory}", "--no-deps"] + + if upgrade: + pip_flags.append("--upgrade") + if quiet: pip_flags.append("-q") diff --git a/tests/unit_tests/cli/hub/test_install.py b/tests/unit_tests/cli/hub/test_install.py index 21eb5abe9..6fd186322 100644 --- a/tests/unit_tests/cli/hub/test_install.py +++ b/tests/unit_tests/cli/hub/test_install.py @@ -29,6 +29,7 @@ def test_install_local_models__false(self, mocker): "hub://guardrails/test-validator", install_local_models=False, quiet=False, + upgrade=False, install_local_models_confirm=ANY, ) @@ -45,6 +46,7 @@ def test_install_local_models__true(self, mocker): "hub://guardrails/test-validator", install_local_models=True, quiet=False, + upgrade=False, install_local_models_confirm=ANY, ) @@ -61,6 +63,7 @@ def test_install_local_models__none(self, mocker): "hub://guardrails/test-validator", install_local_models=None, quiet=False, + upgrade=False, install_local_models_confirm=ANY, ) @@ -77,6 +80,7 @@ def test_install_quiet(self, mocker): "hub://guardrails/test-validator", install_local_models=None, quiet=True, + upgrade=False, install_local_models_confirm=ANY, ) @@ -205,6 +209,23 @@ def test_other_exception(self, mocker): sys_exit_spy.assert_called_once_with(1) + def test_install_with_upgrade_flag(self, mocker): + mock_install = mocker.patch("guardrails.hub.install.install") + runner = CliRunner() + result = runner.invoke( + hub_command, ["install", "--upgrade", "hub://guardrails/test-validator"] + ) + + mock_install.assert_called_once_with( + "hub://guardrails/test-validator", + install_local_models=None, + quiet=False, + install_local_models_confirm=ANY, + upgrade=True, + ) + + assert result.exit_code == 0 + def test_get_site_packages_location(mocker): mock_pip_process = mocker.patch("guardrails.cli.hub.utils.pip_process") diff --git a/tests/unit_tests/hub/test_hub_install.py b/tests/unit_tests/hub/test_hub_install.py index 8677c9b54..4b35feec9 100644 --- a/tests/unit_tests/hub/test_hub_install.py +++ b/tests/unit_tests/hub/test_hub_install.py @@ -99,7 +99,7 @@ def test_install_local_models__false(self, mocker, use_remote_inferencing): ) mock_pip_install_hub_module.assert_called_once_with( - self.manifest, self.site_packages, quiet=ANY, logger=ANY + self.manifest, self.site_packages, quiet=ANY, upgrade=ANY, logger=ANY ) mock_add_to_hub_init.assert_called_once_with(self.manifest, self.site_packages) @@ -160,7 +160,7 @@ def test_install_local_models__true(self, mocker, use_remote_inferencing): ) mock_pip_install_hub_module.assert_called_once_with( - self.manifest, self.site_packages, quiet=ANY, logger=ANY + self.manifest, self.site_packages, quiet=ANY, upgrade=ANY, logger=ANY ) mock_add_to_hub_init.assert_called_once_with(self.manifest, self.site_packages) @@ -221,7 +221,7 @@ def test_install_local_models__none(self, mocker, use_remote_inferencing): ) mock_pip_install_hub_module.assert_called_once_with( - self.manifest, self.site_packages, quiet=ANY, logger=ANY + self.manifest, self.site_packages, quiet=ANY, upgrade=ANY, logger=ANY ) mock_add_to_hub_init.assert_called_once_with(self.manifest, self.site_packages) @@ -278,7 +278,7 @@ def test_happy_path(self, mocker, use_remote_inferencing): ) mock_pip_install_hub_module.assert_called_once_with( - self.manifest, self.site_packages, quiet=ANY, logger=ANY + self.manifest, self.site_packages, quiet=ANY, upgrade=ANY, logger=ANY ) mock_add_to_hub_init.assert_called_once_with(self.manifest, self.site_packages)