diff --git a/guardrails/cli/hub/install.py b/guardrails/cli/hub/install.py index 3ad9a07d1..047f5dae9 100644 --- a/guardrails/cli/hub/install.py +++ b/guardrails/cli/hub/install.py @@ -1,5 +1,5 @@ import sys -from typing import Optional +from typing import Optional, List import typer @@ -10,9 +10,9 @@ @hub_command.command() def install( - package_uri: str = typer.Argument( - help="URI to the package to install.\ -Example: hub://guardrails/regex_match." + package_uris: List[str] = typer.Argument( + ..., + help="URIs to the packages to install. Example: hub://guardrails/regex_match hub://guardrails/toxic_language", ), local_models: Optional[bool] = typer.Option( None, @@ -28,7 +28,7 @@ def install( ): try: trace_if_enabled("hub/install") - from guardrails.hub.install import install + from guardrails.hub.install import install_multiple def confirm(): return typer.confirm( @@ -37,8 +37,8 @@ def confirm(): " local models for local inference?", ) - install( - package_uri, + install_multiple( + package_uris, install_local_models=local_models, quiet=quiet, install_local_models_confirm=confirm, diff --git a/guardrails/hub/install.py b/guardrails/hub/install.py index 6ce9e16cf..ce7d96ab1 100644 --- a/guardrails/hub/install.py +++ b/guardrails/hub/install.py @@ -1,6 +1,6 @@ from contextlib import contextmanager from string import Template -from typing import Callable, cast +from typing import Callable, cast, List from guardrails.hub.validator_package_service import ( ValidatorPackageService, @@ -164,3 +164,35 @@ def install( installed_module.__validator_exports__ = module_manifest.exports return installed_module + + +def install_multiple( + package_uris: List[str], + install_local_models=None, + quiet: bool = True, + install_local_models_confirm: Callable = default_local_models_confirm, +) -> List[ValidatorModuleType]: + """Install multiple validator packages from hub URIs. + + Args: + package_uris (List[str]): List of URIs of the packages to install. + install_local_models (bool): Whether to install local models or not. + quiet (bool): Whether to suppress output or not. + install_local_models_confirm (Callable): A function to confirm the + installation of local models. + + Returns: + List[ValidatorModuleType]: List of installed validator modules. + """ + installed_modules = [] + + for package_uri in package_uris: + installed_module = install( + package_uri, + install_local_models=install_local_models, + quiet=quiet, + install_local_models_confirm=install_local_models_confirm, + ) + installed_modules.append(installed_module) + + return installed_modules diff --git a/tests/unit_tests/cli/hub/test_install.py b/tests/unit_tests/cli/hub/test_install.py old mode 100644 new mode 100755 index 21eb5abe9..cd6263ca0 --- a/tests/unit_tests/cli/hub/test_install.py +++ b/tests/unit_tests/cli/hub/test_install.py @@ -82,6 +82,50 @@ def test_install_quiet(self, mocker): assert result.exit_code == 0 + def test_install_multiple_validators(self, mocker): + mock_install_multiple = mocker.patch("guardrails.hub.install.install_multiple") + runner = CliRunner() + result = runner.invoke( + hub_command, + [ + "install", + "hub://guardrails/validator1", + "hub://guardrails/validator2", + "--no-install-local-models", + ], + ) + + mock_install_multiple.assert_called_once_with( + ["hub://guardrails/validator1", "hub://guardrails/validator2"], + install_local_models=False, + quiet=False, + install_local_models_confirm=ANY, + ) + + assert result.exit_code == 0 + + def test_install_multiple_validators_with_quiet(self, mocker): + mock_install_multiple = mocker.patch("guardrails.hub.install.install_multiple") + runner = CliRunner() + result = runner.invoke( + hub_command, + [ + "install", + "hub://guardrails/validator1", + "hub://guardrails/validator2", + "--quiet", + ], + ) + + mock_install_multiple.assert_called_once_with( + ["hub://guardrails/validator1", "hub://guardrails/validator2"], + install_local_models=None, + quiet=True, + install_local_models_confirm=ANY, + ) + + assert result.exit_code == 0 + class TestPipProcess: def test_no_package_string_format(self, mocker): @@ -208,7 +252,7 @@ def test_other_exception(self, mocker): def test_get_site_packages_location(mocker): mock_pip_process = mocker.patch("guardrails.cli.hub.utils.pip_process") - mock_pip_process.return_value = {"Location": "/site-pacakges"} + mock_pip_process.return_value = {"Location": "/site-packages"} from guardrails.cli.hub.utils import get_site_packages_location @@ -216,4 +260,4 @@ def test_get_site_packages_location(mocker): mock_pip_process.assert_called_once_with("show", "pip", format="json") - assert response == "/site-pacakges" + assert response == "/site-packages"