Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions guardrails/classes/credentials.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os
from dataclasses import dataclass
from os.path import expanduser
Expand All @@ -14,16 +15,30 @@ class Credentials(Serializeable):
no_metrics: Optional[bool] = False

@staticmethod
def from_rc_file() -> "Credentials":
def from_rc_file(logger: Optional[logging.Logger] = None) -> "Credentials":
try:
if not logger:
logger = logging.getLogger()
home = expanduser("~")
guardrails_rc = os.path.join(home, ".guardrailsrc")
with open(guardrails_rc) as rc_file:
lines = rc_file.readlines()
filtered_lines = list(filter(lambda l: l.strip(), lines))
creds = {}
for line in lines:
key, value = line.split("=", 1)
creds[key.strip()] = value.strip()
for line in filtered_lines:
line_content = line.split("=", 1)
if len(line_content) != 2:
logger.warn(
"""
Invalid line found in .guardrailsrc file!
All lines in this file should follow the format: key=value
Ignoring line contents...
"""
)
logger.debug(f".guardrailsrc file location: {guardrails_rc}")
else:
key, value = line_content
creds[key.strip()] = value.strip()
rc_file.close()
return Credentials.from_dict(creds)

Expand Down
2 changes: 1 addition & 1 deletion guardrails/cli/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def save_configuration_file(
f"id={str(uuid.uuid4())}{os.linesep}",
f"client_id={client_id}{os.linesep}",
f"client_secret={client_secret}{os.linesep}",
f"no_metrics={str(no_metrics).lower()}{os.linesep}",
f"no_metrics={str(no_metrics).lower()}",
]
rc_file.writelines(lines)
rc_file.close()
Expand Down
6 changes: 3 additions & 3 deletions guardrails/cli/server/hub_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def fetch_module_manifest(


def fetch_module(module_name: str) -> ModuleManifest:
creds = Credentials.from_rc_file()
creds = Credentials.from_rc_file(logger)
token = get_auth_token(creds)

module_manifest_json = fetch_module_manifest(module_name, token, creds.id)
Expand All @@ -89,7 +89,7 @@ def get_validator_manifest(module_name: str):
# GET /auth
def get_auth():
try:
creds = Credentials.from_rc_file()
creds = Credentials.from_rc_file(logger)
token = get_auth_token(creds)
auth_url = f"{validator_hub_service}/auth"
response = fetch(auth_url, token, creds.id)
Expand All @@ -105,7 +105,7 @@ def get_auth():

def post_validator_submit(package_name: str, content: str):
try:
creds = Credentials.from_rc_file()
creds = Credentials.from_rc_file(logger)
token = get_auth_token(creds)
submission_url = f"{validator_hub_service}/validator/submit"

Expand Down
11 changes: 9 additions & 2 deletions guardrails/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,13 @@ def __init__(
self.base_model = base_model
self._set_tracer(tracer)

credentials = Credentials.from_rc_file(logger)

# Get unique id of user from credentials
self._user_id = Credentials.from_rc_file().id or ""
self._user_id = credentials.id or ""

# Get metrics opt-out from credentials
self._disable_tracer = Credentials.from_rc_file().no_metrics
self._disable_tracer = credentials.no_metrics

# Get id of guard object (that is unique)
self._guard_id = id(self) # id of guard object; not the class
Expand Down Expand Up @@ -683,6 +685,7 @@ def _call_sync(
metadata=metadata,
base_model=self.base_model,
full_schema_reask=full_schema_reask,
disable_tracer=self._disable_tracer,
)
return runner(call_log=call_log, prompt_params=prompt_params)
else:
Expand All @@ -701,6 +704,7 @@ def _call_sync(
metadata=metadata,
base_model=self.base_model,
full_schema_reask=full_schema_reask,
disable_tracer=self._disable_tracer,
)
call = runner(call_log=call_log, prompt_params=prompt_params)
return ValidationOutcome[OT].from_guard_history(call)
Expand Down Expand Up @@ -760,6 +764,7 @@ async def _call_async(
metadata=metadata,
base_model=self.base_model,
full_schema_reask=full_schema_reask,
disable_tracer=self._disable_tracer,
)
call = await runner.async_run(
call_log=call_log, prompt_params=prompt_params
Expand Down Expand Up @@ -1020,6 +1025,7 @@ def _sync_parse(
output=llm_output,
base_model=self.base_model,
full_schema_reask=full_schema_reask,
disable_tracer=self._disable_tracer,
)
call = runner(call_log=call_log, prompt_params=prompt_params)

Expand Down Expand Up @@ -1062,6 +1068,7 @@ async def _async_parse(
output=llm_output,
base_model=self.base_model,
full_schema_reask=full_schema_reask,
disable_tracer=self._disable_tracer,
)
call = await runner.async_run(
call_log=call_log, prompt_params=prompt_params
Expand Down
25 changes: 19 additions & 6 deletions guardrails/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from eliot import add_destinations, start_action
from pydantic import BaseModel

from guardrails.classes.credentials import Credentials
from guardrails.classes.history import Call, Inputs, Iteration, Outputs
from guardrails.datatypes import verify_metadata_requirements
from guardrails.errors import ValidationError
Expand Down Expand Up @@ -68,6 +67,7 @@ def __init__(
Union[Type[BaseModel], Type[List[Type[BaseModel]]]]
] = None,
full_schema_reask: bool = False,
disable_tracer: Optional[bool] = True,
):
if prompt:
assert api, "Must provide an API if a prompt is provided."
Expand Down Expand Up @@ -109,7 +109,7 @@ def __init__(
self.full_schema_reask = full_schema_reask

# Get metrics opt-out from credentials
self._disable_tracer = Credentials.from_rc_file().no_metrics
self._disable_tracer = disable_tracer

if not self._disable_tracer:
# Get the HubTelemetry singleton
Expand Down Expand Up @@ -354,7 +354,7 @@ def validate_msg_history(
iteration = Iteration(inputs=inputs)
call_log.iterations.insert(0, iteration)
validated_msg_history = msg_history_schema.validate(
iteration, msg_str, self.metadata
iteration, msg_str, self.metadata, disable_tracer=self._disable_tracer
)
iteration.outputs.validation_output = validated_msg_history
if isinstance(validated_msg_history, ReAsk):
Expand Down Expand Up @@ -394,7 +394,10 @@ def validate_prompt(
iteration = Iteration(inputs=inputs)
call_log.iterations.insert(0, iteration)
validated_prompt = prompt_schema.validate(
iteration, prompt.source, self.metadata
iteration,
prompt.source,
self.metadata,
disable_tracer=self._disable_tracer,
)
iteration.outputs.validation_output = validated_prompt
if validated_prompt is None:
Expand All @@ -415,7 +418,10 @@ def validate_instructions(
iteration = Iteration(inputs=inputs)
call_log.iterations.insert(0, iteration)
validated_instructions = instructions_schema.validate(
iteration, instructions.source, self.metadata
iteration,
instructions.source,
self.metadata,
disable_tracer=self._disable_tracer,
)
iteration.outputs.validation_output = validated_instructions
if validated_instructions is None:
Expand Down Expand Up @@ -608,7 +614,12 @@ def validate(
"""Validate the output."""
with start_action(action_type="validate", index=index) as action:
validated_output = output_schema.validate(
iteration, parsed_output, self.metadata, attempt_number=index, **kwargs
iteration,
parsed_output,
self.metadata,
attempt_number=index,
disable_tracer=self._disable_tracer,
**kwargs,
)

action.log(
Expand Down Expand Up @@ -682,6 +693,7 @@ def __init__(
Union[Type[BaseModel], Type[List[Type[BaseModel]]]]
] = None,
full_schema_reask: bool = False,
disable_tracer: Optional[bool] = True,
):
super().__init__(
output_schema=output_schema,
Expand All @@ -697,6 +709,7 @@ def __init__(
output=output,
base_model=base_model,
full_schema_reask=full_schema_reask,
disable_tracer=disable_tracer,
)
self.api: Optional[AsyncPromptCallableBase] = api

Expand Down
2 changes: 2 additions & 0 deletions guardrails/schema/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ def validate(
data: Optional[Dict[str, Any]],
metadata: Dict,
attempt_number: int = 0,
disable_tracer: Optional[bool] = True,
**kwargs,
) -> Any:
"""Validate a dictionary of data against the schema.
Expand Down Expand Up @@ -369,6 +370,7 @@ def validate(
metadata=metadata,
validator_setup=validation,
iteration=iteration,
disable_tracer=disable_tracer,
)

if check_refrain(validated_response):
Expand Down
2 changes: 2 additions & 0 deletions guardrails/schema/string_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def validate(
data: Any,
metadata: Dict,
attempt_number: int = 0,
disable_tracer: Optional[bool] = True,
**kwargs,
) -> Any:
"""Validate a dictionary of data against the schema.
Expand Down Expand Up @@ -163,6 +164,7 @@ def validate(
metadata=metadata,
validator_setup=validation,
iteration=iteration,
disable_tracer=disable_tracer,
)

validated_response = {dummy_key: validated_response}
Expand Down
25 changes: 14 additions & 11 deletions guardrails/validator_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple

from guardrails.classes.credentials import Credentials
from guardrails.classes.history import Iteration
from guardrails.datatypes import FieldValidation
from guardrails.errors import ValidationError
Expand All @@ -32,6 +31,9 @@ def key_not_empty(key: str) -> bool:
class ValidatorServiceBase:
"""Base class for validator services."""

def __init__(self, disable_tracer: Optional[bool] = True):
self._disable_tracer = disable_tracer

# NOTE: This is avoiding an issue with multiprocessing.
# If we wrap the validate methods at the class level or anytime before
# loop.run_in_executor is called, multiprocessing fails with a Pickling error.
Expand Down Expand Up @@ -135,10 +137,7 @@ def run_validator(
# this will have to change.
validator_logs.instance_id = id(validator)

# Get metrics opt-out from credentials
disable_tracer = Credentials.from_rc_file().no_metrics

if not disable_tracer:
if not self._disable_tracer:
# Get HubTelemetry singleton and create a new span to
# log the validator usage
_hub_telemetry = HubTelemetry()
Expand Down Expand Up @@ -296,7 +295,6 @@ async def run_validators(
# wait for the parallel tasks to finish
if parallel_tasks:
parallel_results = await asyncio.gather(*parallel_tasks)
iteration.outputs.validator_logs.extend(parallel_results)
validators_logs.extend(parallel_results)

# process the results, handle failures
Expand Down Expand Up @@ -406,7 +404,11 @@ def validate(


def validate(
value: Any, metadata: dict, validator_setup: FieldValidation, iteration: Iteration
value: Any,
metadata: dict,
validator_setup: FieldValidation,
iteration: Iteration,
disable_tracer: Optional[bool] = True,
):
process_count = int(os.environ.get("GUARDRAILS_PROCESS_COUNT", 10))

Expand All @@ -423,11 +425,11 @@ def validate(
"To run asynchronously, specify a process count"
"greater than 1 or unset this environment variable."
)
validator_service = SequentialValidatorService()
validator_service = SequentialValidatorService(disable_tracer)
elif loop is not None and not loop.is_running():
validator_service = AsyncValidatorService()
validator_service = AsyncValidatorService(disable_tracer)
else:
validator_service = SequentialValidatorService()
validator_service = SequentialValidatorService(disable_tracer)
return validator_service.validate(
value,
metadata,
Expand All @@ -441,8 +443,9 @@ async def async_validate(
metadata: dict,
validator_setup: FieldValidation,
iteration: Iteration,
disable_tracer: Optional[bool] = True,
):
validator_service = AsyncValidatorService()
validator_service = AsyncValidatorService(disable_tracer)
return await validator_service.async_validate(
value,
metadata,
Expand Down
21 changes: 14 additions & 7 deletions tests/integration_tests/test_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def guard_initializer(
),
],
)
@pytest.mark.parametrize("multiprocessing_validators", (True,)) # False))
@pytest.mark.parametrize("multiprocessing_validators", (True, False))
def test_entity_extraction_with_reask(
mocker, rail, prompt, test_full_schema_reask, multiprocessing_validators
):
Expand Down Expand Up @@ -172,14 +172,17 @@ def test_entity_extraction_with_reask(
assert first.validation_output == entity_extraction.VALIDATED_OUTPUT_REASK_1

# For reask validator logs
# TODO: Update once we add json_path to the ValidatorLog class
nested_validator_logs = list(
x for x in first.validator_logs if x.value_before_validation == "my chase plan"
two_words_validator_logs = list(
x
for x in first.validator_logs
if x.property_path == "$.fees.1.name" and x.registered_name == "two-words"
)
nested_validator_log = nested_validator_logs[1]

assert nested_validator_log.value_before_validation == "my chase plan"
assert nested_validator_log.value_after_validation == FieldReAsk(
two_words_validator_log = two_words_validator_logs[0]

assert two_words_validator_log.value_before_validation == "my chase plan"

expected_value_after_validation = FieldReAsk(
incorrect_value="my chase plan",
fail_results=[
FailResult(
Expand All @@ -189,6 +192,10 @@ def test_entity_extraction_with_reask(
],
path=["fees", 1, "name"],
)
assert (
two_words_validator_log.value_after_validation
== expected_value_after_validation
)

# For re-asked prompt and output
# second = call.iterations.at(1)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/cli/test_configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_save_configuration_file(mocker):
f"id=f49354e0-80c7-4591-81db-cc2f945e5f1e{os.linesep}",
f"client_id=id{os.linesep}",
f"client_secret=secret{os.linesep}",
f"no_metrics=true{os.linesep}",
"no_metrics=true",
]
)
assert close_spy.call_count == 1
5 changes: 5 additions & 0 deletions tests/unit_tests/mocks/mock_async_validator_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@


class MockAsyncValidatorService:
initialized: bool

def __init__(self, *args, **kwargs):
self.initialized = True

async def async_validate(self, *args):
await asyncio.sleep(0.1)
# The return value doesn't really matter here.
Expand Down
Loading