Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions guardrails/cli/hub/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def install(
"--quiet",
help="Run the command in quiet mode to reduce output verbosity.",
),
upgrade: bool = typer.Option(
False, "--upgrade", help="Upgrade the package to the latest version."
),
):
try:
trace_if_enabled("hub/install")
Expand All @@ -41,6 +44,7 @@ def confirm():
package_uri,
install_local_models=local_models,
quiet=quiet,
upgrade=upgrade,
install_local_models_confirm=confirm,
)
except Exception as e:
Expand Down
7 changes: 6 additions & 1 deletion guardrails/hub/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def install(
package_uri: str,
install_local_models=None,
quiet: bool = True,
upgrade: bool = False,
install_local_models_confirm: Callable = default_local_models_confirm,
) -> ValidatorModuleType:
"""Install a validator package from a hub URI.
Expand Down Expand Up @@ -84,7 +85,11 @@ def install(
dl_deps_msg = "Downloading dependencies"
with loader(dl_deps_msg, spinner="bouncingBar"):
ValidatorPackageService.install_hub_module(
module_manifest, site_packages, quiet=quiet, logger=cli_logger
module_manifest,
site_packages,
quiet=quiet,
upgrade=upgrade,
logger=cli_logger,
)

use_remote_endpoint = False
Expand Down
5 changes: 5 additions & 0 deletions guardrails/hub/validator_package_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def install_hub_module(
module_manifest: Manifest,
site_packages: str,
quiet: bool = False,
upgrade: bool = False,
logger=guardrails_logger,
):
install_url = ValidatorPackageService.get_install_url(module_manifest)
Expand All @@ -268,6 +269,10 @@ def install_hub_module(
)

pip_flags = [f"--target={install_directory}", "--no-deps"]

if upgrade:
pip_flags.append("--upgrade")

if quiet:
pip_flags.append("-q")

Expand Down
21 changes: 21 additions & 0 deletions tests/unit_tests/cli/hub/test_install.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def test_install_local_models__false(self, mocker):
"hub://guardrails/test-validator",
install_local_models=False,
quiet=False,
upgrade=False,
install_local_models_confirm=ANY,
)

Expand All @@ -45,6 +46,7 @@ def test_install_local_models__true(self, mocker):
"hub://guardrails/test-validator",
install_local_models=True,
quiet=False,
upgrade=False,
install_local_models_confirm=ANY,
)

Expand All @@ -61,6 +63,7 @@ def test_install_local_models__none(self, mocker):
"hub://guardrails/test-validator",
install_local_models=None,
quiet=False,
upgrade=False,
install_local_models_confirm=ANY,
)

Expand All @@ -77,6 +80,7 @@ def test_install_quiet(self, mocker):
"hub://guardrails/test-validator",
install_local_models=None,
quiet=True,
upgrade=False,
install_local_models_confirm=ANY,
)

Expand Down Expand Up @@ -205,6 +209,23 @@ def test_other_exception(self, mocker):

sys_exit_spy.assert_called_once_with(1)

def test_install_with_upgrade_flag(self, mocker):
mock_install = mocker.patch("guardrails.hub.install.install")
runner = CliRunner()
result = runner.invoke(
hub_command, ["install", "--upgrade", "hub://guardrails/test-validator"]
)

mock_install.assert_called_once_with(
"hub://guardrails/test-validator",
install_local_models=None,
quiet=False,
install_local_models_confirm=ANY,
upgrade=True,
)

assert result.exit_code == 0


def test_get_site_packages_location(mocker):
mock_pip_process = mocker.patch("guardrails.cli.hub.utils.pip_process")
Expand Down
8 changes: 4 additions & 4 deletions tests/unit_tests/hub/test_hub_install.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_install_local_models__false(self, mocker, use_remote_inferencing):
)

mock_pip_install_hub_module.assert_called_once_with(
self.manifest, self.site_packages, quiet=ANY, logger=ANY
self.manifest, self.site_packages, quiet=ANY, upgrade=ANY, logger=ANY
)
mock_add_to_hub_init.assert_called_once_with(self.manifest, self.site_packages)

Expand Down Expand Up @@ -160,7 +160,7 @@ def test_install_local_models__true(self, mocker, use_remote_inferencing):
)

mock_pip_install_hub_module.assert_called_once_with(
self.manifest, self.site_packages, quiet=ANY, logger=ANY
self.manifest, self.site_packages, quiet=ANY, upgrade=ANY, logger=ANY
)
mock_add_to_hub_init.assert_called_once_with(self.manifest, self.site_packages)

Expand Down Expand Up @@ -221,7 +221,7 @@ def test_install_local_models__none(self, mocker, use_remote_inferencing):
)

mock_pip_install_hub_module.assert_called_once_with(
self.manifest, self.site_packages, quiet=ANY, logger=ANY
self.manifest, self.site_packages, quiet=ANY, upgrade=ANY, logger=ANY
)
mock_add_to_hub_init.assert_called_once_with(self.manifest, self.site_packages)

Expand Down Expand Up @@ -278,7 +278,7 @@ def test_happy_path(self, mocker, use_remote_inferencing):
)

mock_pip_install_hub_module.assert_called_once_with(
self.manifest, self.site_packages, quiet=ANY, logger=ANY
self.manifest, self.site_packages, quiet=ANY, upgrade=ANY, logger=ANY
)
mock_add_to_hub_init.assert_called_once_with(self.manifest, self.site_packages)

Expand Down
Loading