Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions guardrails/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import guardrails.cli.compile # noqa
import guardrails.cli.configure # noqa
import guardrails.cli.validate # noqa
from guardrails.cli.guardrails import guardrails as cli
from guardrails.cli.hub import hub

cli.add_typer(hub, name="hub")
cli.add_typer(
hub, name="hub", help="Manage validators installed from the Guardrails Hub."
)


if __name__ == "__main__":
Expand Down
25 changes: 0 additions & 25 deletions guardrails/cli/compile.py

This file was deleted.

47 changes: 40 additions & 7 deletions guardrails/cli/configure.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import os
import sys
import uuid
from os.path import expanduser
from typing import Optional

import typer

from guardrails.cli.guardrails import guardrails
from guardrails.cli.logger import logger
from guardrails.cli.logger import LEVELS, logger
from guardrails.cli.server.hub_client import AuthenticationError, get_auth


def save_configuration_file(
Expand Down Expand Up @@ -38,9 +40,40 @@ def configure(
),
):
"""Set the global configuration for the Guardrails CLI and Hub."""
if not client_id:
client_id = typer.prompt("Client ID")
if not client_secret:
client_secret = typer.prompt("Client secret", hide_input=True)
logger.info("Configuring...")
save_configuration_file(client_id, client_secret, no_metrics) # type: ignore
try:
if not client_id:
client_id = typer.prompt("Client ID")
if not client_secret:
client_secret = typer.prompt("Client secret", hide_input=True)
logger.info("Configuring...")
save_configuration_file(client_id, client_secret, no_metrics) # type: ignore

logger.info("Validating credentials...")
get_auth()
success_message = """

Login successful.

Get started by installing a validator from the Guardrails Hub!

guardrails hub install hub://guardrails/lowercase

Find more validators at https://hub.guardrailsai.com
"""
logger.log(level=LEVELS.get("SUCCESS"), msg=success_message) # type: ignore
except AuthenticationError as auth_error:
logger.error(auth_error)
logger.error(
"""
Check that your Client ID and Client secret are correct and try again.

If you don't have your token credentials you can find them here:

https://hub.guardrailsai.com/tokens
"""
)
sys.exit(1)
except Exception as e:
logger.error("An unexpected error occurred!")
logger.error(e)
sys.exit(1)
15 changes: 12 additions & 3 deletions guardrails/cli/hub/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,18 @@
from guardrails.classes.generic import Stack
from guardrails.cli.hub.hub import hub
from guardrails.cli.logger import LEVELS, logger
from guardrails.cli.server.hub_client import fetch_module
from guardrails.cli.server.hub_client import get_validator_manifest
from guardrails.cli.server.module_manifest import ModuleManifest


def removesuffix(string: str, suffix: str) -> str:
if sys.version_info.minor >= 9:
return string.removesuffix(suffix)
else:
if string.endswith(suffix):
return string[: -len(suffix)]


string_format: Literal["string"] = "string"
json_format: Literal["json"] = "json"

Expand Down Expand Up @@ -125,7 +134,7 @@ def run_post_install(manifest: ModuleManifest):
post_install_script = manifest.post_install
if post_install_script:
module_name = manifest.module_name
post_install_module = post_install_script.removesuffix(".py")
post_install_module = removesuffix(post_install_script, ".py")
relative_path = ".".join([*org_package, module_name])
importlib.import_module(f"guardrails.hub.{relative_path}.{post_install_module}")

Expand Down Expand Up @@ -194,7 +203,7 @@ def install(
module_name = package_uri.replace("hub://", "")

# Prep
module_manifest = fetch_module(module_name)
module_manifest = get_validator_manifest(module_name)
site_packages = get_site_packages_location()

# Install
Expand Down
4 changes: 2 additions & 2 deletions guardrails/cli/server/auth.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import http.client
import json

from guardrails.cli.hub.credentials import Credentials
from guardrails.cli.server.credentials import Credentials


def authenticate(creds: Credentials) -> str:
def get_auth_token(creds: Credentials) -> str:
if creds.client_id and creds.client_secret:
audience = "https://validator-hub-service.guardrailsai.com"
conn = http.client.HTTPSConnection("guardrailsai.us.auth0.com")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ def from_rc_file() -> "Credentials":
return Credentials.from_dict(creds)

except FileNotFoundError as e:
logger.error(e)
logger.error(
logger.warning(e)
logger.warning(
"Guardrails Hub credentials not found!"
"You will need to sign up to use any authenticated Validators here:"
"{insert url}"
"https://hub.guardrailsai.com/tokens"
)
return Credentials()
58 changes: 52 additions & 6 deletions guardrails/cli/server/hub_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

import requests

from guardrails.cli.hub.credentials import Credentials
from guardrails.cli.logger import logger
from guardrails.cli.server.auth import authenticate
from guardrails.cli.server.auth import get_auth_token
from guardrails.cli.server.credentials import Credentials
from guardrails.cli.server.module_manifest import ModuleManifest

validator_hub_service = "https://so4sg4q4pb.execute-api.us-east-1.amazonaws.com"
Expand All @@ -15,6 +15,15 @@
)


class AuthenticationError(Exception):
pass


class HttpError(Exception):
status: int
message: str


def fetch(url: str, token: Optional[str], anonymousUserId: Optional[str]):
try:
# For Debugging
Expand All @@ -29,9 +38,14 @@ def fetch(url: str, token: Optional[str], anonymousUserId: Optional[str]):
if not req.ok:
logger.error(req.status_code)
logger.error(body.get("message"))
sys.exit(1)
http_error = HttpError()
http_error.status = req.status_code
http_error.message = body.get("message")
raise http_error

return body
except HttpError as http_e:
raise http_e
except Exception as e:
logger.error("An unexpected error occurred!", e)
sys.exit(1)
Expand All @@ -50,8 +64,40 @@ def fetch_module_manifest(

def fetch_module(module_name: str) -> ModuleManifest:
creds = Credentials.from_rc_file()
token = authenticate(creds)
token = get_auth_token(creds)

module_manifest_json = fetch_module_manifest(module_name, token, creds.id)
module_manifest = ModuleManifest.from_dict(module_manifest_json)
return module_manifest
return ModuleManifest.from_dict(module_manifest_json)


# GET /validator-manifests/{namespace}/{validatorName}
def get_validator_manifest(module_name: str):
try:
module_manifest = fetch_module(module_name)
if not module_manifest:
logger.error(f"Failed to install hub://{module_name}")
sys.exit(1)
return module_manifest
except HttpError:
logger.error(f"Failed to install hub://{module_name}")
sys.exit(1)
except Exception as e:
logger.error("An unexpected error occurred!", e)
sys.exit(1)


# GET /auth
def get_auth():
try:
creds = Credentials.from_rc_file()
token = get_auth_token(creds)
auth_url = f"{validator_hub_service}/auth"
response = fetch(auth_url, token, creds.id)
if not response:
raise AuthenticationError("Failed to authenticate!")
except HttpError as http_error:
logger.error(http_error)
raise AuthenticationError("Failed to authenticate!")
except Exception as e:
logger.error("An unexpected error occurred!", e)
raise AuthenticationError("Failed to authenticate!")
28 changes: 17 additions & 11 deletions guardrails/cli/server/module_manifest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional

from guardrails.cli.server.serializeable import Serializeable
from pydash.strings import snake_case

from guardrails.cli.server.serializeable import Serializeable, SerializeableJSONEncoder


@dataclass
Expand All @@ -25,6 +27,8 @@ class ModuleTags(Serializeable):

@dataclass
class ModuleManifest(Serializeable):
id: str
name: str
author: Contributor
maintainers: List[Contributor]
repository: Repository
Expand All @@ -33,21 +37,23 @@ class ModuleManifest(Serializeable):
module_name: str
exports: List[str]
tags: ModuleTags
requires_auth: Optional[bool] = True
post_install: Optional[str] = None
index: Optional[str] = None

# @override
@classmethod
def from_dict(cls, data: Dict[str, Any]):
init_kwargs = {snake_case(k): data.get(k) for k in data}
init_kwargs["encoder"] = init_kwargs.get("encoder", SerializeableJSONEncoder)
author = init_kwargs.pop("author", {})
maintainers = init_kwargs.pop("maintainers", [])
repository = init_kwargs.pop("repository", {})
tags = init_kwargs.pop("tags", {})
return cls(
Contributor.from_dict(data.get("author", {})),
[Contributor.from_dict(m) for m in data.get("maintainers", [])],
Repository.from_dict(data.get("repository", {})),
data.get("namespace"), # type: ignore
data.get("packageName"), # type: ignore
data.get("moduleName"), # type: ignore
data.get("exports"), # type: ignore
ModuleTags.from_dict(data.get("tags", {})),
data.get("postInstall"),
data.get("index"),
**init_kwargs,
author=Contributor.from_dict(author),
maintainers=[Contributor.from_dict(m) for m in maintainers],
repository=Repository.from_dict(repository),
tags=ModuleTags.from_dict(tags),
)
27 changes: 20 additions & 7 deletions guardrails/cli/server/serializeable.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,48 @@
import inspect
import json
import sys
from dataclasses import InitVar, asdict, dataclass, field, is_dataclass
from json import JSONEncoder
from typing import Any, Dict

from pydash.strings import snake_case


def get_annotations(obj):
if sys.version_info.minor >= 10:
return inspect.get_annotations(obj)
else:
return obj.__annotations__


class SerializeableJSONEncoder(JSONEncoder):
def default(self, o):
if is_dataclass(o):
return asdict(o)
return super().default(o)


encoder_kwargs = {}
if sys.version_info.minor >= 10:
encoder_kwargs["kw_only"] = True
encoder_kwargs["default"] = SerializeableJSONEncoder


@dataclass
class Serializeable:
encoder: InitVar[JSONEncoder] = field(
kw_only=True, default=SerializeableJSONEncoder # type: ignore
)
encoder: InitVar[JSONEncoder] = field(**encoder_kwargs)

@classmethod
def from_dict(cls, data: Dict[str, Any]):
annotations = inspect.get_annotations(cls)
annotations = get_annotations(cls)
attributes = dict.keys(annotations)
kwargs = {k: data.get(k) for k in data if k in attributes}
snake_case_kwargs = {
snake_case(k): data.get(k) for k in data if snake_case(k) in attributes
}
kwargs.update(snake_case_kwargs)
return cls(**kwargs) # type: ignore
snake_case_kwargs["encoder"] = snake_case_kwargs.get(
"encoder", SerializeableJSONEncoder
)
return cls(**snake_case_kwargs) # type: ignore

@property
def __dict__(self) -> Dict[str, Any]:
Expand Down
Loading