diff --git a/guardrails/api_client.py b/guardrails/api_client.py index 7b825d537..6c16f8725 100644 --- a/guardrails/api_client.py +++ b/guardrails/api_client.py @@ -1,6 +1,6 @@ import json import os -from typing import Any, Iterable, Optional +from typing import Any, Iterator, Optional import requests from guardrails_api_client.configuration import Configuration @@ -80,7 +80,7 @@ def stream_validate( guard: Guard, payload: ValidatePayload, openai_api_key: Optional[str] = None, - ) -> Iterable[Any]: + ) -> Iterator[Any]: _openai_api_key = ( openai_api_key if openai_api_key is not None diff --git a/guardrails/async_guard.py b/guardrails/async_guard.py index e3f7ce658..9af740711 100644 --- a/guardrails/async_guard.py +++ b/guardrails/async_guard.py @@ -4,7 +4,7 @@ from opentelemetry import context as otel_context from typing import ( Any, - AsyncIterable, + AsyncIterator, Awaitable, Callable, Dict, @@ -37,6 +37,7 @@ set_tracer, set_tracer_context, ) +from guardrails.hub_telemetry.hub_tracing import async_trace from guardrails.types.pydantic import ModelOrListOfModels from guardrails.types.validator import UseManyValidatorSpec, UseValidatorSpec from guardrails.telemetry import trace_async_guard_execution, wrap_with_otel_context @@ -187,7 +188,7 @@ async def _execute( ) -> Union[ ValidationOutcome[OT], Awaitable[ValidationOutcome[OT]], - AsyncIterable[ValidationOutcome[OT]], + AsyncIterator[ValidationOutcome[OT]], ]: self._fill_validator_map() self._fill_validators() @@ -219,49 +220,13 @@ async def __exec( ) -> Union[ ValidationOutcome[OT], Awaitable[ValidationOutcome[OT]], - AsyncIterable[ValidationOutcome[OT]], + AsyncIterator[ValidationOutcome[OT]], ]: prompt_params = prompt_params or {} metadata = metadata or {} if full_schema_reask is None: full_schema_reask = self._base_model is not None - if self._allow_metrics_collection: - llm_api_str = "" - if llm_api: - llm_api_module_name = ( - llm_api.__module__ if hasattr(llm_api, "__module__") else "" - ) - llm_api_name = ( - llm_api.__name__ - if hasattr(llm_api, "__name__") - else type(llm_api).__name__ - ) - llm_api_str = f"{llm_api_module_name}.{llm_api_name}" - # Create a new span for this guard call - self._hub_telemetry.create_new_span( - span_name="/guard_call", - attributes=[ - ("guard_id", self.id), - ("user_id", self._user_id), - ("llm_api", llm_api_str), - ( - "custom_reask_prompt", - self._exec_opts.reask_prompt is not None, - ), - ( - "custom_reask_instructions", - self._exec_opts.reask_instructions is not None, - ), - ( - "custom_reask_messages", - self._exec_opts.reask_messages is not None, - ), - ], - is_parent=True, # It will have children - has_parent=False, # Has no parents - ) - set_call_kwargs(kwargs) set_tracer(self._tracer) set_tracer_context(self._tracer_context) @@ -369,7 +334,7 @@ async def _exec( ) -> Union[ ValidationOutcome[OT], Awaitable[ValidationOutcome[OT]], - AsyncIterable[ValidationOutcome[OT]], + AsyncIterator[ValidationOutcome[OT]], ]: """Call the LLM asynchronously and validate the output. @@ -435,6 +400,7 @@ async def _exec( ) return ValidationOutcome[OT].from_guard_history(call) + @async_trace(name="/guard_call", origin="AsyncGuard.__call__") async def __call__( self, llm_api: Optional[Callable[..., Awaitable[Any]]] = None, @@ -450,7 +416,7 @@ async def __call__( ) -> Union[ ValidationOutcome[OT], Awaitable[ValidationOutcome[OT]], - AsyncIterable[ValidationOutcome[OT]], + AsyncIterator[ValidationOutcome[OT]], ]: """Call the LLM and validate the output. Pass an async LLM API to return a coroutine. @@ -501,6 +467,7 @@ async def __call__( **kwargs, ) + @async_trace(name="/guard_call", origin="AsyncGuard.parse") async def parse( self, llm_output: str, @@ -567,7 +534,7 @@ async def parse( async def _stream_server_call( self, *, payload: Dict[str, Any] - ) -> AsyncIterable[ValidationOutcome[OT]]: + ) -> AsyncIterator[ValidationOutcome[OT]]: # TODO: Once server side supports async streaming, this function will need to # yield async generators, not generators if self._api_client: @@ -609,6 +576,7 @@ async def _stream_server_call( else: raise ValueError("AsyncGuard does not have an api client!") + @async_trace(name="/guard_call", origin="AsyncGuard.validate") async def validate( self, llm_output: str, *args, **kwargs ) -> Awaitable[ValidationOutcome[OT]]: diff --git a/guardrails/classes/__init__.py b/guardrails/classes/__init__.py index 16bf9ac0e..0b0a1700b 100644 --- a/guardrails/classes/__init__.py +++ b/guardrails/classes/__init__.py @@ -1,4 +1,5 @@ -from guardrails.classes.credentials import Credentials +from guardrails.classes.credentials import Credentials # type: ignore +from guardrails.classes.rc import RC from guardrails.classes.input_type import InputType from guardrails.classes.output_type import OT from guardrails.classes.validation.validation_result import ( @@ -10,7 +11,8 @@ from guardrails.classes.validation_outcome import ValidationOutcome __all__ = [ - "Credentials", + "Credentials", # type: ignore + "RC", "ErrorSpan", "InputType", "OT", diff --git a/guardrails/classes/credentials.py b/guardrails/classes/credentials.py index 3c251476c..b09206d71 100644 --- a/guardrails/classes/credentials.py +++ b/guardrails/classes/credentials.py @@ -1,21 +1,24 @@ import logging -import os from dataclasses import dataclass -from os.path import expanduser from typing import Optional +from typing_extensions import deprecated -from guardrails.classes.generic.serializeable import Serializeable +from guardrails.classes.generic.serializeable import SerializeableJSONEncoder +from guardrails.classes.rc import RC BOOL_CONFIGS = set(["no_metrics", "enable_metrics", "use_remote_inferencing"]) +@deprecated( + ( + "The `Credentials` class is deprecated and will be removed in version 0.6.x." + " Use the `RC` class instead." + ), + category=DeprecationWarning, +) @dataclass -class Credentials(Serializeable): - id: Optional[str] = None - token: Optional[str] = None +class Credentials(RC): no_metrics: Optional[bool] = False - enable_metrics: Optional[bool] = True - use_remote_inferencing: Optional[bool] = True @staticmethod def _to_bool(value: str) -> Optional[bool]: @@ -27,51 +30,16 @@ def _to_bool(value: str) -> Optional[bool]: @staticmethod def has_rc_file() -> bool: - home = expanduser("~") - guardrails_rc = os.path.join(home, ".guardrailsrc") - return os.path.exists(guardrails_rc) + return RC.exists() @staticmethod - def from_rc_file(logger: Optional[logging.Logger] = None) -> "Credentials": - try: - if not logger: - logger = logging.getLogger() - home = expanduser("~") - guardrails_rc = os.path.join(home, ".guardrailsrc") - with open(guardrails_rc, encoding="utf-8") as rc_file: - lines = rc_file.readlines() - filtered_lines = list(filter(lambda l: l.strip(), lines)) - creds = {} - for line in filtered_lines: - line_content = line.split("=", 1) - if len(line_content) != 2: - logger.warning( - """ - Invalid line found in .guardrailsrc file! - All lines in this file should follow the format: key=value - Ignoring line contents... - """ - ) - logger.debug(f".guardrailsrc file location: {guardrails_rc}") - else: - key, value = line_content - key = key.strip() - value = value.strip() - if key in BOOL_CONFIGS: - value = Credentials._to_bool(value) - - creds[key] = value - - rc_file.close() - - # backfill no_metrics, handle defaults - # remove in 0.5.0 - no_metrics_val = creds.pop("no_metrics", None) - if no_metrics_val is not None and creds.get("enable_metrics") is None: - creds["enable_metrics"] = not no_metrics_val - - creds_dict = Credentials.from_dict(creds) - return creds_dict - - except FileNotFoundError: - return Credentials.from_dict({}) # type: ignore + def from_rc_file(logger: Optional[logging.Logger] = None) -> "Credentials": # type: ignore + rc = RC.load(logger) + return Credentials( # type: ignore + id=rc.id, + token=rc.token, + enable_metrics=rc.enable_metrics, + use_remote_inferencing=rc.use_remote_inferencing, + no_metrics=(not rc.enable_metrics), + encoder=SerializeableJSONEncoder(), + ) diff --git a/guardrails/classes/llm/llm_response.py b/guardrails/classes/llm/llm_response.py index a5d4db5dd..94ef033f4 100644 --- a/guardrails/classes/llm/llm_response.py +++ b/guardrails/classes/llm/llm_response.py @@ -1,6 +1,6 @@ import asyncio from itertools import tee -from typing import Any, Dict, Iterable, Optional, AsyncIterable +from typing import Any, Dict, Iterator, Optional, AsyncIterator from guardrails_api_client import LLMResponse as ILLMResponse from pydantic.config import ConfigDict @@ -19,9 +19,9 @@ class LLMResponse(ILLMResponse): Attributes: output (str): The output from the LLM. - stream_output (Optional[Iterable]): A stream of output from the LLM. + stream_output (Optional[Iterator]): A stream of output from the LLM. Default None. - async_stream_output (Optional[AsyncIterable]): An async stream of output + async_stream_output (Optional[AsyncIterator]): An async stream of output from the LLM. Default None. prompt_token_count (Optional[int]): The number of tokens in the prompt. Default None. @@ -35,8 +35,8 @@ class LLMResponse(ILLMResponse): prompt_token_count: Optional[int] = None response_token_count: Optional[int] = None output: str - stream_output: Optional[Iterable] = None - async_stream_output: Optional[AsyncIterable] = None + stream_output: Optional[Iterator] = None + async_stream_output: Optional[AsyncIterator] = None def to_interface(self) -> ILLMResponse: stream_output = None @@ -73,7 +73,7 @@ def to_dict(self) -> Dict[str, Any]: def from_interface(cls, i_llm_response: ILLMResponse) -> "LLMResponse": stream_output = None if i_llm_response.stream_output: - stream_output = [so for so in i_llm_response.stream_output] + stream_output = iter([so for so in i_llm_response.stream_output]) async_stream_output = None if i_llm_response.async_stream_output: diff --git a/guardrails/classes/rc.py b/guardrails/classes/rc.py new file mode 100644 index 000000000..0d3543480 --- /dev/null +++ b/guardrails/classes/rc.py @@ -0,0 +1,71 @@ +import logging +import os +from dataclasses import dataclass +from os.path import expanduser +from typing import Optional + +from guardrails.classes.generic.serializeable import Serializeable +from guardrails.utils.casting_utils import to_bool + +BOOL_CONFIGS = set(["no_metrics", "enable_metrics", "use_remote_inferencing"]) + + +@dataclass +class RC(Serializeable): + id: Optional[str] = None + token: Optional[str] = None + enable_metrics: Optional[bool] = True + use_remote_inferencing: Optional[bool] = True + + @staticmethod + def exists() -> bool: + home = expanduser("~") + guardrails_rc = os.path.join(home, ".guardrailsrc") + return os.path.exists(guardrails_rc) + + @classmethod + def load(cls, logger: Optional[logging.Logger] = None) -> "RC": + try: + if not logger: + logger = logging.getLogger() + home = expanduser("~") + guardrails_rc = os.path.join(home, ".guardrailsrc") + with open(guardrails_rc, encoding="utf-8") as rc_file: + lines = rc_file.readlines() + filtered_lines = list(filter(lambda l: l.strip(), lines)) + config = {} + for line in filtered_lines: + line_content = line.split("=", 1) + if len(line_content) != 2: + logger.warning( + """ + Invalid line found in .guardrailsrc file! + All lines in this file should follow the format: key=value + Ignoring line contents... + """ + ) + logger.debug(f".guardrailsrc file location: {guardrails_rc}") + else: + key, value = line_content + key = key.strip() + value = value.strip() + if key in BOOL_CONFIGS: + value = to_bool(value) + + config[key] = value + + rc_file.close() + + # backfill no_metrics, handle defaults + # We missed this comment in the 0.5.0 release + # Making it a TODO for 0.6.0 + # TODO: remove in 0.6.0 + no_metrics_val = config.pop("no_metrics", None) + if no_metrics_val is not None and config.get("enable_metrics") is None: + config["enable_metrics"] = not no_metrics_val + + rc = cls.from_dict(config) + return rc + + except FileNotFoundError: + return cls.from_dict({}) # type: ignore diff --git a/guardrails/cli/configure.py b/guardrails/cli/configure.py index f1b9a8d08..a8739227a 100644 --- a/guardrails/cli/configure.py +++ b/guardrails/cli/configure.py @@ -6,7 +6,7 @@ import typer -from guardrails.classes.credentials import Credentials +from guardrails.settings import settings from guardrails.cli.guardrails import guardrails from guardrails.cli.logger import LEVELS, logger from guardrails.cli.hub.console import console @@ -46,7 +46,7 @@ def save_configuration_file( def _get_default_token() -> str: """Get the default token from the configuration file.""" - file_token = Credentials.from_rc_file(logger).token + file_token = settings.rc.token if file_token is None: return "" return file_token @@ -79,7 +79,8 @@ def configure( ), ): version_warnings_if_applicable(console) - trace_if_enabled("configure") + if settings.rc.exists(): + trace_if_enabled("configure") existing_token = _get_default_token() last4 = existing_token[-4:] if existing_token else "" diff --git a/guardrails/cli/create.py b/guardrails/cli/create.py index 183d74256..26ea830ca 100644 --- a/guardrails/cli/create.py +++ b/guardrails/cli/create.py @@ -10,12 +10,13 @@ from guardrails.cli.guardrails import guardrails as gr_cli from guardrails.cli.hub.template import get_template -from guardrails.cli.telemetry import trace_if_enabled +from guardrails.hub_telemetry.hub_tracing import trace console = Console() @gr_cli.command(name="create") +@trace(name="guardrails-cli/create") def create_command( validators: Optional[str] = typer.Option( default="", @@ -46,7 +47,6 @@ def create_command( help="Print out the validators to be installed without making any changes.", ), ): - trace_if_enabled("create") # fix pyright typing issue validators = cast(str, validators) filepath = check_filename(filepath) diff --git a/guardrails/cli/hub/create_validator.py b/guardrails/cli/hub/create_validator.py index 902461f36..613a4bb60 100644 --- a/guardrails/cli/hub/create_validator.py +++ b/guardrails/cli/hub/create_validator.py @@ -8,7 +8,7 @@ from guardrails.cli.hub.hub import hub_command from guardrails.cli.logger import LEVELS, logger -from guardrails.cli.telemetry import trace_if_enabled +from guardrails.hub_telemetry.hub_tracing import trace validator_template = Template( """ @@ -148,6 +148,7 @@ def test_failure_case(self): @hub_command.command(name="create-validator") +@trace(name="guardrails-cli/hub/create-validator") def create_validator( name: str = typer.Argument(help="The name for your validator."), filepath: str = typer.Argument( @@ -170,7 +171,6 @@ def create_validator( The template repository can be found here:\ https://github.com/guardrails-ai/validator-template """ - trace_if_enabled("hub/create-validator") logger.log(level=LEVELS.get("NOTICE") or 0, msg=disclaimer) package_name = snake_case(name) diff --git a/guardrails/cli/hub/install.py b/guardrails/cli/hub/install.py index 844b05a87..1a0cb5e16 100644 --- a/guardrails/cli/hub/install.py +++ b/guardrails/cli/hub/install.py @@ -5,12 +5,13 @@ from guardrails.cli.hub.hub import hub_command from guardrails.cli.logger import logger +from guardrails.hub_telemetry.hub_tracing import trace from guardrails.cli.hub.console import console -from guardrails.cli.telemetry import trace_if_enabled from guardrails.cli.version import version_warnings_if_applicable @hub_command.command() +@trace(name="guardrails-cli/hub/install") def install( package_uris: List[str] = typer.Argument( ..., @@ -32,7 +33,6 @@ def install( ), ): try: - trace_if_enabled("hub/install") from guardrails.hub.install import install_multiple def confirm(): diff --git a/guardrails/cli/hub/list.py b/guardrails/cli/hub/list.py index c11cf788e..2dcf0cd3d 100644 --- a/guardrails/cli/hub/list.py +++ b/guardrails/cli/hub/list.py @@ -3,14 +3,14 @@ from guardrails.cli.hub.hub import hub_command from guardrails.cli.hub.utils import get_site_packages_location -from guardrails.cli.telemetry import trace_if_enabled +from guardrails.hub_telemetry.hub_tracing import trace from .console import console @hub_command.command(name="list") +@trace(name="guardrails-cli/hub/list") def list(): """List all installed validators.""" - trace_if_enabled("hub/list") site_packages = get_site_packages_location() hub_init_file = os.path.join(site_packages, "guardrails", "hub", "__init__.py") diff --git a/guardrails/cli/hub/submit.py b/guardrails/cli/hub/submit.py index 9fa48a632..20af32af3 100644 --- a/guardrails/cli/hub/submit.py +++ b/guardrails/cli/hub/submit.py @@ -8,10 +8,11 @@ from guardrails.cli.hub.hub import hub_command from guardrails.cli.logger import LEVELS, logger from guardrails.cli.server.hub_client import HttpError, post_validator_submit -from guardrails.cli.telemetry import trace_if_enabled +from guardrails.hub_telemetry.hub_tracing import trace @hub_command.command(name="submit") +@trace(name="guardrails-cli/hub/submit") def submit( package_name: str = typer.Argument(help="The package name for your validator."), filepath: str = typer.Argument( @@ -21,7 +22,6 @@ def submit( """Submit a validator to the Guardrails AI team for review and publishing.""" try: - trace_if_enabled("hub/submit") if not filepath or filepath == "./{package_name}.py": filepath = f"./{package_name}.py" diff --git a/guardrails/cli/hub/uninstall.py b/guardrails/cli/hub/uninstall.py index 15313ac00..bd3722701 100644 --- a/guardrails/cli/hub/uninstall.py +++ b/guardrails/cli/hub/uninstall.py @@ -13,7 +13,7 @@ from guardrails.cli.hub.utils import get_site_packages_location from guardrails.cli.hub.utils import get_org_and_package_dirs from guardrails.cli.hub.utils import get_hub_directory -from guardrails.cli.telemetry import trace_if_enabled +from guardrails.hub_telemetry.hub_tracing import trace from .console import console @@ -75,13 +75,13 @@ def uninstall_hub_module(manifest: Manifest, site_packages: str): @hub_command.command() +@trace(name="guardrails-cli/hub/uninstall") def uninstall( package_uri: str = typer.Argument( help="URI to the package to uninstall. Example: hub://guardrails/regex_match." ), ): """Uninstall a validator from the Hub.""" - trace_if_enabled("hub/uninstall") if not package_uri.startswith("hub://"): logger.error("Invalid URI!") sys.exit(1) diff --git a/guardrails/cli/server/auth.py b/guardrails/cli/server/auth.py deleted file mode 100644 index 4394b2949..000000000 --- a/guardrails/cli/server/auth.py +++ /dev/null @@ -1,26 +0,0 @@ -# import http.client -# import json - -# from guardrails.classes.credentials import Credentials - - -# unused - for now -# def get_auth_token(creds: Credentials) -> str: -# if creds.client_id and creds.client_secret: -# audience = "https://validator-hub-service.guardrailsai.com" -# conn = http.client.HTTPSConnection("guardrailsai.us.auth0.com") -# payload = json.dumps( -# { -# "client_id": creds.client_id, -# "client_secret": creds.client_secret, -# "audience": audience, -# "grant_type": "client_credentials", -# } -# ) -# headers = {"content-type": "application/json"} -# conn.request("POST", "/oauth/token", payload, headers) - -# res = conn.getresponse() -# data = json.loads(res.read().decode("utf-8")) -# return data.get("access_token", "") -# return "" diff --git a/guardrails/cli/server/hub_client.py b/guardrails/cli/server/hub_client.py index 6e5950676..05ada5a5e 100644 --- a/guardrails/cli/server/hub_client.py +++ b/guardrails/cli/server/hub_client.py @@ -9,7 +9,8 @@ from jwt import ExpiredSignatureError, DecodeError -from guardrails.classes.credentials import Credentials +from guardrails.settings import settings +from guardrails.classes.rc import RC from guardrails.cli.logger import logger from guardrails.version import GUARDRAILS_VERSION @@ -86,8 +87,8 @@ def fetch_module_manifest( return fetch(manifest_url, token, anonymousUserId) -def get_jwt_token(creds: Credentials) -> Optional[str]: - token = creds.token +def get_jwt_token(rc: RC) -> Optional[str]: + token = rc.token # check for jwt expiration if token: @@ -101,23 +102,21 @@ def get_jwt_token(creds: Credentials) -> Optional[str]: def fetch_module(module_name: str) -> Optional[Manifest]: - creds = Credentials.from_rc_file(logger) - token = get_jwt_token(creds) + token = get_jwt_token(settings.rc) - module_manifest_json = fetch_module_manifest(module_name, token, creds.id) + module_manifest_json = fetch_module_manifest(module_name, token, settings.rc.id) return Manifest.from_dict(module_manifest_json) def fetch_template(template_address: str) -> Dict[str, Any]: - creds = Credentials.from_rc_file(logger) - token = get_jwt_token(creds) + token = get_jwt_token(settings.rc) namespace, template_name = template_address.replace("hub:template://", "").split( "/", 1 ) template_path = f"guard-templates/{namespace}/{template_name}" template_url = f"{VALIDATOR_HUB_SERVICE}/{template_path}" - return fetch(template_url, token, creds.id) + return fetch(template_url, token, settings.rc.id) # GET /guard-templates/{namespace}/{guardTemplateName} @@ -164,10 +163,9 @@ def get_validator_manifest(module_name: str): # GET /auth def get_auth(): try: - creds = Credentials.from_rc_file(logger) - token = get_jwt_token(creds) + token = get_jwt_token(settings.rc) auth_url = f"{VALIDATOR_HUB_SERVICE}/auth" - response = fetch(auth_url, token, creds.id) + response = fetch(auth_url, token, settings.rc.id) if not response: raise AuthenticationError("Failed to authenticate!") except HttpError as http_error: @@ -182,8 +180,7 @@ def get_auth(): def post_validator_submit(package_name: str, content: str): try: - creds = Credentials.from_rc_file(logger) - token = get_jwt_token(creds) + token = get_jwt_token(settings.rc) submission_url = f"{VALIDATOR_HUB_SERVICE}/validator/submit" headers = { diff --git a/guardrails/cli/telemetry.py b/guardrails/cli/telemetry.py index 0023d90d4..1d66adeaa 100644 --- a/guardrails/cli/telemetry.py +++ b/guardrails/cli/telemetry.py @@ -1,24 +1,13 @@ import platform -from typing import Optional -from guardrails.classes.credentials import Credentials +from guardrails.settings import settings from guardrails.utils.hub_telemetry_utils import HubTelemetry from guardrails.version import GUARDRAILS_VERSION -from guardrails.cli.logger import logger - -config: Optional[Credentials] = None - - -def load_config_file() -> Credentials: - global config - if not config: - config = Credentials.from_rc_file(logger) - return config def trace_if_enabled(command_name: str): - config = load_config_file() - if config.enable_metrics is True: + if settings.rc.enable_metrics is True: telemetry = HubTelemetry() + telemetry._enabled = True telemetry.create_new_span( f"guardrails-cli/{command_name}", [ @@ -30,6 +19,4 @@ def trace_if_enabled(command_name: str): ("machine", platform.machine()), ("processor", platform.processor()), ], - True, - False, ) diff --git a/guardrails/cli/validate.py b/guardrails/cli/validate.py index 6f8900200..958b4083a 100644 --- a/guardrails/cli/validate.py +++ b/guardrails/cli/validate.py @@ -5,7 +5,7 @@ from guardrails import Guard from guardrails.cli.guardrails import guardrails -from guardrails.cli.telemetry import trace_if_enabled +from guardrails.hub_telemetry.hub_tracing import trace def validate_llm_output(rail: str, llm_output: str) -> Union[str, Dict, List, None]: @@ -16,6 +16,7 @@ def validate_llm_output(rail: str, llm_output: str) -> Union[str, Dict, List, No @guardrails.command() +@trace(name="guardrails-cli/validate") def validate( rail: str = typer.Argument( ..., help="Path to the rail spec.", exists=True, file_okay=True, dir_okay=False @@ -29,7 +30,6 @@ def validate( ), ): """Validate the output of an LLM against a `rail` spec.""" - trace_if_enabled("validate") result = validate_llm_output(rail, llm_output) # Result is a dictionary, log it to a file print(result) diff --git a/guardrails/guard.py b/guardrails/guard.py index ac8198e5e..8f3fd4332 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -7,7 +7,7 @@ Callable, Dict, Generic, - Iterable, + Iterator, List, Optional, Sequence, @@ -32,9 +32,9 @@ from guardrails.api_client import GuardrailsApiClient from guardrails.classes.output_type import OT +from guardrails.classes.rc import RC from guardrails.classes.validation.validation_result import ErrorSpan from guardrails.classes.validation_outcome import ValidationOutcome -from guardrails.classes.credentials import Credentials from guardrails.classes.execution import GuardExecutionOptions from guardrails.classes.generic import Stack from guardrails.classes.history import Call @@ -64,6 +64,7 @@ set_tracer, set_tracer_context, ) +from guardrails.hub_telemetry.hub_tracing import trace from guardrails.types.on_fail import OnFailAction from guardrails.types.pydantic import ModelOrListOfModels from guardrails.utils.naming_utils import random_id @@ -259,6 +260,7 @@ def configure( self._set_num_reasks(num_reasks) if tracer: self._set_tracer(tracer) + self._load_rc() self._configure_hub_telemtry(allow_metrics_collection) def _set_num_reasks(self, num_reasks: Optional[int] = None) -> None: @@ -285,24 +287,28 @@ def _set_tracer(self, tracer: Optional[Tracer] = None) -> None: set_tracer_context() self._tracer_context = get_tracer_context() + def _load_rc(self) -> None: + rc = RC.load(logger) + settings.rc = rc + def _configure_hub_telemtry( self, allow_metrics_collection: Optional[bool] = None ) -> None: - credentials = None - if allow_metrics_collection is None: - credentials = Credentials.from_rc_file(logger) - # TODO: Check credentials.enable_metrics after merge from main - allow_metrics_collection = credentials.enable_metrics is True + allow_metrics_collection = ( + settings.rc.enable_metrics is True + if allow_metrics_collection is None + else allow_metrics_collection + ) self._allow_metrics_collection = allow_metrics_collection - if allow_metrics_collection: - if not credentials: - credentials = Credentials.from_rc_file(logger) - # Get unique id of user from credentials - self._user_id = credentials.id or "" - # Initialize Hub Telemetry singleton and get the tracer - self._hub_telemetry = HubTelemetry() + # Initialize Hub Telemetry singleton and get the tracer + self._hub_telemetry = HubTelemetry() + self._hub_telemetry._enabled = allow_metrics_collection + + if allow_metrics_collection is True: + # Get unique id of user from rc file + self._user_id = settings.rc.id or "" def _fill_validator_map(self): # dont init validators if were going to call the server @@ -692,7 +698,7 @@ def _execute( metadata: Optional[Dict], full_schema_reask: Optional[bool] = None, **kwargs, - ) -> Union[ValidationOutcome[OT], Iterable[ValidationOutcome[OT]]]: + ) -> Union[ValidationOutcome[OT], Iterator[ValidationOutcome[OT]]]: self._fill_validator_map() self._fill_validators() self._fill_exec_opts( @@ -735,42 +741,6 @@ def __exec( if full_schema_reask is None: full_schema_reask = self._base_model is not None - if self._allow_metrics_collection and self._hub_telemetry: - # Create a new span for this guard call - llm_api_str = "" - if llm_api: - llm_api_module_name = ( - llm_api.__module__ if hasattr(llm_api, "__module__") else "" - ) - llm_api_name = ( - llm_api.__name__ - if hasattr(llm_api, "__name__") - else type(llm_api).__name__ - ) - llm_api_str = f"{llm_api_module_name}.{llm_api_name}" - self._hub_telemetry.create_new_span( - span_name="/guard_call", - attributes=[ - ("guard_id", self.id), - ("user_id", self._user_id), - ("llm_api", llm_api_str if llm_api_str else "None"), - ( - "custom_reask_prompt", - self._exec_opts.reask_prompt is not None, - ), - ( - "custom_reask_instructions", - self._exec_opts.reask_instructions is not None, - ), - ( - "custom_reask_messages", - self._exec_opts.reask_messages is not None, - ), - ], - is_parent=True, # It will have children - has_parent=False, # Has no parents - ) - set_call_kwargs(kwargs) set_tracer(self._tracer) set_tracer_context(self._tracer_context) @@ -869,7 +839,7 @@ def _exec( instructions: Optional[str] = None, msg_history: Optional[List[Dict]] = None, **kwargs, - ) -> Union[ValidationOutcome[OT], Iterable[ValidationOutcome[OT]]]: + ) -> Union[ValidationOutcome[OT], Iterator[ValidationOutcome[OT]]]: api = None if llm_api is not None or kwargs.get("model") is not None: @@ -920,6 +890,7 @@ def _exec( call = runner(call_log=call_log, prompt_params=prompt_params) return ValidationOutcome[OT].from_guard_history(call) + @trace(name="/guard_call", origin="Guard.__call__") def __call__( self, llm_api: Optional[Callable] = None, @@ -932,7 +903,7 @@ def __call__( metadata: Optional[Dict] = None, full_schema_reask: Optional[bool] = None, **kwargs, - ) -> Union[ValidationOutcome[OT], Iterable[ValidationOutcome[OT]]]: + ) -> Union[ValidationOutcome[OT], Iterator[ValidationOutcome[OT]]]: """Call the LLM and validate the output. Args: @@ -978,6 +949,7 @@ def __call__( **kwargs, ) + @trace(name="/guard_call", origin="Guard.parse") def parse( self, llm_output: str, @@ -1154,6 +1126,7 @@ def use_many( self._save() return self + @trace(name="/guard_call", origin="Guard.validate") def validate(self, llm_output: str, *args, **kwargs) -> ValidationOutcome[OT]: return self.parse(llm_output=llm_output, *args, **kwargs) @@ -1216,7 +1189,7 @@ def _stream_server_call( self, *, payload: Dict[str, Any], - ) -> Iterable[ValidationOutcome[OT]]: + ) -> Iterator[ValidationOutcome[OT]]: if settings.use_server and self._api_client: validation_output: Optional[IValidationOutcome] = None response = self._api_client.stream_validate( @@ -1266,7 +1239,7 @@ def _call_server( metadata: Optional[Dict] = {}, full_schema_reask: Optional[bool] = True, **kwargs, - ) -> Union[ValidationOutcome[OT], Iterable[ValidationOutcome[OT]]]: + ) -> Union[ValidationOutcome[OT], Iterator[ValidationOutcome[OT]]]: if settings.use_server and self._api_client: payload: Dict[str, Any] = { "args": list(args), diff --git a/guardrails/hub/install.py b/guardrails/hub/install.py index bb104ae9d..046ee073f 100644 --- a/guardrails/hub/install.py +++ b/guardrails/hub/install.py @@ -6,7 +6,7 @@ ValidatorPackageService, ValidatorModuleType, ) -from guardrails.classes.credentials import Credentials +from guardrails.classes.rc import RC from guardrails.cli.hub.console import console from guardrails.cli.logger import LEVELS, logger as cli_logger @@ -61,7 +61,7 @@ def install( quiet_printer = console.print if not quiet else lambda x: None # 1. Validation - has_rc_file = Credentials.has_rc_file() + rc_file_exists = RC.exists() module_name = ValidatorPackageService.get_module_name(package_uri) installing_msg = f"Installing {package_uri}..." @@ -98,11 +98,10 @@ def install( ) try: - if has_rc_file: + if rc_file_exists: # if we do want to remote then we don't want to install local models use_remote_endpoint = ( - Credentials.from_rc_file(cli_logger).use_remote_inferencing - and module_has_endpoint + RC.load(cli_logger).use_remote_inferencing and module_has_endpoint ) elif install_local_models is None and module_has_endpoint: install_local_models = install_local_models_confirm() diff --git a/guardrails/hub_telemetry/__init__.py b/guardrails/hub_telemetry/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/guardrails/hub_telemetry/hub_tracing.py b/guardrails/hub_telemetry/hub_tracing.py new file mode 100644 index 000000000..a02b1b688 --- /dev/null +++ b/guardrails/hub_telemetry/hub_tracing.py @@ -0,0 +1,261 @@ +from functools import wraps +from typing import ( + Any, + Dict, + Optional, +) + +from opentelemetry.trace import Span +from opentelemetry.trace.propagation import set_span_in_context + +from guardrails.classes.validation.validation_result import ValidationResult +from guardrails.hub_token.token import VALIDATOR_HUB_SERVICE +from guardrails.types.primitives import PrimitiveTypes +from guardrails.utils.safe_get import safe_get +from guardrails.utils.hub_telemetry_utils import HubTelemetry + + +def get_guard_call_attributes( + attrs: Dict[str, Any], origin: str, *args, **kwargs +) -> Dict[str, Any]: + attrs["stream"] = kwargs.get("stream", False) + + guard_self = safe_get(args, 0) + if guard_self is not None: + attrs["guard_id"] = guard_self.id + attrs["user_id"] = guard_self._user_id + attrs["custom_reask_prompt"] = guard_self._exec_opts.reask_prompt is not None + attrs["custom_reask_instructions"] = ( + guard_self._exec_opts.reask_instructions is not None + ) + attrs["custom_reask_messages"] = ( + guard_self._exec_opts.reask_messages is not None + ) + attrs["output_type"] = ( + "unstructured" + if PrimitiveTypes.is_primitive( + guard_self.output_schema.type.actual_instance + ) + else "structured" + ) + return attrs + + llm_api_str = "" # noqa + llm_api = kwargs.get("llm_api") + if origin in ["Guard.__call__", "AsyncGuard.__call__"]: + llm_api = safe_get(args, 1, llm_api) + + if llm_api: + llm_api_module_name = ( + llm_api.__module__ if hasattr(llm_api, "__module__") else "" + ) + llm_api_name = ( + llm_api.__name__ if hasattr(llm_api, "__name__") else type(llm_api).__name__ + ) + llm_api_str = f"{llm_api_module_name}.{llm_api_name}" + attrs["llm_api"] = llm_api_str if llm_api_str else "None" + + return attrs + + +def get_validator_inference_attributes( + attrs: Dict[str, Any], *args, **kwargs +) -> Dict[str, Any]: + validator_self = safe_get(args, 0) + if validator_self is not None: + used_guardrails_endpoint = ( + VALIDATOR_HUB_SERVICE in validator_self.validation_endpoint + and not validator_self.use_local + ) + used_custom_endpoint = ( + not validator_self.use_local and not used_guardrails_endpoint + ) + attrs["validator_name"] = validator_self.rail_alias + attrs["used_remote_inference"] = not validator_self.use_local + attrs["used_local_inference"] = validator_self.use_local + attrs["used_guardrails_endpoint"] = used_guardrails_endpoint + attrs["used_custom_endpoint"] = used_custom_endpoint + return attrs + + +def get_validator_usage_attributes( + attrs: Dict[str, Any], response, *args, **kwargs +) -> Dict[str, Any]: + # We're wrapping a wrapped function, + # so the first arg is the validator service + validator_self = safe_get(args, 1) + if validator_self is not None: + attrs["validator_name"] = validator_self.rail_alias + attrs["validator_on_fail"] = validator_self.on_fail_descriptor + + if response is not None: + attrs["validator_result"] = ( + response.outcome if isinstance(response, ValidationResult) else None + ) + + return attrs + + +def add_attributes( + span: Span, + attrs: Dict[str, Any], + name: str, + origin: str, + *args, + response=None, + **kwargs, +): + attrs["origin"] = origin + if name == "/guard_call": + attrs = get_guard_call_attributes(attrs, origin, *args, **kwargs) + elif name == "/reasks": + if response is not None and hasattr(response, "iterations"): + attrs["reask_count"] = len(response.iterations) - 1 + else: + attrs["reask_count"] = 0 + elif name == "/validator_inference": + attrs = get_validator_inference_attributes(attrs, *args, **kwargs) + elif name == "/validator_usage": + attrs = get_validator_usage_attributes(attrs, response, *args, **kwargs) + + for key, value in attrs.items(): + if value is not None: + span.set_attribute(key, value) + + +def trace( + *, + name: str, + origin: Optional[str] = None, + **attrs, +): + def decorator(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + hub_telemetry = HubTelemetry() + if hub_telemetry._enabled and hub_telemetry._tracer is not None: + with hub_telemetry._tracer.start_span( + name, + context=hub_telemetry.extract_current_context(), + set_status_on_exception=True, + ) as span: # noqa + context = set_span_in_context(span) + hub_telemetry.inject_current_context(context=context) + nonlocal origin + origin = origin if origin is not None else name + + resp = fn(*args, **kwargs) + add_attributes( + span, attrs, name, origin, *args, response=resp, **kwargs + ) + return resp + else: + return fn(*args, **kwargs) + + return wrapper + + return decorator + + +def async_trace( + *, + name: str, + origin: Optional[str] = None, +): + def decorator(fn): + @wraps(fn) + async def async_wrapper(*args, **kwargs): + hub_telemetry = HubTelemetry() + if hub_telemetry._enabled and hub_telemetry._tracer is not None: + with hub_telemetry._tracer.start_span( + name, + context=hub_telemetry.extract_current_context(), + set_status_on_exception=True, + ) as span: # noqa + context = set_span_in_context(span) + hub_telemetry.inject_current_context(context=context) + + nonlocal origin + origin = origin if origin is not None else name + add_attributes(span, {"async": True}, name, origin, *args, **kwargs) + return await fn(*args, **kwargs) + else: + return await fn(*args, **kwargs) + + return async_wrapper + + return decorator + + +def _run_gen(fn, *args, **kwargs): + gen = fn(*args, **kwargs) + for item in gen: + yield item + + +def trace_stream( + *, + name: str, + origin: Optional[str] = None, + **attrs, +): + def decorator(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + hub_telemetry = HubTelemetry() + if hub_telemetry._enabled and hub_telemetry._tracer is not None: + with hub_telemetry._tracer.start_span( + name, + context=hub_telemetry.extract_current_context(), + set_status_on_exception=True, + ) as span: # noqa + context = set_span_in_context(span) + hub_telemetry.inject_current_context(context=context) + + nonlocal origin + origin = origin if origin is not None else name + add_attributes(span, attrs, name, origin, *args, **kwargs) + return _run_gen(fn, *args, **kwargs) + else: + return fn(*args, **kwargs) + + return wrapper + + return decorator + + +async def _run_async_gen(fn, *args, **kwargs): + gen = fn(*args, **kwargs) + async for item in gen: + yield item + + +def async_trace_stream( + *, + name: str, + origin: Optional[str] = None, + **attrs, +): + def decorator(fn): + @wraps(fn) + async def wrapper(*args, **kwargs): + hub_telemetry = HubTelemetry() + if hub_telemetry._enabled and hub_telemetry._tracer is not None: + with hub_telemetry._tracer.start_span( + name, + context=hub_telemetry.extract_current_context(), + set_status_on_exception=True, + ) as span: # noqa + context = set_span_in_context(span) + hub_telemetry.inject_current_context(context=context) + + nonlocal origin + origin = origin if origin is not None else name + add_attributes(span, attrs, name, origin, *args, **kwargs) + return _run_async_gen(fn, *args, **kwargs) + else: + return fn(*args, **kwargs) + + return wrapper + + return decorator diff --git a/guardrails/hub_token/token.py b/guardrails/hub_token/token.py index 2b413e5d8..b447804d9 100644 --- a/guardrails/hub_token/token.py +++ b/guardrails/hub_token/token.py @@ -1,9 +1,10 @@ import os -from guardrails.classes.credentials import Credentials import jwt from jwt import ExpiredSignatureError, DecodeError from typing import Optional +from guardrails.classes.rc import RC + FIND_NEW_TOKEN = "You can find a new token at https://hub.guardrailsai.com/keys" TOKEN_EXPIRED_MESSAGE = f"""Your token has expired. Please run `guardrails configure`\ @@ -36,8 +37,8 @@ class HttpError(Exception): ) -def get_jwt_token(creds: Credentials) -> Optional[str]: - token = creds.token +def get_jwt_token(rc: RC) -> Optional[str]: + token = rc.token # check for jwt expiration if token: diff --git a/guardrails/integrations/databricks/ml_flow_instrumentor.py b/guardrails/integrations/databricks/ml_flow_instrumentor.py index 38d458f46..8bed264ba 100644 --- a/guardrails/integrations/databricks/ml_flow_instrumentor.py +++ b/guardrails/integrations/databricks/ml_flow_instrumentor.py @@ -3,12 +3,11 @@ import sys from typing import ( Any, - AsyncIterable, + AsyncIterator, Awaitable, Callable, Coroutine, - Generator, - Iterable, + Iterator, Union, ) @@ -107,13 +106,13 @@ def instrument(self): def _instrument_guard( self, guard_execute: Callable[ - ..., Union[ValidationOutcome[OT], Iterable[ValidationOutcome[OT]]] + ..., Union[ValidationOutcome[OT], Iterator[ValidationOutcome[OT]]] ], ): @wraps(guard_execute) def _guard_execute_wrapper( *args, **kwargs - ) -> Union[ValidationOutcome[OT], Iterable[ValidationOutcome[OT]]]: + ) -> Union[ValidationOutcome[OT], Iterator[ValidationOutcome[OT]]]: with mlflow.start_span( name="guardrails/guard", span_type="guard", @@ -131,7 +130,7 @@ def _guard_execute_wrapper( try: result = guard_execute(*args, **kwargs) - if isinstance(result, Iterable) and not isinstance( + if isinstance(result, Iterator) and not isinstance( result, ValidationOutcome ): return trace_stream_guard(guard_span, result, history) # type: ignore @@ -153,7 +152,7 @@ def _instrument_async_guard( Union[ ValidationOutcome[OT], Awaitable[ValidationOutcome[OT]], - AsyncIterable[ValidationOutcome[OT]], + AsyncIterator[ValidationOutcome[OT]], ], ], ], @@ -164,7 +163,7 @@ async def _async_guard_execute_wrapper( ) -> Union[ ValidationOutcome[OT], Awaitable[ValidationOutcome[OT]], - AsyncIterable[ValidationOutcome[OT]], + AsyncIterator[ValidationOutcome[OT]], ]: with mlflow.start_span( name="guardrails/guard", @@ -184,7 +183,7 @@ async def _async_guard_execute_wrapper( try: result = await guard_execute(*args, **kwargs) - if isinstance(result, AsyncIterable): + if isinstance(result, AsyncIterator): return trace_async_stream_guard(guard_span, result, history) # type: ignore res = result if inspect.isawaitable(result): @@ -220,12 +219,12 @@ def trace_step_wrapper(*args, **kwargs) -> Iteration: return trace_step_wrapper def _instrument_stream_runner_step( - self, runner_step: Callable[..., Generator[ValidationOutcome[OT], None, None]] + self, runner_step: Callable[..., Iterator[ValidationOutcome[OT]]] ): @wraps(runner_step) def trace_stream_step_wrapper( *args, **kwargs - ) -> Generator[ValidationOutcome[OT], None, None]: + ) -> Iterator[ValidationOutcome[OT]]: with mlflow.start_span( name="guardrails/guard/step", span_type="step", @@ -283,12 +282,12 @@ async def trace_async_step_wrapper(*args, **kwargs) -> Iteration: return trace_async_step_wrapper def _instrument_async_stream_runner_step( - self, runner_step: Callable[..., AsyncIterable[ValidationOutcome[OT]]] - ) -> Callable[..., AsyncIterable[ValidationOutcome[OT]]]: + self, runner_step: Callable[..., AsyncIterator[ValidationOutcome[OT]]] + ) -> Callable[..., AsyncIterator[ValidationOutcome[OT]]]: @wraps(runner_step) async def trace_async_stream_step_wrapper( *args, **kwargs - ) -> AsyncIterable[ValidationOutcome[OT]]: + ) -> AsyncIterator[ValidationOutcome[OT]]: with mlflow.start_span( name="guardrails/guard/step", span_type="step", diff --git a/guardrails/llm_providers.py b/guardrails/llm_providers.py index b48bf031f..be71a62ec 100644 --- a/guardrails/llm_providers.py +++ b/guardrails/llm_providers.py @@ -6,7 +6,7 @@ Awaitable, Callable, Dict, - Iterable, + Iterator, List, Optional, Type, @@ -506,7 +506,7 @@ def _invoke_llm( if kwargs.get("stream", False): # If stream is defined and set to True, # the callable returns a generator object - llm_response = cast(Iterable[str], response) + llm_response = cast(Iterator[str], response) return LLMResponse( output="", stream_output=llm_response, @@ -774,7 +774,7 @@ def _invoke_llm(self, *args, **kwargs) -> LLMResponse: if kwargs.get("stream", False): # If stream is defined and set to True, # the callable returns a generator object - llm_response = cast(Iterable[str], llm_response) + llm_response = cast(Iterator[str], llm_response) return LLMResponse( output="", stream_output=llm_response, @@ -1101,7 +1101,7 @@ async def invoke_llm( if kwargs.get("stream", False): # If stream is defined and set to True, # the callable returns a generator object - # response = cast(AsyncIterable[str], response) + # response = cast(AsyncIterator[str], response) return LLMResponse( output="", async_stream_output=response.completion_stream, # pyright: ignore[reportGeneralTypeIssues] diff --git a/guardrails/remote_inference/remote_inference.py b/guardrails/remote_inference/remote_inference.py index 3303ef63d..78e911e85 100644 --- a/guardrails/remote_inference/remote_inference.py +++ b/guardrails/remote_inference/remote_inference.py @@ -1,19 +1,19 @@ from typing import Optional -from guardrails.classes.credentials import Credentials +from guardrails.classes.rc import RC # TODO: Consolidate with telemetry switches -def get_use_remote_inference(creds: Credentials) -> Optional[bool]: - """Load the use_remote_inferencing setting from the credentials. +def get_use_remote_inference(rc: RC) -> Optional[bool]: + """Load the use_remote_inferencing setting from the rc file. Args: - creds (Credentials): The credentials object. + rc (RC): The rc settings. Returns: Optional[bool]: The use_remote_inferencing setting, or None if not found. """ try: - use_remote_inferencing = creds.use_remote_inferencing + use_remote_inferencing = rc.use_remote_inferencing if isinstance(use_remote_inferencing, str): return use_remote_inferencing.lower() == "true" elif isinstance(use_remote_inferencing, bool): diff --git a/guardrails/run/async_runner.py b/guardrails/run/async_runner.py index e71d68b00..420b2be37 100644 --- a/guardrails/run/async_runner.py +++ b/guardrails/run/async_runner.py @@ -1,6 +1,6 @@ import copy from functools import partial -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Tuple, cast from guardrails import validator_service @@ -9,12 +9,14 @@ from guardrails.classes.output_type import OutputTypes from guardrails.constants import fail_status from guardrails.errors import ValidationError -from guardrails.llm_providers import AsyncPromptCallableBase, PromptCallableBase +from guardrails.llm_providers import AsyncPromptCallableBase from guardrails.logger import set_scope from guardrails.prompt import Instructions, Prompt from guardrails.run.runner import Runner from guardrails.run.utils import msg_history_source, msg_history_string from guardrails.schema.validator import schema_validation +from guardrails.hub_telemetry.hub_tracing import async_trace +from guardrails.types.inputs import MessageHistory from guardrails.types.pydantic import ModelOrListOfModels from guardrails.types.validator import ValidatorMap from guardrails.utils.exception_utils import UserFacingException @@ -60,10 +62,11 @@ def __init__( disable_tracer=disable_tracer, exec_options=exec_options, ) - self.api: Optional[AsyncPromptCallableBase] = api + self.api = api # TODO: Refactor this to use inheritance and overrides # Why are we using a different method here instead of just overriding? + @async_trace(name="/reasks", origin="AsyncRunner.async_run") async def async_run( self, call_log: Call, prompt_params: Optional[Dict] = None ) -> Call: @@ -129,15 +132,6 @@ async def async_run( include_instructions=include_instructions, ) - # Log how many times we reasked - # Use the HubTelemetry singleton - if not self._disable_tracer: - self._hub_telemetry.create_new_span( - span_name="/reasks", - attributes=[("reask_count", index)], - is_parent=False, # This span has no children - has_parent=True, # This span has a parent - ) except UserFacingException as e: # Because Pydantic v1 doesn't respect property setters call_log.exception = e.original_exception @@ -150,6 +144,7 @@ async def async_run( return call_log # TODO: Refactor this to use inheritance and overrides + @async_trace(name="/step", origin="AsyncRunner.async_step") @trace_async_step async def async_step( self, @@ -247,6 +242,7 @@ async def async_step( return iteration # TODO: Refactor this to use inheritance and overrides + @async_trace(name="/llm_call", origin="AsyncRunner.async_call") @trace_async_call async def async_call( self, @@ -285,6 +281,7 @@ async def async_call( return llm_response # TODO: Refactor this to use inheritance and overrides + @async_trace(name="/validation", origin="AsyncRunner.async_validate") async def async_validate( self, iteration: Iteration, @@ -323,6 +320,7 @@ async def async_validate( return validated_output # TODO: Refactor this to use inheritance and overrides + @async_trace(name="/input_prep", origin="AsyncRunner.async_prepare") async def async_prepare( self, call_log: Call, @@ -332,7 +330,7 @@ async def async_prepare( prompt: Optional[Prompt], msg_history: Optional[List[Dict]], prompt_params: Optional[Dict] = None, - api: Optional[Union[PromptCallableBase, AsyncPromptCallableBase]], + api: Optional[AsyncPromptCallableBase], ) -> Tuple[Optional[Instructions], Optional[Prompt], Optional[List[Dict]]]: """Prepare by running pre-processing and input validation. @@ -340,6 +338,8 @@ async def async_prepare( The instructions, prompt, and message history. """ prompt_params = prompt_params or {} + if api is None: + raise UserFacingException(ValueError("API must be provided.")) has_prompt_validation = "prompt" in self.validation_map has_instructions_validation = "instructions" in self.validation_map @@ -356,45 +356,12 @@ async def async_prepare( prompt, instructions = None, None # Runner.prepare_msg_history - formatted_msg_history = [] - - # Format any variables in the message history with the prompt params. - for msg in msg_history: - msg_copy = copy.deepcopy(msg) - msg_copy["content"] = msg_copy["content"].format(**prompt_params) - formatted_msg_history.append(msg_copy) - - if "msg_history" in self.validation_map: - # Runner.validate_msg_history - msg_str = msg_history_string(formatted_msg_history) - inputs = Inputs( - llm_output=msg_str, - ) - iteration = Iteration( - call_id=call_log.id, index=attempt_number, inputs=inputs - ) - call_log.iterations.insert(0, iteration) - value, _metadata = await validator_service.async_validate( - value=msg_str, - metadata=self.metadata, - validator_map=self.validation_map, - iteration=iteration, - disable_tracer=self._disable_tracer, - path="msg_history", - ) - validated_msg_history = validator_service.post_process_validation( - value, attempt_number, iteration, OutputTypes.STRING - ) - validated_msg_history = cast(str, validated_msg_history) - - iteration.outputs.validation_response = validated_msg_history - if isinstance(validated_msg_history, ReAsk): - raise ValidationError( - f"Message history validation failed: " - f"{validated_msg_history}" - ) - if validated_msg_history != msg_str: - raise ValidationError("Message history validation failed") + msg_history = await self.prepare_msg_history( + call_log=call_log, + msg_history=msg_history, + prompt_params=prompt_params, + attempt_number=attempt_number, + ) elif prompt is not None: if has_msg_history_validation: raise UserFacingException( @@ -405,87 +372,164 @@ async def async_prepare( ) msg_history = None - use_xml = prompt_uses_xml(prompt._source) - # Runner.prepare_prompt - prompt = prompt.format(**prompt_params) - - # TODO(shreya): should there be any difference - # to parsing params for prompt? - if instructions is not None and isinstance(instructions, Instructions): - instructions = instructions.format(**prompt_params) - - instructions, prompt = preprocess_prompt( - prompt_callable=api, # type: ignore - instructions=instructions, - prompt=prompt, - output_type=self.output_type, - use_xml=use_xml, + instructions, prompt = await self.prepare_prompt( + call_log, instructions, prompt, prompt_params, api, attempt_number ) - # validate prompt - if "prompt" in self.validation_map and prompt is not None: - # Runner.validate_prompt - inputs = Inputs( - llm_output=prompt.source, - ) - iteration = Iteration( - call_id=call_log.id, index=attempt_number, inputs=inputs - ) - call_log.iterations.insert(0, iteration) - value, _metadata = await validator_service.async_validate( - value=prompt.source, - metadata=self.metadata, - validator_map=self.validation_map, - iteration=iteration, - disable_tracer=self._disable_tracer, - path="prompt", - ) - validated_prompt = validator_service.post_process_validation( - value, attempt_number, iteration, OutputTypes.STRING - ) - - iteration.outputs.validation_response = validated_prompt - if isinstance(validated_prompt, ReAsk): - raise ValidationError( - f"Prompt validation failed: {validated_prompt}" - ) - elif not validated_prompt or iteration.status == fail_status: - raise ValidationError("Prompt validation failed") - prompt = Prompt(cast(str, validated_prompt)) - - # validate instructions - if "instructions" in self.validation_map and instructions is not None: - # Runner.validate_instructions - inputs = Inputs( - llm_output=instructions.source, - ) - iteration = Iteration( - call_id=call_log.id, index=attempt_number, inputs=inputs - ) - call_log.iterations.insert(0, iteration) - value, _metadata = await validator_service.async_validate( - value=instructions.source, - metadata=self.metadata, - validator_map=self.validation_map, - iteration=iteration, - disable_tracer=self._disable_tracer, - path="instructions", - ) - validated_instructions = validator_service.post_process_validation( - value, attempt_number, iteration, OutputTypes.STRING - ) - - iteration.outputs.validation_response = validated_instructions - if isinstance(validated_instructions, ReAsk): - raise ValidationError( - f"Instructions validation failed: {validated_instructions}" - ) - elif not validated_instructions or iteration.status == fail_status: - raise ValidationError("Instructions validation failed") - instructions = Instructions(cast(str, validated_instructions)) else: raise UserFacingException( ValueError("'prompt' or 'msg_history' must be provided.") ) return instructions, prompt, msg_history + + async def prepare_msg_history( + self, + call_log: Call, + msg_history: MessageHistory, + prompt_params: Dict, + attempt_number: int, + ) -> MessageHistory: + formatted_msg_history = [] + + # Format any variables in the message history with the prompt params. + for msg in msg_history: + msg_copy = copy.deepcopy(msg) + msg_copy["content"] = msg_copy["content"].format(**prompt_params) + formatted_msg_history.append(msg_copy) + + if "msg_history" in self.validation_map: + await self.validate_msg_history( + call_log, formatted_msg_history, attempt_number + ) + + return formatted_msg_history + + @async_trace(name="/input_validation", origin="AsyncRunner.validate_msg_history") + async def validate_msg_history( + self, call_log: Call, msg_history: MessageHistory, attempt_number: int + ): + msg_str = msg_history_string(msg_history) + inputs = Inputs( + llm_output=msg_str, + ) + iteration = Iteration(call_id=call_log.id, index=attempt_number, inputs=inputs) + call_log.iterations.insert(0, iteration) + value, _metadata = await validator_service.async_validate( + value=msg_str, + metadata=self.metadata, + validator_map=self.validation_map, + iteration=iteration, + disable_tracer=self._disable_tracer, + path="msg_history", + ) + validated_msg_history = validator_service.post_process_validation( + value, attempt_number, iteration, OutputTypes.STRING + ) + validated_msg_history = cast(str, validated_msg_history) + + iteration.outputs.validation_response = validated_msg_history + if isinstance(validated_msg_history, ReAsk): + raise ValidationError( + f"Message history validation failed: " f"{validated_msg_history}" + ) + if validated_msg_history != msg_str: + raise ValidationError("Message history validation failed") + + async def prepare_prompt( + self, + call_log: Call, + instructions: Optional[Instructions], + prompt: Prompt, + prompt_params: Dict, + api: AsyncPromptCallableBase, + attempt_number: int, + ): + use_xml = prompt_uses_xml(prompt._source) + prompt = prompt.format(**prompt_params) + + # TODO(shreya): should there be any difference + # to parsing params for prompt? + if instructions is not None and isinstance(instructions, Instructions): + instructions = instructions.format(**prompt_params) + + instructions, prompt = preprocess_prompt( + prompt_callable=api, # type: ignore + instructions=instructions, + prompt=prompt, + output_type=self.output_type, + use_xml=use_xml, + ) + + # validate prompt + if "prompt" in self.validation_map and prompt is not None: + prompt = await self.validate_prompt(call_log, prompt, attempt_number) + + # validate instructions + if "instructions" in self.validation_map and instructions is not None: + instructions = await self.validate_instructions( + call_log, instructions, attempt_number + ) + + return instructions, prompt + + @async_trace(name="/input_validation", origin="AsyncRunner.validate_prompt") + async def validate_prompt( + self, + call_log: Call, + prompt: Prompt, + attempt_number: int, + ): + inputs = Inputs( + llm_output=prompt.source, + ) + iteration = Iteration(call_id=call_log.id, index=attempt_number, inputs=inputs) + call_log.iterations.insert(0, iteration) + value, _metadata = await validator_service.async_validate( + value=prompt.source, + metadata=self.metadata, + validator_map=self.validation_map, + iteration=iteration, + disable_tracer=self._disable_tracer, + path="prompt", + ) + validated_prompt = validator_service.post_process_validation( + value, attempt_number, iteration, OutputTypes.STRING + ) + + iteration.outputs.validation_response = validated_prompt + if isinstance(validated_prompt, ReAsk): + raise ValidationError(f"Prompt validation failed: {validated_prompt}") + elif not validated_prompt or iteration.status == fail_status: + raise ValidationError("Prompt validation failed") + return Prompt(cast(str, validated_prompt)) + + @async_trace(name="/input_validation", origin="AsyncRunner.validate_instructions") + async def validate_instructions( + self, call_log: Call, instructions: Instructions, attempt_number: int + ): + inputs = Inputs( + llm_output=instructions.source, + ) + iteration = Iteration(call_id=call_log.id, index=attempt_number, inputs=inputs) + call_log.iterations.insert(0, iteration) + value, _metadata = await validator_service.async_validate( + value=instructions.source, + metadata=self.metadata, + validator_map=self.validation_map, + iteration=iteration, + disable_tracer=self._disable_tracer, + path="instructions", + ) + validated_instructions = validator_service.post_process_validation( + value, attempt_number, iteration, OutputTypes.STRING + ) + + iteration.outputs.validation_response = validated_instructions + if isinstance(validated_instructions, ReAsk): + raise ValidationError( + f"Instructions validation failed: {validated_instructions}" + ) + elif not validated_instructions or iteration.status == fail_status: + raise ValidationError("Instructions validation failed") + return Instructions(cast(str, validated_instructions)) diff --git a/guardrails/run/async_stream_runner.py b/guardrails/run/async_stream_runner.py index 5285a9634..8f39c21f2 100644 --- a/guardrails/run/async_stream_runner.py +++ b/guardrails/run/async_stream_runner.py @@ -1,6 +1,6 @@ from typing import ( Any, - AsyncIterable, + AsyncIterator, Dict, List, Optional, @@ -27,12 +27,14 @@ from guardrails.run import StreamRunner from guardrails.run.async_runner import AsyncRunner from guardrails.telemetry import trace_async_stream_step +from guardrails.hub_telemetry.hub_tracing import async_trace_stream class AsyncStreamRunner(AsyncRunner, StreamRunner): + # @async_trace_stream(name="/reasks", origin="AsyncStreamRunner.async_run") async def async_run( self, call_log: Call, prompt_params: Optional[Dict] = None - ) -> AsyncIterable[ValidationOutcome]: + ) -> AsyncIterator[ValidationOutcome]: prompt_params = prompt_params or {} ( @@ -63,6 +65,7 @@ async def async_run( async for call in result: yield call + @async_trace_stream(name="/step", origin="AsyncStreamRunner.async_step") @trace_async_stream_step async def async_step( self, @@ -76,7 +79,7 @@ async def async_step( msg_history: Optional[List[Dict]] = None, prompt_params: Optional[Dict] = None, output: Optional[str] = None, - ) -> AsyncIterable[ValidationOutcome]: + ) -> AsyncIterator[ValidationOutcome]: prompt_params = prompt_params or {} inputs = Inputs( llm_api=api, diff --git a/guardrails/run/runner.py b/guardrails/run/runner.py index b62bb64ca..fe80f7941 100644 --- a/guardrails/run/runner.py +++ b/guardrails/run/runner.py @@ -19,6 +19,7 @@ from guardrails.run.utils import msg_history_source, msg_history_string from guardrails.schema.rail_schema import json_schema_to_rail_output from guardrails.schema.validator import schema_validation +from guardrails.hub_telemetry.hub_tracing import trace from guardrails.types import ModelOrListOfModels, ValidatorMap, MessageHistory from guardrails.utils.exception_utils import UserFacingException from guardrails.utils.hub_telemetry_utils import HubTelemetry @@ -154,10 +155,11 @@ def __init__( # Get metrics opt-out from credentials self._disable_tracer = disable_tracer - if not self._disable_tracer: - # Get the HubTelemetry singleton - self._hub_telemetry = HubTelemetry() + # Get the HubTelemetry singleton + self._hub_telemetry = HubTelemetry() + self._hub_telemetry._enabled = not self._disable_tracer + @trace(name="/reasks", origin="Runner.__call__") def __call__(self, call_log: Call, prompt_params: Optional[Dict] = None) -> Call: """Execute the runner by repeatedly calling step until the reask budget is exhausted. @@ -222,16 +224,6 @@ def __call__(self, call_log: Call, prompt_params: Optional[Dict] = None) -> Call ) ) - # Log how many times we reasked - # Use the HubTelemetry singleton - if not self._disable_tracer: - self._hub_telemetry.create_new_span( - span_name="/reasks", - attributes=[("reask_count", index)], - is_parent=False, # This span has no children - has_parent=True, # This span has a parent - ) - except UserFacingException as e: # Because Pydantic v1 doesn't respect property setters call_log.exception = e.original_exception @@ -242,6 +234,7 @@ def __call__(self, call_log: Call, prompt_params: Optional[Dict] = None) -> Call raise e return call_log + @trace(name="/step", origin="Runner.step") @trace_step def step( self, @@ -335,6 +328,7 @@ def step( raise e return iteration + @trace(name="/input_validation", origin="Runner.validate_msg_history") def validate_msg_history( self, call_log: Call, msg_history: MessageHistory, attempt_number: int ) -> None: @@ -384,6 +378,7 @@ def prepare_msg_history( return formatted_msg_history + @trace(name="/input_validation", origin="Runner.validate_prompt") def validate_prompt(self, call_log: Call, prompt: Prompt, attempt_number: int): inputs = Inputs( llm_output=prompt.source, @@ -411,6 +406,7 @@ def validate_prompt(self, call_log: Call, prompt: Prompt, attempt_number: int): raise ValidationError("Prompt validation failed") return Prompt(cast(str, validated_prompt)) + @trace(name="/input_validation", origin="Runner.validate_instructions") def validate_instructions( self, call_log: Call, instructions: Instructions, attempt_number: int ): @@ -477,6 +473,7 @@ def prepare_prompt( return instructions, prompt + @trace(name="/input_prep", origin="Runner.prepare") def prepare( self, call_log: Call, @@ -527,6 +524,7 @@ def prepare( return instructions, prompt, msg_history + @trace(name="/llm_call", origin="Runner.call") @trace_call def call( self, @@ -572,6 +570,7 @@ def parse(self, output: str, output_schema: Dict[str, Any], **kwargs): parsed_output = coerce_types(parsed_output, output_schema) return parsed_output, error + @trace(name="/validation", origin="Runner.validate") def validate( self, iteration: Iteration, diff --git a/guardrails/run/stream_runner.py b/guardrails/run/stream_runner.py index 2e041abf4..2f1272c6b 100644 --- a/guardrails/run/stream_runner.py +++ b/guardrails/run/stream_runner.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Union, cast +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast from guardrails import validator_service from guardrails.classes.history import Call, Inputs, Iteration, Outputs @@ -12,6 +12,7 @@ ) from guardrails.prompt import Instructions, Prompt from guardrails.run.runner import Runner +from guardrails.hub_telemetry.hub_tracing import trace_stream from guardrails.utils.parsing_utils import ( coerce_types, parse_llm_output, @@ -30,9 +31,10 @@ class StreamRunner(Runner): similar. """ + @trace_stream(name="/reasks", origin="StreamRunner.__call__") def __call__( self, call_log: Call, prompt_params: Optional[Dict] = {} - ) -> Generator[ValidationOutcome[OT], None, None]: + ) -> Iterator[ValidationOutcome[OT]]: """Execute the StreamRunner. Args: @@ -74,6 +76,7 @@ def __call__( call_log=call_log, ) + @trace_stream(name="/step", origin="StreamRunner.step") @trace_stream_step def step( self, @@ -86,7 +89,7 @@ def step( output_schema: Dict[str, Any], call_log: Call, output: Optional[str] = None, - ) -> Generator[ValidationOutcome[OT], None, None]: + ) -> Iterator[ValidationOutcome[OT]]: """Run a full step.""" inputs = Inputs( llm_api=api, @@ -148,7 +151,7 @@ def step( # for now, handle string and json schema differently if self.output_type == OutputTypes.STRING: - def prepare_chunk_generator(stream) -> Iterable[Tuple[Any, bool]]: + def prepare_chunk_generator(stream) -> Iterator[Tuple[Any, bool]]: for chunk in stream: chunk_text = self.get_chunk_text(chunk, api) nonlocal fragment @@ -292,11 +295,7 @@ def get_chunk_text(self, chunk: Any, api: Union[PromptCallableBase, None]) -> st return chunk_text def parse( - self, - output: str, - output_schema: Dict[str, Any], - *, - verified: set, + self, output: str, output_schema: Dict[str, Any], *, verified: set, **kwargs ): """Parse the output.""" parsed_output, error = parse_llm_output( diff --git a/guardrails/settings.py b/guardrails/settings.py index ad21d724c..e3c2bcfab 100644 --- a/guardrails/settings.py +++ b/guardrails/settings.py @@ -1,10 +1,13 @@ import threading from typing import Optional +from guardrails.classes.rc import RC + class Settings: _instance = None _lock = threading.Lock() + _rc: RC """Whether to use a local server for running Guardrails.""" use_server: Optional[bool] """Whether to disable tracing. @@ -25,6 +28,17 @@ def __new__(cls) -> "Settings": def _initialize(self): self.use_server = None self.disable_tracing = None + self._rc = RC.load() + + @property + def rc(self) -> RC: + if self._rc is None: + self._rc = RC.load() + return self._rc + + @rc.setter + def rc(self, value: RC): + self._rc = value settings = Settings() diff --git a/guardrails/telemetry/guard_tracing.py b/guardrails/telemetry/guard_tracing.py index 57e5a4614..31e074b1b 100644 --- a/guardrails/telemetry/guard_tracing.py +++ b/guardrails/telemetry/guard_tracing.py @@ -1,11 +1,11 @@ import inspect from typing import ( Any, - AsyncIterable, + AsyncIterator, Awaitable, Callable, Coroutine, - Iterable, + Iterator, Optional, Union, ) @@ -132,9 +132,9 @@ def add_guard_attributes( def trace_stream_guard( guard_span: Span, - result: Iterable[ValidationOutcome[OT]], + result: Iterator[ValidationOutcome[OT]], history: Stack[Call], -) -> Iterable[ValidationOutcome[OT]]: +) -> Iterator[ValidationOutcome[OT]]: next_exists = True while next_exists: try: @@ -152,12 +152,12 @@ def trace_guard_execution( guard_name: str, history: Stack[Call], _execute_fn: Callable[ - ..., Union[ValidationOutcome[OT], Iterable[ValidationOutcome[OT]]] + ..., Union[ValidationOutcome[OT], Iterator[ValidationOutcome[OT]]] ], tracer: Optional[Tracer] = None, *args, **kwargs, -) -> Union[ValidationOutcome[OT], Iterable[ValidationOutcome[OT]]]: +) -> Union[ValidationOutcome[OT], Iterator[ValidationOutcome[OT]]]: if not settings.disable_tracing: current_otel_context = context.get_current() tracer = tracer or trace.get_tracer("guardrails-ai", GUARDRAILS_VERSION) @@ -172,7 +172,7 @@ def trace_guard_execution( try: result = _execute_fn(*args, **kwargs) - if isinstance(result, Iterable) and not isinstance( + if isinstance(result, Iterator) and not isinstance( result, ValidationOutcome ): return trace_stream_guard(guard_span, result, history) @@ -188,9 +188,9 @@ def trace_guard_execution( async def trace_async_stream_guard( guard_span: Span, - result: AsyncIterable[ValidationOutcome[OT]], + result: AsyncIterator[ValidationOutcome[OT]], history: Stack[Call], -) -> AsyncIterable[ValidationOutcome[OT]]: +) -> AsyncIterator[ValidationOutcome[OT]]: next_exists = True while next_exists: try: @@ -215,7 +215,7 @@ async def trace_async_guard_execution( Union[ ValidationOutcome[OT], Awaitable[ValidationOutcome[OT]], - AsyncIterable[ValidationOutcome[OT]], + AsyncIterator[ValidationOutcome[OT]], ], ], ], @@ -225,7 +225,7 @@ async def trace_async_guard_execution( ) -> Union[ ValidationOutcome[OT], Awaitable[ValidationOutcome[OT]], - AsyncIterable[ValidationOutcome[OT]], + AsyncIterator[ValidationOutcome[OT]], ]: if not settings.disable_tracing: current_otel_context = context.get_current() @@ -241,7 +241,7 @@ async def trace_async_guard_execution( try: result = await _execute_fn(*args, **kwargs) - if isinstance(result, AsyncIterable): + if isinstance(result, AsyncIterator): return trace_async_stream_guard(guard_span, result, history) res = result diff --git a/guardrails/telemetry/runner_tracing.py b/guardrails/telemetry/runner_tracing.py index 9226f9fb5..5a6ef6f85 100644 --- a/guardrails/telemetry/runner_tracing.py +++ b/guardrails/telemetry/runner_tracing.py @@ -1,10 +1,10 @@ import json from functools import wraps from typing import ( - AsyncIterable, + AsyncIterator, Awaitable, Callable, - Generator, + Iterator, Optional, ) @@ -86,8 +86,8 @@ def trace_step_wrapper(*args, **kwargs) -> Iteration: def trace_stream_step_generator( - fn: Callable[..., Generator[ValidationOutcome[OT], None, None]], *args, **kwargs -) -> Generator[ValidationOutcome[OT], None, None]: + fn: Callable[..., Iterator[ValidationOutcome[OT]]], *args, **kwargs +) -> Iterator[ValidationOutcome[OT]]: current_otel_context = context.get_current() tracer = get_tracer() tracer = tracer or trace.get_tracer("guardrails-ai", GUARDRAILS_VERSION) @@ -119,12 +119,10 @@ def trace_stream_step_generator( def trace_stream_step( - fn: Callable[..., Generator[ValidationOutcome[OT], None, None]], -) -> Callable[..., Generator[ValidationOutcome[OT], None, None]]: + fn: Callable[..., Iterator[ValidationOutcome[OT]]], +) -> Callable[..., Iterator[ValidationOutcome[OT]]]: @wraps(fn) - def trace_stream_step_wrapper( - *args, **kwargs - ) -> Generator[ValidationOutcome[OT], None, None]: + def trace_stream_step_wrapper(*args, **kwargs) -> Iterator[ValidationOutcome[OT]]: if not settings.disable_tracing: return trace_stream_step_generator(fn, *args, **kwargs) else: @@ -163,8 +161,8 @@ async def trace_async_step_wrapper(*args, **kwargs) -> Iteration: async def trace_async_stream_step_generator( - fn: Callable[..., AsyncIterable[ValidationOutcome[OT]]], *args, **kwargs -) -> AsyncIterable[ValidationOutcome[OT]]: + fn: Callable[..., AsyncIterator[ValidationOutcome[OT]]], *args, **kwargs +) -> AsyncIterator[ValidationOutcome[OT]]: current_otel_context = context.get_current() tracer = get_tracer() tracer = tracer or trace.get_tracer("guardrails-ai", GUARDRAILS_VERSION) @@ -197,12 +195,12 @@ async def trace_async_stream_step_generator( def trace_async_stream_step( - fn: Callable[..., AsyncIterable[ValidationOutcome[OT]]], + fn: Callable[..., AsyncIterator[ValidationOutcome[OT]]], ): @wraps(fn) async def trace_async_stream_step_wrapper( *args, **kwargs - ) -> AsyncIterable[ValidationOutcome[OT]]: + ) -> AsyncIterator[ValidationOutcome[OT]]: if not settings.disable_tracing: return trace_async_stream_step_generator(fn, *args, **kwargs) else: diff --git a/guardrails/types/primitives.py b/guardrails/types/primitives.py index 3f07678fc..28f3cab41 100644 --- a/guardrails/types/primitives.py +++ b/guardrails/types/primitives.py @@ -3,7 +3,15 @@ class PrimitiveTypes(str, Enum): - BOOLEAN = SimpleTypes.BOOLEAN - INTEGER = SimpleTypes.INTEGER - NUMBER = SimpleTypes.NUMBER - STRING = SimpleTypes.STRING + BOOLEAN = SimpleTypes.BOOLEAN.value + INTEGER = SimpleTypes.INTEGER.value + NUMBER = SimpleTypes.NUMBER.value + STRING = SimpleTypes.STRING.value + + @staticmethod + def is_primitive(value: str) -> bool: + try: + return value in [member.value for member in PrimitiveTypes] + except Exception as e: + print(e) + return False diff --git a/guardrails/utils/casting_utils.py b/guardrails/utils/casting_utils.py index 7e4ee2d2f..79ee00171 100644 --- a/guardrails/utils/casting_utils.py +++ b/guardrails/utils/casting_utils.py @@ -1,4 +1,5 @@ from typing import Any, Optional +import warnings def to_int(v: Any) -> Optional[int]: @@ -23,3 +24,12 @@ def to_string(v: Any) -> Optional[str]: return str_value except Exception: return None + + +def to_bool(value: str) -> Optional[bool]: + if value.lower() == "true": + return True + if value.lower() == "false": + return False + warnings.warn(f"Could not cast {value} to bool. Returning None.") + return None diff --git a/guardrails/utils/hub_telemetry_utils.py b/guardrails/utils/hub_telemetry_utils.py index 877457c2b..d8e5d2dc9 100644 --- a/guardrails/utils/hub_telemetry_utils.py +++ b/guardrails/utils/hub_telemetry_utils.py @@ -1,6 +1,8 @@ # Imports import logging +from typing import Optional +from guardrails.settings import settings from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( # HTTP Exporter OTLPSpanExporter, ) @@ -8,6 +10,7 @@ from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import ConsoleSpanExporter, SimpleSpanProcessor from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator +from opentelemetry.trace.propagation import set_span_in_context class HubTelemetry: @@ -22,19 +25,23 @@ class HubTelemetry: _processor = None _tracer = None _prop = None - _carrier = {} + _enabled = False def __new__( cls, service_name: str = "guardrails-hub", tracer_name: str = "gr_hub", export_locally: bool = False, + *, + enabled: Optional[bool] = None, ): if cls._instance is None: logging.debug("Creating HubTelemetry instance...") cls._instance = super(HubTelemetry, cls).__new__(cls) logging.debug("Initializing HubTelemetry instance...") - cls._instance.initialize_tracer(service_name, tracer_name, export_locally) + cls._instance.initialize_tracer( + service_name, tracer_name, export_locally, enabled=enabled + ) else: logging.debug("Returning existing HubTelemetry instance...") return cls._instance @@ -44,11 +51,17 @@ def initialize_tracer( service_name: str, tracer_name: str, export_locally: bool, + *, + enabled: Optional[bool] = None, ): """Initializes a tracer for Guardrails Hub.""" + if enabled is None: + enabled = settings.rc.enable_metrics or False + self._enabled = enabled + self._carrier = {} self._service_name = service_name - # self._endpoint = "http://localhost:4318/v1/traces" + # self._endpoint = "http://localhost:5318/v1/traces" self._endpoint = ( "https://hty0gc1ok3.execute-api.us-east-1.amazonaws.com/v1/traces" ) @@ -76,11 +89,14 @@ def initialize_tracer( self._prop = TraceContextTextMapPropagator() - def inject_current_context(self) -> None: + def inject_current_context(self, context=None) -> None: """Injects the current context into the carrier.""" if not self._prop: return - self._prop.inject(carrier=self._carrier) + if context is not None: + self._prop.inject(carrier=self._carrier, context=context) + else: + self._prop.inject(carrier=self._carrier) def extract_current_context(self): """Extracts the current context from the carrier.""" @@ -89,13 +105,7 @@ def extract_current_context(self): context = self._prop.extract(carrier=self._carrier) return context - def create_new_span( - self, - span_name: str, - attributes: list, - is_parent: bool, # Inject current context if IS a parent span - has_parent: bool, # Extract current context if HAS a parent span - ): + def create_new_span(self, span_name: str, attributes: list): """Creates a new span within the tracer with the given name and attributes. @@ -112,13 +122,12 @@ def create_new_span( """ if self._tracer is None: return - with self._tracer.start_as_current_span( + with self._tracer.start_span( span_name, # type: ignore (Fails in Python 3.9 for invalid reason) - context=self.extract_current_context() if has_parent else None, + context=self.extract_current_context(), ) as span: - if is_parent: - # Inject the current context - self.inject_current_context() + context = set_span_in_context(span) + self.inject_current_context(context=context) for attribute in attributes: span.set_attribute(attribute[0], attribute[1]) diff --git a/guardrails/utils/openai_utils/v1.py b/guardrails/utils/openai_utils/v1.py index 93e24042c..45316a72c 100644 --- a/guardrails/utils/openai_utils/v1.py +++ b/guardrails/utils/openai_utils/v1.py @@ -1,4 +1,4 @@ -from typing import Any, AsyncIterable, Callable, Dict, Iterable, List, Optional, cast +from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, cast import openai @@ -105,7 +105,7 @@ def construct_nonchat_response( if stream: # If stream is defined and set to True, # openai returns a generator - openai_response = cast(Iterable[Dict[str, Any]], openai_response) + openai_response = cast(Iterator[Dict[str, Any]], openai_response) # Simply return the generator wrapped in an LLMResponse return LLMResponse(output="", stream_output=openai_response) @@ -180,7 +180,7 @@ def construct_chat_response( if stream: # If stream is defined and set to True, # openai returns a generator object - openai_response = cast(Iterable[Dict[str, Any]], openai_response) + openai_response = cast(Iterator[Dict[str, Any]], openai_response) # Simply return the generator wrapped in an LLMResponse return LLMResponse(output="", stream_output=openai_response) @@ -266,7 +266,7 @@ async def construct_nonchat_response( # If stream is defined and set to True, # openai returns a generator object complete_output = "" - openai_response = cast(AsyncIterable[Dict[str, Any]], openai_response) + openai_response = cast(AsyncIterator[Dict[str, Any]], openai_response) async for response in openai_response: complete_output += response["choices"][0]["text"] @@ -330,7 +330,7 @@ async def construct_chat_response( # If stream is defined and set to True, # openai returns a generator object collected_messages = [] - openai_response = cast(AsyncIterable[Dict[str, Any]], openai_response) + openai_response = cast(AsyncIterator[Dict[str, Any]], openai_response) async for chunk in openai_response: chunk_message = chunk["choices"][0]["delta"] collected_messages.append(chunk_message) # save the message diff --git a/guardrails/validator_base.py b/guardrails/validator_base.py index f0b1689d0..d2c8e53db 100644 --- a/guardrails/validator_base.py +++ b/guardrails/validator_base.py @@ -11,6 +11,7 @@ from dataclasses import dataclass from string import Template from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union +from typing_extensions import deprecated from warnings import warn import warnings @@ -18,14 +19,15 @@ import requests from langchain_core.runnables import Runnable +from guardrails.settings import settings from guardrails.classes import ErrorSpan # noqa from guardrails.classes import PassResult # noqa from guardrails.classes import FailResult, ValidationResult -from guardrails.classes.credentials import Credentials from guardrails.constants import hub from guardrails.hub_token.token import VALIDATOR_HUB_SERVICE, get_jwt_token from guardrails.logger import logger from guardrails.remote_inference import remote_inference +from guardrails.hub_telemetry.hub_tracing import trace from guardrails.types.on_fail import OnFailAction from guardrails.utils.safe_get import safe_get from guardrails.utils.hub_telemetry_utils import HubTelemetry @@ -83,22 +85,25 @@ def __init__( on_fail: Optional[Union[Callable[[Any, FailResult], Any], OnFailAction]] = None, **kwargs, ): - self.creds = Credentials.from_rc_file() - self._disable_telemetry = self.creds.enable_metrics is not True + self._disable_telemetry = settings.rc.enable_metrics is not True if not self._disable_telemetry: - self._hub_telemetry = HubTelemetry() + self._hub_telemetry = HubTelemetry(enabled=settings.rc.enable_metrics) self.use_local = kwargs.get("use_local", None) self.validation_endpoint = kwargs.get("validation_endpoint", None) - if not self.creds: + # NOTE: I think this is an evergreen check + # We should test w/o an rc file, + # and if this doesn't raise then we should remove this. + if not settings.rc: raise ValueError( - "No credentials found. Please run `guardrails configure` and try again." + "No .guardrailsrc file found." + " Please run `guardrails configure` and try again." ) - self.hub_jwt_token = get_jwt_token(self.creds) + self.hub_jwt_token = get_jwt_token(settings.rc) # If use_local is not set, we can fall back to the setting determined in CLI if self.use_local is None: - self.use_local = not remote_inference.get_use_remote_inference(self.creds) + self.use_local = not remote_inference.get_use_remote_inference(settings.rc) if not self.validation_endpoint: validator_id = self.rail_alias.split("/")[-1] @@ -138,6 +143,18 @@ def __init__( self.rail_alias in validators_registry ), f"Validator {self.__class__.__name__} is not registered. " + @property + @deprecated( + ( + "The `creds` attribute is deprecated and will be removed in version 0.6.x." + " Use `settings.rc` instead." + ) + ) + def creds(self): + from guardrails.classes.credentials import Credentials # type: ignore + + return Credentials.from_rc_file() # type: ignore + def _set_on_fail_method(self, on_fail: Callable[[Any, FailResult], Any]): """Set the on_fail method for the validator.""" on_fail_args = inspect.getfullargspec(on_fail) @@ -201,7 +218,6 @@ def validate(self, value: Any, metadata: Dict[str, Any]) -> ValidationResult: validation requirements, logic, or pre/post processing. """ validation_result = self._validate(value, metadata) - self._log_telemetry() return validation_result async def async_validate( @@ -217,6 +233,7 @@ async def async_validate( loop = asyncio.get_event_loop() return await loop.run_in_executor(None, self.validate, value, metadata) + @trace(name="/validator_inference", origin="Validator._inference") def _inference(self, model_input: Any) -> Any: """Calls either a local or remote inference engine for use in the validation call. @@ -463,29 +480,6 @@ def to_runnable(self) -> Runnable: return ValidatorRunnable(self) - def _log_telemetry(self) -> None: - """Logs telemetry after the validator is called.""" - - if not self._disable_telemetry: - # Get HubTelemetry singleton and create a new span to - # log the validator inference - used_guardrails_endpoint = ( - VALIDATOR_HUB_SERVICE in self.validation_endpoint and not self.use_local - ) - used_custom_endpoint = not self.use_local and not used_guardrails_endpoint - self._hub_telemetry.create_new_span( - span_name="/validator_inference", - attributes=[ - ("validator_name", self.rail_alias), - ("used_remote_inference", not self.use_local), - ("used_local_inference", self.use_local), - ("used_guardrails_endpoint", used_guardrails_endpoint), - ("used_custom_endpoint", used_custom_endpoint), - ], - is_parent=False, # This span will have no children - has_parent=True, # This span has a parent - ) - V = TypeVar("V", bound=Validator, covariant=True) validators_registry: Dict[str, Type[Validator]] = {} diff --git a/guardrails/validator_service/__init__.py b/guardrails/validator_service/__init__.py index 1ea4ef9c7..c5f54efe2 100644 --- a/guardrails/validator_service/__init__.py +++ b/guardrails/validator_service/__init__.py @@ -1,6 +1,6 @@ import asyncio import os -from typing import Any, Iterable, Optional, Tuple +from typing import Any, Iterator, Optional, Tuple import warnings from guardrails.actions.filter import apply_filters @@ -12,6 +12,9 @@ ) from guardrails.types import ValidatorMap from guardrails.telemetry.legacy_validator_tracing import trace_validation_result + +# Keep this imported for backwards compatibility +from guardrails.validator_service.validator_service_base import ValidatorServiceBase # noqa from guardrails.validator_service.async_validator_service import AsyncValidatorService from guardrails.validator_service.sequential_validator_service import ( SequentialValidatorService, @@ -98,14 +101,14 @@ def validate( def validate_stream( - value_stream: Iterable[Tuple[Any, bool]], + value_stream: Iterator[Tuple[Any, bool]], metadata: dict, validator_map: ValidatorMap, iteration: Iteration, disable_tracer: Optional[bool] = True, path: Optional[str] = None, **kwargs, -) -> Iterable[StreamValidationResult]: +) -> Iterator[StreamValidationResult]: if path is None: path = "$" sequential_validator_service = SequentialValidatorService(disable_tracer) diff --git a/guardrails/validator_service/async_validator_service.py b/guardrails/validator_service/async_validator_service.py index 323e6c410..03ef3d5ba 100644 --- a/guardrails/validator_service/async_validator_service.py +++ b/guardrails/validator_service/async_validator_service.py @@ -9,6 +9,7 @@ PassResult, ValidationResult, ) +from guardrails.hub_telemetry.hub_tracing import async_trace from guardrails.telemetry.validator_tracing import trace_async_validator from guardrails.types import ValidatorMap, OnFailAction from guardrails.classes.validation.validator_logs import ValidatorLogs @@ -23,6 +24,9 @@ class AsyncValidatorService(ValidatorServiceBase): + @async_trace( + name="/validator_usage", origin="AsyncValidatorService.execute_validator" + ) async def execute_validator( self, validator: Validator, diff --git a/guardrails/validator_service/sequential_validator_service.py b/guardrails/validator_service/sequential_validator_service.py index e86598277..a07615172 100644 --- a/guardrails/validator_service/sequential_validator_service.py +++ b/guardrails/validator_service/sequential_validator_service.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, Dict, Iterable, List, Optional, Tuple, cast +from typing import Any, Dict, Iterator, List, Optional, Tuple, cast from guardrails.actions.filter import Filter from guardrails.actions.refrain import Refrain @@ -80,12 +80,12 @@ def run_validators_stream( self, iteration: Iteration, validator_map: ValidatorMap, - value_stream: Iterable[Tuple[Any, bool]], + value_stream: Iterator[Tuple[Any, bool]], metadata: Dict[str, Any], absolute_property_path: str, reference_property_path: str, **kwargs, - ) -> Iterable[StreamValidationResult]: + ) -> Iterator[StreamValidationResult]: validators = validator_map.get(reference_property_path, []) for validator in validators: if validator.on_fail_descriptor == OnFailAction.FIX: @@ -122,12 +122,12 @@ def run_validators_stream_fix( self, iteration: Iteration, validator_map: ValidatorMap, - value_stream: Iterable[Tuple[Any, bool]], + value_stream: Iterator[Tuple[Any, bool]], metadata: Dict[str, Any], absolute_property_path: str, reference_property_path: str, **kwargs, - ) -> Iterable[StreamValidationResult]: + ) -> Iterator[StreamValidationResult]: validators = validator_map.get(reference_property_path, []) acc_output = "" validator_partial_acc: dict[int, str] = {} @@ -271,12 +271,12 @@ def run_validators_stream_noop( self, iteration: Iteration, validator_map: ValidatorMap, - value_stream: Iterable[Tuple[Any, bool]], + value_stream: Iterator[Tuple[Any, bool]], metadata: Dict[str, Any], absolute_property_path: str, reference_property_path: str, **kwargs, - ) -> Iterable[StreamValidationResult]: + ) -> Iterator[StreamValidationResult]: validators = validator_map.get(reference_property_path, []) # Validate the field # TODO: Under what conditions do we yield? @@ -481,14 +481,14 @@ def validate( def validate_stream( self, - value_stream: Iterable[Tuple[Any, bool]], + value_stream: Iterator[Tuple[Any, bool]], metadata: dict, validator_map: ValidatorMap, iteration: Iteration, absolute_path: str, reference_path: str, **kwargs, - ) -> Iterable[StreamValidationResult]: + ) -> Iterator[StreamValidationResult]: # I assume validate stream doesn't need validate_dependents # because right now we're only handling StringSchema diff --git a/guardrails/validator_service/validator_service_base.py b/guardrails/validator_service/validator_service_base.py index d580a9ff0..0ea2e120b 100644 --- a/guardrails/validator_service/validator_service_base.py +++ b/guardrails/validator_service/validator_service_base.py @@ -12,8 +12,8 @@ ) from guardrails.errors import ValidationError from guardrails.merge import merge +from guardrails.hub_telemetry.hub_tracing import trace from guardrails.types import OnFailAction -from guardrails.utils.hub_telemetry_utils import HubTelemetry from guardrails.classes.validation.validator_logs import ValidatorLogs from guardrails.actions.reask import FieldReAsk from guardrails.telemetry import trace_validator @@ -43,6 +43,7 @@ def __init__(self, disable_tracer: Optional[bool] = True): # This is a well known issue without any real solutions. # Using `fork` instead of `spawn` may alleviate the symptom for POSIX systems, # but is relatively unsupported on Windows. + @trace(name="/validator_usage", origin="ValidatorServiceBase.execute_validator") def execute_validator( self, validator: Validator, @@ -152,26 +153,6 @@ def after_run_validator( validator_logs.validation_result = result validator_logs.end_time = end_time - if not self._disable_tracer: - # Get HubTelemetry singleton and create a new span to - # log the validator usage - _hub_telemetry = HubTelemetry() - _hub_telemetry.create_new_span( - span_name="/validator_usage", - attributes=[ - ("validator_name", validator.rail_alias), - ("validator_on_fail", validator.on_fail_descriptor), - ( - "validator_result", - result.outcome - if isinstance(result, ValidationResult) - else None, - ), - ], - is_parent=False, # This span will have no children - has_parent=True, # This span has a parent - ) - return validator_logs def run_validator( diff --git a/tests/conftest.py b/tests/conftest.py index 5db54f65b..c822037e9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -37,20 +37,17 @@ def mock_validator_base_hub_telemetry(): @pytest.fixture(autouse=True) -def mock_validator_service_hub_telemetry(): - with patch( - "guardrails.validator_service.validator_service_base.HubTelemetry" - ) as MockHubTelemetry: +def mock_runner_hub_telemetry(): + with patch("guardrails.run.runner.HubTelemetry") as MockHubTelemetry: MockHubTelemetry.return_value = MagicMock() MockHubTelemetry.return_value.to_dict = None yield MockHubTelemetry @pytest.fixture(autouse=True) -def mock_runner_hub_telemetry(): - with patch("guardrails.run.runner.HubTelemetry") as MockHubTelemetry: +def mock_hub_tracing(): + with patch("guardrails.hub_telemetry.hub_tracing.HubTelemetry") as MockHubTelemetry: MockHubTelemetry.return_value = MagicMock() - MockHubTelemetry.return_value.to_dict = None yield MockHubTelemetry @@ -59,5 +56,5 @@ def pytest_collection_modifyitems(items): if "no_hub_telemetry_mock" in item.keywords: item.fixturenames.remove("mock_guard_hub_telemetry") item.fixturenames.remove("mock_validator_base_hub_telemetry") - item.fixturenames.remove("mock_validator_service_hub_telemetry") item.fixturenames.remove("mock_runner_hub_telemetry") + item.fixturenames.remove("mock_hub_tracing") diff --git a/tests/integration_tests/test_telemetry.py b/tests/integration_tests/test_telemetry.py index 9685aa7bb..4bb70bebb 100644 --- a/tests/integration_tests/test_telemetry.py +++ b/tests/integration_tests/test_telemetry.py @@ -95,10 +95,18 @@ def test_hub_traces_go_to_hub_telem_sink(self, mocker): hub_spans = hub_exporter.get_finished_spans() assert len(private_spans) == 0 - assert len(hub_spans) == 3 + assert len(hub_spans) == 6 - for span in hub_spans: - assert span.name in ["/guard_call", "/validator_usage", "/reasks"] + span_names = sorted([span.name for span in hub_spans]) + + assert span_names == [ + "/guard_call", + "/llm_call", + "/reasks", + "/step", + "/validation", + "/validator_usage", + ] @pytest.mark.no_hub_telemetry_mock def test_no_cross_contamination(self, mocker): @@ -143,6 +151,15 @@ def test_no_cross_contamination(self, mocker): "guardrails/guard/step/validator", ] - assert len(hub_spans) == 3 - for span in hub_spans: - assert span.name in ["/guard_call", "/validator_usage", "/reasks"] + assert len(hub_spans) == 6 + + span_names = sorted([span.name for span in hub_spans]) + + assert span_names == [ + "/guard_call", + "/llm_call", + "/reasks", + "/step", + "/validation", + "/validator_usage", + ] diff --git a/tests/unit_tests/classes/test_credentials.py b/tests/unit_tests/classes/test_credentials.py deleted file mode 100644 index 4e54030a5..000000000 --- a/tests/unit_tests/classes/test_credentials.py +++ /dev/null @@ -1,80 +0,0 @@ -import pytest -from tests.unit_tests.mocks.mock_file import MockFile - - -def test_from_rc_file(mocker): - # TODO: Re-enable this once we move nltk.download calls to individual validator repos. # noqa - # Right now, it fires during our import chain, causing this to blow up - mocker.patch("nltk.data.find") - mocker.patch("nltk.download") - - expanduser_mock = mocker.patch("guardrails.classes.credentials.expanduser") - expanduser_mock.return_value = "/Home" - - import os - - join_spy = mocker.spy(os.path, "join") - - mock_file = MockFile() - mock_open = mocker.patch("guardrails.classes.credentials.open") - mock_open.return_value = mock_file - - readlines_spy = mocker.patch.object(mock_file, "readlines") - readlines_spy.return_value = ["key1=val1", "key2=val2"] - close_spy = mocker.spy(mock_file, "close") - - from guardrails.classes.credentials import Credentials - - mock_from_dict = mocker.patch.object(Credentials, "from_dict") - - Credentials.from_rc_file() - - assert expanduser_mock.called is True - join_spy.assert_called_once_with("/Home", ".guardrailsrc") - - assert mock_open.call_count == 1 - assert readlines_spy.call_count == 1 - assert close_spy.call_count == 1 - # This is supposed to look wrong; since this method is on the super, - # it doesn't care if the key values are actually correct. - # Something to watch out for. - mock_from_dict.assert_called_once_with({"key1": "val1", "key2": "val2"}) - - -@pytest.mark.parametrize("no_metrics", [True, False, None]) -def test_from_rc_file_backfill_no_metrics_true(mocker, no_metrics): - expanduser_mock = mocker.patch("guardrails.classes.credentials.expanduser") - expanduser_mock.return_value = "/Home" - - import os - - join_spy = mocker.spy(os.path, "join") - - mock_file = MockFile() - mock_open = mocker.patch("guardrails.classes.credentials.open") - mock_open.return_value = mock_file - readlines_spy = mocker.patch.object(mock_file, "readlines") - readlines_spy.return_value = [f"no_metrics={no_metrics}"] - close_spy = mocker.spy(mock_file, "close") - - from guardrails.classes.credentials import Credentials - - mock_from_dict = mocker.patch.object(Credentials, "from_dict") - - Credentials.from_rc_file() - - assert expanduser_mock.called is True - join_spy.assert_called_once_with("/Home", ".guardrailsrc") - - assert mock_open.call_count == 1 - assert readlines_spy.call_count == 1 - assert close_spy.call_count == 1 - # This is supposed to look wrong; since this method is on the super, - # it doesn't care if the key values are actually correct. - # Something to watch out for. - - expected_dict = {"enable_metrics": not no_metrics} - if no_metrics is None: - expected_dict = {} - - mock_from_dict.assert_called_once_with(expected_dict) diff --git a/tests/unit_tests/classes/test_rc.py b/tests/unit_tests/classes/test_rc.py new file mode 100644 index 000000000..fb71471f0 --- /dev/null +++ b/tests/unit_tests/classes/test_rc.py @@ -0,0 +1,80 @@ +import pytest +from tests.unit_tests.mocks.mock_file import MockFile + + +class TestRC: + def test_load(self, mocker): + # TODO: Re-enable this once we move nltk.download calls to individual validator repos. # noqa + # Right now, it fires during our import chain, causing this to blow up + mocker.patch("nltk.data.find") + mocker.patch("nltk.download") + + expanduser_mock = mocker.patch("guardrails.classes.rc.expanduser") + expanduser_mock.return_value = "/Home" + + import os + + join_spy = mocker.spy(os.path, "join") + + mock_file = MockFile() + mock_open = mocker.patch("guardrails.classes.rc.open") + mock_open.return_value = mock_file + + readlines_spy = mocker.patch.object(mock_file, "readlines") + readlines_spy.return_value = ["key1=val1", "key2=val2"] + close_spy = mocker.spy(mock_file, "close") + + from guardrails.classes.rc import RC + + mock_from_dict = mocker.patch.object(RC, "from_dict") + + RC.load() + + assert expanduser_mock.called is True + join_spy.assert_called_once_with("/Home", ".guardrailsrc") + + assert mock_open.call_count == 1 + assert readlines_spy.call_count == 1 + assert close_spy.call_count == 1 + # This is supposed to look wrong; since this method is on the super, + # it doesn't care if the key values are actually correct. + # Something to watch out for. + mock_from_dict.assert_called_once_with({"key1": "val1", "key2": "val2"}) + + @pytest.mark.parametrize("no_metrics", [True, False, None]) + def test_load_backfill_no_metrics_true(self, mocker, no_metrics): + expanduser_mock = mocker.patch("guardrails.classes.rc.expanduser") + expanduser_mock.return_value = "/Home" + + import os + + join_spy = mocker.spy(os.path, "join") + + mock_file = MockFile() + mock_open = mocker.patch("guardrails.classes.rc.open") + mock_open.return_value = mock_file + readlines_spy = mocker.patch.object(mock_file, "readlines") + readlines_spy.return_value = [f"no_metrics={no_metrics}"] + close_spy = mocker.spy(mock_file, "close") + + from guardrails.classes.rc import RC + + mock_from_dict = mocker.patch.object(RC, "from_dict") + + RC.load() + + assert expanduser_mock.called is True + join_spy.assert_called_once_with("/Home", ".guardrailsrc") + + assert mock_open.call_count == 1 + assert readlines_spy.call_count == 1 + assert close_spy.call_count == 1 + # This is supposed to look wrong; since this method is on the super, + # it doesn't care if the key values are actually correct. + # Something to watch out for. + + expected_dict = {"enable_metrics": not no_metrics} + if no_metrics is None: + expected_dict = {} + + mock_from_dict.assert_called_once_with(expected_dict) diff --git a/tests/unit_tests/cli/server/test_hub_client.py b/tests/unit_tests/cli/server/test_hub_client.py index 7af761fc5..998a76408 100644 --- a/tests/unit_tests/cli/server/test_hub_client.py +++ b/tests/unit_tests/cli/server/test_hub_client.py @@ -5,7 +5,7 @@ from datetime import timezone -from guardrails.classes.credentials import Credentials +from guardrails.classes.rc import RC from guardrails.cli.server.hub_client import ( TOKEN_EXPIRED_MESSAGE, TOKEN_INVALID_MESSAGE, @@ -51,22 +51,22 @@ def test_get_jwt_token(): timedelta = datetime.timedelta(seconds=1000) expiration = datetime.datetime.now(tz=timezone.utc) + timedelta valid_jwt = jwt.encode({"exp": expiration}, secret_key, algorithm="HS256") - creds = Credentials.from_dict({"token": valid_jwt}) + rc = RC.from_dict({"token": valid_jwt}) # Test valid token - assert get_jwt_token(creds) == valid_jwt + assert get_jwt_token(rc) == valid_jwt # Test with an expired JWT with pytest.raises(ExpiredTokenError) as e: expired = datetime.datetime.now(tz=timezone.utc) - timedelta expired_jwt = jwt.encode({"exp": expired}, secret_key, algorithm="HS256") - get_jwt_token(Credentials.from_dict({"token": expired_jwt})) + get_jwt_token(RC.from_dict({"token": expired_jwt})) assert str(e.value) == TOKEN_EXPIRED_MESSAGE # Test with an invalid token format with pytest.raises(InvalidTokenError) as e: invalid_jwt = "invalid" - get_jwt_token(Credentials.from_dict({"token": invalid_jwt})) + get_jwt_token(RC.from_dict({"token": invalid_jwt})) assert str(e.value) == TOKEN_INVALID_MESSAGE diff --git a/tests/unit_tests/hub/test_hub_install.py b/tests/unit_tests/hub/test_hub_install.py index 4b35feec9..3aea52899 100644 --- a/tests/unit_tests/hub/test_hub_install.py +++ b/tests/unit_tests/hub/test_hub_install.py @@ -1,7 +1,7 @@ import pytest from unittest.mock import ANY, call, MagicMock -from guardrails.classes.credentials import Credentials +from guardrails.classes.rc import RC from guardrails_hub_types import Manifest from guardrails.hub.validator_package_service import ( InvalidHubInstallURL, @@ -35,7 +35,7 @@ def setup_method(self): def test_exits_early_if_uri_is_not_valid(self, mocker, use_remote_inferencing): mocker.patch( - "guardrails.hub.install.Credentials.has_rc_file", + "guardrails.hub.install.RC.exists", return_value=True, ) with pytest.raises(InvalidHubInstallURL): @@ -43,12 +43,12 @@ def test_exits_early_if_uri_is_not_valid(self, mocker, use_remote_inferencing): def test_install_local_models__false(self, mocker, use_remote_inferencing): mocker.patch( - "guardrails.hub.install.Credentials.has_rc_file", + "guardrails.hub.install.RC.exists", return_value=True, ) mocker.patch( - "guardrails.hub.install.Credentials.from_rc_file", - return_value=Credentials.from_dict( + "guardrails.hub.install.RC.load", + return_value=RC.from_dict( {"use_remote_inferencing": use_remote_inferencing} ), ) @@ -105,12 +105,12 @@ def test_install_local_models__false(self, mocker, use_remote_inferencing): def test_install_local_models__true(self, mocker, use_remote_inferencing): mocker.patch( - "guardrails.hub.install.Credentials.has_rc_file", + "guardrails.hub.install.RC.exists", return_value=True, ) mocker.patch( - "guardrails.hub.install.Credentials.from_rc_file", - return_value=Credentials.from_dict( + "guardrails.hub.install.RC.load", + return_value=RC.from_dict( {"use_remote_inferencing": use_remote_inferencing} ), ) @@ -166,12 +166,12 @@ def test_install_local_models__true(self, mocker, use_remote_inferencing): def test_install_local_models__none(self, mocker, use_remote_inferencing): mocker.patch( - "guardrails.hub.install.Credentials.has_rc_file", + "guardrails.hub.install.RC.exists", return_value=True, ) mocker.patch( - "guardrails.hub.install.Credentials.from_rc_file", - return_value=Credentials.from_dict( + "guardrails.hub.install.RC.load", + return_value=RC.from_dict( {"use_remote_inferencing": use_remote_inferencing} ), ) @@ -227,12 +227,12 @@ def test_install_local_models__none(self, mocker, use_remote_inferencing): def test_happy_path(self, mocker, use_remote_inferencing): mocker.patch( - "guardrails.hub.install.Credentials.has_rc_file", + "guardrails.hub.install.RC.exists", return_value=True, ) mocker.patch( - "guardrails.hub.install.Credentials.from_rc_file", - return_value=Credentials.from_dict( + "guardrails.hub.install.RC.load", + return_value=RC.from_dict( {"use_remote_inferencing": use_remote_inferencing} ), ) @@ -284,7 +284,7 @@ def test_happy_path(self, mocker, use_remote_inferencing): def test_install_local_models_confirmation(self, mocker, use_remote_inferencing): mocker.patch( - "guardrails.hub.install.Credentials.has_rc_file", + "guardrails.hub.install.RC.exists", return_value=False, ) mocker.patch("guardrails.hub.install.cli_logger.log") @@ -335,7 +335,7 @@ def test_install_local_models_confirmation_raises_exception( self, mocker, use_remote_inferencing ): mocker.patch( - "guardrails.hub.install.Credentials.has_rc_file", + "guardrails.hub.install.RC.exists", return_value=False, ) mocker.patch("guardrails.hub.install.cli_logger.log") @@ -381,12 +381,12 @@ def test_install_local_models_confirmation_raises_exception( def test_use_remote_endpoint(self, mocker, use_remote_inferencing: bool): mocker.patch( - "guardrails.hub.install.Credentials.has_rc_file", + "guardrails.hub.install.RC.exists", return_value=True, ) mocker.patch( - "guardrails.hub.install.Credentials.from_rc_file", - return_value=Credentials.from_dict( + "guardrails.hub.install.RC.load", + return_value=RC.from_dict( {"use_remote_inferencing": use_remote_inferencing} ), ) diff --git a/tests/unit_tests/integrations/databricks/test_ml_flow_instrumentor.py b/tests/unit_tests/integrations/databricks/test_ml_flow_instrumentor.py index 0e0593383..1419979f8 100644 --- a/tests/unit_tests/integrations/databricks/test_ml_flow_instrumentor.py +++ b/tests/unit_tests/integrations/databricks/test_ml_flow_instrumentor.py @@ -194,9 +194,9 @@ def test__instrument_guard_stream(self, mocker): m = MlFlowInstrumentor("mock experiment") - mock_result = [ - ValidationOutcome(call_id="mock call id", validation_passed=True) - ] + mock_result = iter( + [ValidationOutcome(call_id="mock call id", validation_passed=True)] + ) mock_execute = MagicMock() mock_execute.return_value = mock_result mock_guard = MagicMock(spec=Guard)