diff --git a/guardrails/cli/__init__.py b/guardrails/cli/__init__.py index 4198e7db4..833ada8b4 100644 --- a/guardrails/cli/__init__.py +++ b/guardrails/cli/__init__.py @@ -1,10 +1,11 @@ -import guardrails.cli.compile # noqa import guardrails.cli.configure # noqa import guardrails.cli.validate # noqa from guardrails.cli.guardrails import guardrails as cli from guardrails.cli.hub import hub -cli.add_typer(hub, name="hub") +cli.add_typer( + hub, name="hub", help="Manage validators installed from the Guardrails Hub." +) if __name__ == "__main__": diff --git a/guardrails/cli/compile.py b/guardrails/cli/compile.py deleted file mode 100644 index d759a60d8..000000000 --- a/guardrails/cli/compile.py +++ /dev/null @@ -1,25 +0,0 @@ -import typer - -from guardrails.cli.guardrails import guardrails -from guardrails.cli.logger import logger - - -def compile_rail(rail: str, out: str) -> None: - """Compile guardrails from the guardrails.yml file.""" - raise NotImplementedError("Currently compiling rail is not supported.") - - -@guardrails.command() -def compile( - rail: str = typer.Argument( - ..., help="Path to the rail spec.", exists=True, file_okay=True, dir_okay=False - ), - out: str = typer.Option( - default=".rail_output", - help="Path to the compiled output directory.", - file_okay=False, - dir_okay=True, - ), -): - """Compile guardrails from a `rail` spec.""" - logger.error("Not supported yet. Use `validate` instead.") diff --git a/guardrails/cli/configure.py b/guardrails/cli/configure.py index 74e2adce0..ce674e1e2 100644 --- a/guardrails/cli/configure.py +++ b/guardrails/cli/configure.py @@ -1,4 +1,5 @@ import os +import sys import uuid from os.path import expanduser from typing import Optional @@ -6,7 +7,8 @@ import typer from guardrails.cli.guardrails import guardrails -from guardrails.cli.logger import logger +from guardrails.cli.logger import LEVELS, logger +from guardrails.cli.server.hub_client import AuthenticationError, get_auth def save_configuration_file( @@ -38,9 +40,40 @@ def configure( ), ): """Set the global configuration for the Guardrails CLI and Hub.""" - if not client_id: - client_id = typer.prompt("Client ID") - if not client_secret: - client_secret = typer.prompt("Client secret", hide_input=True) - logger.info("Configuring...") - save_configuration_file(client_id, client_secret, no_metrics) # type: ignore + try: + if not client_id: + client_id = typer.prompt("Client ID") + if not client_secret: + client_secret = typer.prompt("Client secret", hide_input=True) + logger.info("Configuring...") + save_configuration_file(client_id, client_secret, no_metrics) # type: ignore + + logger.info("Validating credentials...") + get_auth() + success_message = """ + + Login successful. + + Get started by installing a validator from the Guardrails Hub! + + guardrails hub install hub://guardrails/lowercase + + Find more validators at https://hub.guardrailsai.com + """ + logger.log(level=LEVELS.get("SUCCESS"), msg=success_message) # type: ignore + except AuthenticationError as auth_error: + logger.error(auth_error) + logger.error( + """ + Check that your Client ID and Client secret are correct and try again. + + If you don't have your token credentials you can find them here: + + https://hub.guardrailsai.com/tokens + """ + ) + sys.exit(1) + except Exception as e: + logger.error("An unexpected error occurred!") + logger.error(e) + sys.exit(1) diff --git a/guardrails/cli/hub/install.py b/guardrails/cli/hub/install.py index 05a2eb6fb..54acac1f3 100644 --- a/guardrails/cli/hub/install.py +++ b/guardrails/cli/hub/install.py @@ -13,9 +13,18 @@ from guardrails.classes.generic import Stack from guardrails.cli.hub.hub import hub from guardrails.cli.logger import LEVELS, logger -from guardrails.cli.server.hub_client import fetch_module +from guardrails.cli.server.hub_client import get_validator_manifest from guardrails.cli.server.module_manifest import ModuleManifest + +def removesuffix(string: str, suffix: str) -> str: + if sys.version_info.minor >= 9: + return string.removesuffix(suffix) + else: + if string.endswith(suffix): + return string[: -len(suffix)] + + string_format: Literal["string"] = "string" json_format: Literal["json"] = "json" @@ -125,7 +134,7 @@ def run_post_install(manifest: ModuleManifest): post_install_script = manifest.post_install if post_install_script: module_name = manifest.module_name - post_install_module = post_install_script.removesuffix(".py") + post_install_module = removesuffix(post_install_script, ".py") relative_path = ".".join([*org_package, module_name]) importlib.import_module(f"guardrails.hub.{relative_path}.{post_install_module}") @@ -194,7 +203,7 @@ def install( module_name = package_uri.replace("hub://", "") # Prep - module_manifest = fetch_module(module_name) + module_manifest = get_validator_manifest(module_name) site_packages = get_site_packages_location() # Install diff --git a/guardrails/cli/server/auth.py b/guardrails/cli/server/auth.py index f0152b352..49cc1feef 100644 --- a/guardrails/cli/server/auth.py +++ b/guardrails/cli/server/auth.py @@ -1,10 +1,10 @@ import http.client import json -from guardrails.cli.hub.credentials import Credentials +from guardrails.cli.server.credentials import Credentials -def authenticate(creds: Credentials) -> str: +def get_auth_token(creds: Credentials) -> str: if creds.client_id and creds.client_secret: audience = "https://validator-hub-service.guardrailsai.com" conn = http.client.HTTPSConnection("guardrailsai.us.auth0.com") diff --git a/guardrails/cli/hub/credentials.py b/guardrails/cli/server/credentials.py similarity index 91% rename from guardrails/cli/hub/credentials.py rename to guardrails/cli/server/credentials.py index 7a67928f1..f51d4306b 100644 --- a/guardrails/cli/hub/credentials.py +++ b/guardrails/cli/server/credentials.py @@ -29,10 +29,10 @@ def from_rc_file() -> "Credentials": return Credentials.from_dict(creds) except FileNotFoundError as e: - logger.error(e) - logger.error( + logger.warning(e) + logger.warning( "Guardrails Hub credentials not found!" "You will need to sign up to use any authenticated Validators here:" - "{insert url}" + "https://hub.guardrailsai.com/tokens" ) return Credentials() diff --git a/guardrails/cli/server/hub_client.py b/guardrails/cli/server/hub_client.py index 2049859bc..50a1cbf52 100644 --- a/guardrails/cli/server/hub_client.py +++ b/guardrails/cli/server/hub_client.py @@ -4,9 +4,9 @@ import requests -from guardrails.cli.hub.credentials import Credentials from guardrails.cli.logger import logger -from guardrails.cli.server.auth import authenticate +from guardrails.cli.server.auth import get_auth_token +from guardrails.cli.server.credentials import Credentials from guardrails.cli.server.module_manifest import ModuleManifest validator_hub_service = "https://so4sg4q4pb.execute-api.us-east-1.amazonaws.com" @@ -15,6 +15,15 @@ ) +class AuthenticationError(Exception): + pass + + +class HttpError(Exception): + status: int + message: str + + def fetch(url: str, token: Optional[str], anonymousUserId: Optional[str]): try: # For Debugging @@ -29,9 +38,14 @@ def fetch(url: str, token: Optional[str], anonymousUserId: Optional[str]): if not req.ok: logger.error(req.status_code) logger.error(body.get("message")) - sys.exit(1) + http_error = HttpError() + http_error.status = req.status_code + http_error.message = body.get("message") + raise http_error return body + except HttpError as http_e: + raise http_e except Exception as e: logger.error("An unexpected error occurred!", e) sys.exit(1) @@ -50,8 +64,40 @@ def fetch_module_manifest( def fetch_module(module_name: str) -> ModuleManifest: creds = Credentials.from_rc_file() - token = authenticate(creds) + token = get_auth_token(creds) module_manifest_json = fetch_module_manifest(module_name, token, creds.id) - module_manifest = ModuleManifest.from_dict(module_manifest_json) - return module_manifest + return ModuleManifest.from_dict(module_manifest_json) + + +# GET /validator-manifests/{namespace}/{validatorName} +def get_validator_manifest(module_name: str): + try: + module_manifest = fetch_module(module_name) + if not module_manifest: + logger.error(f"Failed to install hub://{module_name}") + sys.exit(1) + return module_manifest + except HttpError: + logger.error(f"Failed to install hub://{module_name}") + sys.exit(1) + except Exception as e: + logger.error("An unexpected error occurred!", e) + sys.exit(1) + + +# GET /auth +def get_auth(): + try: + creds = Credentials.from_rc_file() + token = get_auth_token(creds) + auth_url = f"{validator_hub_service}/auth" + response = fetch(auth_url, token, creds.id) + if not response: + raise AuthenticationError("Failed to authenticate!") + except HttpError as http_error: + logger.error(http_error) + raise AuthenticationError("Failed to authenticate!") + except Exception as e: + logger.error("An unexpected error occurred!", e) + raise AuthenticationError("Failed to authenticate!") diff --git a/guardrails/cli/server/module_manifest.py b/guardrails/cli/server/module_manifest.py index 77512cc43..df66fb77e 100644 --- a/guardrails/cli/server/module_manifest.py +++ b/guardrails/cli/server/module_manifest.py @@ -1,7 +1,9 @@ from dataclasses import dataclass, field from typing import Any, Dict, List, Optional -from guardrails.cli.server.serializeable import Serializeable +from pydash.strings import snake_case + +from guardrails.cli.server.serializeable import Serializeable, SerializeableJSONEncoder @dataclass @@ -25,6 +27,8 @@ class ModuleTags(Serializeable): @dataclass class ModuleManifest(Serializeable): + id: str + name: str author: Contributor maintainers: List[Contributor] repository: Repository @@ -33,21 +37,23 @@ class ModuleManifest(Serializeable): module_name: str exports: List[str] tags: ModuleTags + requires_auth: Optional[bool] = True post_install: Optional[str] = None index: Optional[str] = None # @override @classmethod def from_dict(cls, data: Dict[str, Any]): + init_kwargs = {snake_case(k): data.get(k) for k in data} + init_kwargs["encoder"] = init_kwargs.get("encoder", SerializeableJSONEncoder) + author = init_kwargs.pop("author", {}) + maintainers = init_kwargs.pop("maintainers", []) + repository = init_kwargs.pop("repository", {}) + tags = init_kwargs.pop("tags", {}) return cls( - Contributor.from_dict(data.get("author", {})), - [Contributor.from_dict(m) for m in data.get("maintainers", [])], - Repository.from_dict(data.get("repository", {})), - data.get("namespace"), # type: ignore - data.get("packageName"), # type: ignore - data.get("moduleName"), # type: ignore - data.get("exports"), # type: ignore - ModuleTags.from_dict(data.get("tags", {})), - data.get("postInstall"), - data.get("index"), + **init_kwargs, + author=Contributor.from_dict(author), + maintainers=[Contributor.from_dict(m) for m in maintainers], + repository=Repository.from_dict(repository), + tags=ModuleTags.from_dict(tags), ) diff --git a/guardrails/cli/server/serializeable.py b/guardrails/cli/server/serializeable.py index 7da328555..8d3252872 100644 --- a/guardrails/cli/server/serializeable.py +++ b/guardrails/cli/server/serializeable.py @@ -1,5 +1,6 @@ import inspect import json +import sys from dataclasses import InitVar, asdict, dataclass, field, is_dataclass from json import JSONEncoder from typing import Any, Dict @@ -7,6 +8,13 @@ from pydash.strings import snake_case +def get_annotations(obj): + if sys.version_info.minor >= 10: + return inspect.get_annotations(obj) + else: + return obj.__annotations__ + + class SerializeableJSONEncoder(JSONEncoder): def default(self, o): if is_dataclass(o): @@ -14,22 +22,27 @@ def default(self, o): return super().default(o) +encoder_kwargs = {} +if sys.version_info.minor >= 10: + encoder_kwargs["kw_only"] = True + encoder_kwargs["default"] = SerializeableJSONEncoder + + @dataclass class Serializeable: - encoder: InitVar[JSONEncoder] = field( - kw_only=True, default=SerializeableJSONEncoder # type: ignore - ) + encoder: InitVar[JSONEncoder] = field(**encoder_kwargs) @classmethod def from_dict(cls, data: Dict[str, Any]): - annotations = inspect.get_annotations(cls) + annotations = get_annotations(cls) attributes = dict.keys(annotations) - kwargs = {k: data.get(k) for k in data if k in attributes} snake_case_kwargs = { snake_case(k): data.get(k) for k in data if snake_case(k) in attributes } - kwargs.update(snake_case_kwargs) - return cls(**kwargs) # type: ignore + snake_case_kwargs["encoder"] = snake_case_kwargs.get( + "encoder", SerializeableJSONEncoder + ) + return cls(**snake_case_kwargs) # type: ignore @property def __dict__(self) -> Dict[str, Any]: diff --git a/tests/unit_tests/cli/hub/test_install.py b/tests/unit_tests/cli/hub/test_install.py index a494a8810..17b9db2c4 100644 --- a/tests/unit_tests/cli/hub/test_install.py +++ b/tests/unit_tests/cli/hub/test_install.py @@ -24,8 +24,10 @@ def test_exits_early_if_uri_is_not_valid(self, mocker): def test_happy_path(self, mocker): mock_logger_log = mocker.patch("guardrails.cli.hub.install.logger.log") - mock_fetch_module = mocker.patch("guardrails.cli.hub.install.fetch_module") + mock_get_validator_manifest = mocker.patch("guardrails.cli.hub.install.get_validator_manifest") manifest = ModuleManifest( + "id", + "name", "me", [], Repository(url="some-repo"), @@ -35,7 +37,7 @@ def test_happy_path(self, mocker): ["TestValidator"], ModuleTags(), ) - mock_fetch_module.return_value = manifest + mock_get_validator_manifest.return_value = manifest mock_get_site_packages_location = mocker.patch( "guardrails.cli.hub.install.get_site_packages_location" @@ -67,7 +69,7 @@ def test_happy_path(self, mocker): assert mock_logger_log.call_count == 2 mock_logger_log.assert_has_calls(log_calls) - mock_fetch_module.assert_called_once_with("guardrails/test-validator") + mock_get_validator_manifest.assert_called_once_with("guardrails/test-validator") assert mock_get_site_packages_location.call_count == 1 @@ -213,6 +215,8 @@ def test_get_site_packages_location(mocker): [ ( ModuleManifest( + "id", + "name", "me", [], Repository(url="some-repo"), @@ -226,6 +230,8 @@ def test_get_site_packages_location(mocker): ), ( ModuleManifest( + "id", + "name", "me", [], Repository(url="some-repo"), @@ -249,6 +255,8 @@ def test_get_org_and_package_dirs(manifest, expected): def test_get_hub_directory(): manifest = ModuleManifest( + "id", + "name", "me", [], Repository(url="some-repo"), @@ -269,6 +277,8 @@ def test_get_hub_directory(): class TestAddToHubInits: def test_closes_early_if_already_added(self, mocker): manifest = ModuleManifest( + "id", + "name", "me", [], Repository(url="some-repo"), @@ -328,6 +338,8 @@ def test_closes_early_if_already_added(self, mocker): def test_appends_import_line_if_not_present(self, mocker): manifest = ModuleManifest( + "id", + "name", "me", [], Repository(url="some-repo"), @@ -406,6 +418,8 @@ def test_appends_import_line_if_not_present(self, mocker): def test_creates_namespace_init_if_not_exists(self, mocker): manifest = ModuleManifest( + "id", + "name", "me", [], Repository(url="some-repo"), @@ -464,6 +478,8 @@ class TestRunPostInstall: "manifest", [ ModuleManifest( + "id", + "name", "me", [], Repository(url="some-repo"), @@ -474,6 +490,8 @@ class TestRunPostInstall: ModuleTags(), ), ModuleManifest( + "id", + "name", "me", [], Repository(url="some-repo"), @@ -503,6 +521,8 @@ def test_runs_script_if_exists(self, mocker): from guardrails.cli.hub.install import run_post_install manifest = ModuleManifest( + "id", + "name", "me", [], Repository(url="some-repo"), @@ -527,6 +547,8 @@ def test_runs_script_if_exists(self, mocker): [ ( ModuleManifest( + "id", + "name", "me", [], Repository(url="some-repo"), @@ -540,6 +562,8 @@ def test_runs_script_if_exists(self, mocker): ), ( ModuleManifest( + "id", + "name", "me", [], Repository(url="git+some-repo"), @@ -554,6 +578,8 @@ def test_runs_script_if_exists(self, mocker): ), ( ModuleManifest( + "id", + "name", "me", [], Repository(url="git+some-repo", branch="prod"), @@ -611,6 +637,8 @@ def test_install_hub_module(mocker): from guardrails.cli.hub.install import install_hub_module manifest = ModuleManifest( + "id", + "name", "me", [], Repository(url="some-repo"), diff --git a/tests/unit_tests/cli/hub/server/test_auth.py b/tests/unit_tests/cli/server/test_auth.py similarity index 69% rename from tests/unit_tests/cli/hub/server/test_auth.py rename to tests/unit_tests/cli/server/test_auth.py index 1b13494c4..aff255c98 100644 --- a/tests/unit_tests/cli/hub/server/test_auth.py +++ b/tests/unit_tests/cli/server/test_auth.py @@ -3,5 +3,5 @@ # TODO @pytest.mark.skip() -def test_authenticate(): +def test_get_auth_token(): assert 1 == 1 diff --git a/tests/unit_tests/cli/hub/test_credentials.py b/tests/unit_tests/cli/server/test_credentials.py similarity index 85% rename from tests/unit_tests/cli/hub/test_credentials.py rename to tests/unit_tests/cli/server/test_credentials.py index 6ad9a4481..3366c12d7 100644 --- a/tests/unit_tests/cli/hub/test_credentials.py +++ b/tests/unit_tests/cli/server/test_credentials.py @@ -7,7 +7,7 @@ def test_from_rc_file(mocker): mocker.patch("nltk.data.find") mocker.patch("nltk.download") - expanduser_mock = mocker.patch("guardrails.cli.hub.credentials.expanduser") + expanduser_mock = mocker.patch("guardrails.cli.server.credentials.expanduser") expanduser_mock.return_value = "/Home" import os @@ -15,14 +15,14 @@ def test_from_rc_file(mocker): join_spy = mocker.spy(os.path, "join") mock_file = MockFile() - mock_open = mocker.patch("guardrails.cli.hub.credentials.open") + mock_open = mocker.patch("guardrails.cli.server.credentials.open") mock_open.return_value = mock_file readlines_spy = mocker.patch.object(mock_file, "readlines") readlines_spy.return_value = ["key1=val1", "key2=val2"] close_spy = mocker.spy(mock_file, "close") - from guardrails.cli.hub.credentials import Credentials + from guardrails.cli.server.credentials import Credentials mock_from_dict = mocker.patch.object(Credentials, "from_dict") diff --git a/tests/unit_tests/cli/hub/server/test_hub_client.py b/tests/unit_tests/cli/server/test_hub_client.py similarity index 60% rename from tests/unit_tests/cli/hub/server/test_hub_client.py rename to tests/unit_tests/cli/server/test_hub_client.py index aa7b8aaf1..a7bafe600 100644 --- a/tests/unit_tests/cli/hub/server/test_hub_client.py +++ b/tests/unit_tests/cli/server/test_hub_client.py @@ -17,3 +17,15 @@ def test_fetch_module_manifest(): @pytest.mark.skip() def test_fetch_module(): assert 1 == 1 + + +# TODO +@pytest.mark.skip() +def test_get_validator_manifest(): + assert 1 == 1 + + +# TODO +@pytest.mark.skip() +def test_get_auth(): + assert 1 == 1 diff --git a/tests/unit_tests/cli/test_compile.py b/tests/unit_tests/cli/test_compile.py deleted file mode 100644 index 85ae19ccb..000000000 --- a/tests/unit_tests/cli/test_compile.py +++ /dev/null @@ -1,18 +0,0 @@ -import pytest - -from guardrails.cli.compile import compile, compile_rail, logger - - -def test_compile_rail(): - with pytest.raises(NotImplementedError) as nie: - compile_rail("my_spec.rail", ".rail_output") - assert nie is not None - assert str(nie) == "Currently compiling rail is not supported." - - -def test_compile(mocker): - error_log_mock = mocker.patch.object(logger, "error") - - compile("my_spec.rail") - - error_log_mock.assert_called_once_with("Not supported yet. Use `validate` instead.") diff --git a/tests/unit_tests/cli/test_configure.py b/tests/unit_tests/cli/test_configure.py index e50fa28b7..7a587d6e3 100644 --- a/tests/unit_tests/cli/test_configure.py +++ b/tests/unit_tests/cli/test_configure.py @@ -19,16 +19,21 @@ def test_configure(mocker, client_id, client_secret, no_metrics): "guardrails.cli.configure.save_configuration_file" ) mock_logger_info = mocker.patch("guardrails.cli.configure.logger.info") + mock_get_auth = mocker.patch("guardrails.cli.configure.get_auth") from guardrails.cli.configure import configure configure(client_id, client_secret, no_metrics) - mock_logger_info.assert_called_once_with("Configuring...") + assert mock_logger_info.call_count == 2 + expected_calls = [call("Configuring..."), call("Validating credentials...")] + mock_logger_info.assert_has_calls(expected_calls) + mock_save_configuration_file.assert_called_once_with( client_id, client_secret, no_metrics ) + assert mock_get_auth.call_count == 1 def test_configure_prompting(mocker): mock_typer_prompt = mocker.patch("typer.prompt") @@ -37,6 +42,7 @@ def test_configure_prompting(mocker): "guardrails.cli.configure.save_configuration_file" ) mock_logger_info = mocker.patch("guardrails.cli.configure.logger.info") + mock_get_auth = mocker.patch("guardrails.cli.configure.get_auth") from guardrails.cli.configure import configure @@ -45,9 +51,15 @@ def test_configure_prompting(mocker): assert mock_typer_prompt.call_count == 2 expected_calls = [call("Client ID"), call("Client secret", hide_input=True)] mock_typer_prompt.assert_has_calls(expected_calls) - mock_logger_info.assert_called_once_with("Configuring...") + + assert mock_logger_info.call_count == 2 + expected_calls = [call("Configuring..."), call("Validating credentials...")] + mock_logger_info.assert_has_calls(expected_calls) + mock_save_configuration_file.assert_called_once_with("id", "secret", False) + assert mock_get_auth.call_count == 1 + def test_save_configuration_file(mocker): # TODO: Re-enable this once we move nltk.download calls to individual validator repos. # noqa