Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions guardrails/api_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import os
from typing import Any, Iterable, Optional
from typing import Any, Iterator, Optional

import requests
from guardrails_api_client.configuration import Configuration
Expand Down Expand Up @@ -80,7 +80,7 @@ def stream_validate(
guard: Guard,
payload: ValidatePayload,
openai_api_key: Optional[str] = None,
) -> Iterable[Any]:
) -> Iterator[Any]:
_openai_api_key = (
openai_api_key
if openai_api_key is not None
Expand Down
52 changes: 10 additions & 42 deletions guardrails/async_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from opentelemetry import context as otel_context
from typing import (
Any,
AsyncIterable,
AsyncIterator,
Awaitable,
Callable,
Dict,
Expand Down Expand Up @@ -37,6 +37,7 @@
set_tracer,
set_tracer_context,
)
from guardrails.hub_telemetry.hub_tracing import async_trace
from guardrails.types.pydantic import ModelOrListOfModels
from guardrails.types.validator import UseManyValidatorSpec, UseValidatorSpec
from guardrails.telemetry import trace_async_guard_execution, wrap_with_otel_context
Expand Down Expand Up @@ -187,7 +188,7 @@ async def _execute(
) -> Union[
ValidationOutcome[OT],
Awaitable[ValidationOutcome[OT]],
AsyncIterable[ValidationOutcome[OT]],
AsyncIterator[ValidationOutcome[OT]],
]:
self._fill_validator_map()
self._fill_validators()
Expand Down Expand Up @@ -219,49 +220,13 @@ async def __exec(
) -> Union[
ValidationOutcome[OT],
Awaitable[ValidationOutcome[OT]],
AsyncIterable[ValidationOutcome[OT]],
AsyncIterator[ValidationOutcome[OT]],
]:
prompt_params = prompt_params or {}
metadata = metadata or {}
if full_schema_reask is None:
full_schema_reask = self._base_model is not None

if self._allow_metrics_collection:
llm_api_str = ""
if llm_api:
llm_api_module_name = (
llm_api.__module__ if hasattr(llm_api, "__module__") else ""
)
llm_api_name = (
llm_api.__name__
if hasattr(llm_api, "__name__")
else type(llm_api).__name__
)
llm_api_str = f"{llm_api_module_name}.{llm_api_name}"
# Create a new span for this guard call
self._hub_telemetry.create_new_span(
span_name="/guard_call",
attributes=[
("guard_id", self.id),
("user_id", self._user_id),
("llm_api", llm_api_str),
(
"custom_reask_prompt",
self._exec_opts.reask_prompt is not None,
),
(
"custom_reask_instructions",
self._exec_opts.reask_instructions is not None,
),
(
"custom_reask_messages",
self._exec_opts.reask_messages is not None,
),
],
is_parent=True, # It will have children
has_parent=False, # Has no parents
)

set_call_kwargs(kwargs)
set_tracer(self._tracer)
set_tracer_context(self._tracer_context)
Expand Down Expand Up @@ -369,7 +334,7 @@ async def _exec(
) -> Union[
ValidationOutcome[OT],
Awaitable[ValidationOutcome[OT]],
AsyncIterable[ValidationOutcome[OT]],
AsyncIterator[ValidationOutcome[OT]],
]:
"""Call the LLM asynchronously and validate the output.

Expand Down Expand Up @@ -435,6 +400,7 @@ async def _exec(
)
return ValidationOutcome[OT].from_guard_history(call)

@async_trace(name="/guard_call", origin="AsyncGuard.__call__")
async def __call__(
self,
llm_api: Optional[Callable[..., Awaitable[Any]]] = None,
Expand All @@ -450,7 +416,7 @@ async def __call__(
) -> Union[
ValidationOutcome[OT],
Awaitable[ValidationOutcome[OT]],
AsyncIterable[ValidationOutcome[OT]],
AsyncIterator[ValidationOutcome[OT]],
]:
"""Call the LLM and validate the output. Pass an async LLM API to
return a coroutine.
Expand Down Expand Up @@ -501,6 +467,7 @@ async def __call__(
**kwargs,
)

@async_trace(name="/guard_call", origin="AsyncGuard.parse")
async def parse(
self,
llm_output: str,
Expand Down Expand Up @@ -567,7 +534,7 @@ async def parse(

async def _stream_server_call(
self, *, payload: Dict[str, Any]
) -> AsyncIterable[ValidationOutcome[OT]]:
) -> AsyncIterator[ValidationOutcome[OT]]:
# TODO: Once server side supports async streaming, this function will need to
# yield async generators, not generators
if self._api_client:
Expand Down Expand Up @@ -609,6 +576,7 @@ async def _stream_server_call(
else:
raise ValueError("AsyncGuard does not have an api client!")

@async_trace(name="/guard_call", origin="AsyncGuard.validate")
async def validate(
self, llm_output: str, *args, **kwargs
) -> Awaitable[ValidationOutcome[OT]]:
Expand Down
6 changes: 4 additions & 2 deletions guardrails/classes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from guardrails.classes.credentials import Credentials
from guardrails.classes.credentials import Credentials # type: ignore
from guardrails.classes.rc import RC
from guardrails.classes.input_type import InputType
from guardrails.classes.output_type import OT
from guardrails.classes.validation.validation_result import (
Expand All @@ -10,7 +11,8 @@
from guardrails.classes.validation_outcome import ValidationOutcome

__all__ = [
"Credentials",
"Credentials", # type: ignore
"RC",
"ErrorSpan",
"InputType",
"OT",
Expand Down
76 changes: 22 additions & 54 deletions guardrails/classes/credentials.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
import logging
import os
from dataclasses import dataclass
from os.path import expanduser
from typing import Optional
from typing_extensions import deprecated

from guardrails.classes.generic.serializeable import Serializeable
from guardrails.classes.generic.serializeable import SerializeableJSONEncoder
from guardrails.classes.rc import RC

BOOL_CONFIGS = set(["no_metrics", "enable_metrics", "use_remote_inferencing"])


@deprecated(
(
"The `Credentials` class is deprecated and will be removed in version 0.6.x."
" Use the `RC` class instead."
),
category=DeprecationWarning,
)
@dataclass
class Credentials(Serializeable):
id: Optional[str] = None
token: Optional[str] = None
class Credentials(RC):
no_metrics: Optional[bool] = False
enable_metrics: Optional[bool] = True
use_remote_inferencing: Optional[bool] = True

@staticmethod
def _to_bool(value: str) -> Optional[bool]:
Expand All @@ -27,51 +30,16 @@ def _to_bool(value: str) -> Optional[bool]:

@staticmethod
def has_rc_file() -> bool:
home = expanduser("~")
guardrails_rc = os.path.join(home, ".guardrailsrc")
return os.path.exists(guardrails_rc)
return RC.exists()

@staticmethod
def from_rc_file(logger: Optional[logging.Logger] = None) -> "Credentials":
try:
if not logger:
logger = logging.getLogger()
home = expanduser("~")
guardrails_rc = os.path.join(home, ".guardrailsrc")
with open(guardrails_rc, encoding="utf-8") as rc_file:
lines = rc_file.readlines()
filtered_lines = list(filter(lambda l: l.strip(), lines))
creds = {}
for line in filtered_lines:
line_content = line.split("=", 1)
if len(line_content) != 2:
logger.warning(
"""
Invalid line found in .guardrailsrc file!
All lines in this file should follow the format: key=value
Ignoring line contents...
"""
)
logger.debug(f".guardrailsrc file location: {guardrails_rc}")
else:
key, value = line_content
key = key.strip()
value = value.strip()
if key in BOOL_CONFIGS:
value = Credentials._to_bool(value)

creds[key] = value

rc_file.close()

# backfill no_metrics, handle defaults
# remove in 0.5.0
no_metrics_val = creds.pop("no_metrics", None)
if no_metrics_val is not None and creds.get("enable_metrics") is None:
creds["enable_metrics"] = not no_metrics_val

creds_dict = Credentials.from_dict(creds)
return creds_dict

except FileNotFoundError:
return Credentials.from_dict({}) # type: ignore
def from_rc_file(logger: Optional[logging.Logger] = None) -> "Credentials": # type: ignore
rc = RC.load(logger)
return Credentials( # type: ignore
id=rc.id,
token=rc.token,
enable_metrics=rc.enable_metrics,
use_remote_inferencing=rc.use_remote_inferencing,
no_metrics=(not rc.enable_metrics),
encoder=SerializeableJSONEncoder(),
)
12 changes: 6 additions & 6 deletions guardrails/classes/llm/llm_response.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
from itertools import tee
from typing import Any, Dict, Iterable, Optional, AsyncIterable
from typing import Any, Dict, Iterator, Optional, AsyncIterator

from guardrails_api_client import LLMResponse as ILLMResponse
from pydantic.config import ConfigDict
Expand All @@ -19,9 +19,9 @@ class LLMResponse(ILLMResponse):

Attributes:
output (str): The output from the LLM.
stream_output (Optional[Iterable]): A stream of output from the LLM.
stream_output (Optional[Iterator]): A stream of output from the LLM.
Default None.
async_stream_output (Optional[AsyncIterable]): An async stream of output
async_stream_output (Optional[AsyncIterator]): An async stream of output
from the LLM. Default None.
prompt_token_count (Optional[int]): The number of tokens in the prompt.
Default None.
Expand All @@ -35,8 +35,8 @@ class LLMResponse(ILLMResponse):
prompt_token_count: Optional[int] = None
response_token_count: Optional[int] = None
output: str
stream_output: Optional[Iterable] = None
async_stream_output: Optional[AsyncIterable] = None
stream_output: Optional[Iterator] = None
async_stream_output: Optional[AsyncIterator] = None

def to_interface(self) -> ILLMResponse:
stream_output = None
Expand Down Expand Up @@ -73,7 +73,7 @@ def to_dict(self) -> Dict[str, Any]:
def from_interface(cls, i_llm_response: ILLMResponse) -> "LLMResponse":
stream_output = None
if i_llm_response.stream_output:
stream_output = [so for so in i_llm_response.stream_output]
stream_output = iter([so for so in i_llm_response.stream_output])

async_stream_output = None
if i_llm_response.async_stream_output:
Expand Down
71 changes: 71 additions & 0 deletions guardrails/classes/rc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import logging
import os
from dataclasses import dataclass
from os.path import expanduser
from typing import Optional

from guardrails.classes.generic.serializeable import Serializeable
from guardrails.utils.casting_utils import to_bool

BOOL_CONFIGS = set(["no_metrics", "enable_metrics", "use_remote_inferencing"])


@dataclass
class RC(Serializeable):
id: Optional[str] = None
token: Optional[str] = None
enable_metrics: Optional[bool] = True
use_remote_inferencing: Optional[bool] = True

@staticmethod
def exists() -> bool:
home = expanduser("~")
guardrails_rc = os.path.join(home, ".guardrailsrc")
return os.path.exists(guardrails_rc)

@classmethod
def load(cls, logger: Optional[logging.Logger] = None) -> "RC":
try:
if not logger:
logger = logging.getLogger()
home = expanduser("~")
guardrails_rc = os.path.join(home, ".guardrailsrc")
with open(guardrails_rc, encoding="utf-8") as rc_file:
lines = rc_file.readlines()
filtered_lines = list(filter(lambda l: l.strip(), lines))
config = {}
for line in filtered_lines:
line_content = line.split("=", 1)
if len(line_content) != 2:
logger.warning(
"""
Invalid line found in .guardrailsrc file!
All lines in this file should follow the format: key=value
Ignoring line contents...
"""
)
logger.debug(f".guardrailsrc file location: {guardrails_rc}")
else:
key, value = line_content
key = key.strip()
value = value.strip()
if key in BOOL_CONFIGS:
value = to_bool(value)

config[key] = value

rc_file.close()

# backfill no_metrics, handle defaults
# We missed this comment in the 0.5.0 release
# Making it a TODO for 0.6.0
# TODO: remove in 0.6.0
no_metrics_val = config.pop("no_metrics", None)
if no_metrics_val is not None and config.get("enable_metrics") is None:
config["enable_metrics"] = not no_metrics_val

rc = cls.from_dict(config)
return rc

except FileNotFoundError:
return cls.from_dict({}) # type: ignore
7 changes: 4 additions & 3 deletions guardrails/cli/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import typer

from guardrails.classes.credentials import Credentials
from guardrails.settings import settings
from guardrails.cli.guardrails import guardrails
from guardrails.cli.logger import LEVELS, logger
from guardrails.cli.hub.console import console
Expand Down Expand Up @@ -46,7 +46,7 @@ def save_configuration_file(

def _get_default_token() -> str:
"""Get the default token from the configuration file."""
file_token = Credentials.from_rc_file(logger).token
file_token = settings.rc.token
if file_token is None:
return ""
return file_token
Expand Down Expand Up @@ -79,7 +79,8 @@ def configure(
),
):
version_warnings_if_applicable(console)
trace_if_enabled("configure")
if settings.rc.exists():
trace_if_enabled("configure")
existing_token = _get_default_token()
last4 = existing_token[-4:] if existing_token else ""

Expand Down
Loading
Loading