From 31cb5a46c6071a13d621e775d73d00e9893225c3 Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Mon, 17 Jun 2024 14:21:46 -0500 Subject: [PATCH 1/2] ensure validators are only initialized once --- guardrails/guard.py | 15 ++- tests/integration_tests/test_guard.py | 186 +++++++++++++++++++++++++- 2 files changed, 197 insertions(+), 4 deletions(-) diff --git a/guardrails/guard.py b/guardrails/guard.py index f1035bc17..c113d4162 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -70,6 +70,7 @@ set_tracer, set_tracer_context, ) +from guardrails.types.on_fail import OnFailAction from guardrails.types.pydantic import ModelOrListOfModels from guardrails.utils.naming_utils import random_id from guardrails.utils.safe_get import safe_get @@ -256,8 +257,18 @@ def _fill_validator_map(self): for v in entry if ( v.rail_alias == ref.id - and v.on_fail_descriptor == ref.on_fail - and v.get_args() == ref.kwargs + and ( + v.on_fail_descriptor == ref.on_fail + or ( + v.on_fail_descriptor == OnFailAction.NOOP + and not ref.on_fail + ) + ) + and ( + v.get_args() == ref.kwargs + or not v.get_args() + and not ref.kwargs + ) ) ], 0, diff --git a/tests/integration_tests/test_guard.py b/tests/integration_tests/test_guard.py index 20945d7f4..586197932 100644 --- a/tests/integration_tests/test_guard.py +++ b/tests/integration_tests/test_guard.py @@ -5,8 +5,8 @@ from typing import Optional, Union import pytest -from pydantic import BaseModel -from guardrails_api_client import Guard as IGuard, GuardHistory +from pydantic import BaseModel, Field +from guardrails_api_client import Guard as IGuard, GuardHistory, ValidatorReference import guardrails as gd from guardrails.actions.reask import SkeletonReAsk @@ -23,6 +23,7 @@ RegexMatch, ValidLength, ValidChoices, + LowerCase, ) from .mock_llm_outputs import ( @@ -1186,3 +1187,184 @@ def test_guard_from_pydantic_with_mock_hf_model(): tokenizer=tokenizer, prompt="Don't care about the output. Just don't crash.", ) + + +class TestValidatorInitializedOnce: + def test_guard_init(self, mocker): + init_spy = mocker.spy(LowerCase, "__init__") + + guard = Guard(validators=[ValidatorReference(id="lower-case", on="$")]) + + # Validator is not initialized until the guard is used + assert init_spy.call_count == 0 + + guard.parse("some-name") + + assert init_spy.call_count == 1 + + # Validator is not initialized again + guard.parse("some-other-name") + + assert init_spy.call_count == 1 + + def test_from_rail(self, mocker): + init_spy = mocker.spy(LowerCase, "__init__") + + guard = Guard.from_rail_string( + """ + + + + """ + ) + + assert init_spy.call_count == 1 + + # Validator is not initialized again + guard.parse("some-name") + + assert init_spy.call_count == 1 + + def test_from_pydantic_validator_instance(self, mocker): + init_spy = mocker.spy(LowerCase, "__init__") + + class MyModel(BaseModel): + name: str = Field(..., validators=[LowerCase()]) + + guard = Guard().from_pydantic(MyModel) + + assert init_spy.call_count == 1 + + # Validator is not initialized again + guard.parse('{ "name": "some-name" }') + + assert init_spy.call_count == 1 + + def test_from_pydantic_str(self, mocker): + init_spy = mocker.spy(LowerCase, "__init__") + + class MyModel(BaseModel): + name: str = Field(..., validators=[("lower-case", "noop")]) + + guard = Guard().from_pydantic(MyModel) + + assert init_spy.call_count == 1 + + # Validator is not initialized again + guard.parse('{ "name": "some-name" }') + + assert init_spy.call_count == 1 + + def test_from_pydantic_same_instance_on_two_models(self, mocker): + init_spy = mocker.spy(LowerCase, "__init__") + + lower_case = LowerCase() + + class MyModel(BaseModel): + name: str = Field(..., validators=[lower_case]) + + class MyOtherModel(BaseModel): + name: str = Field(..., validators=[lower_case]) + + guard_1 = Guard.from_pydantic(MyModel) + guard_2 = Guard.from_pydantic(MyOtherModel) + + assert init_spy.call_count == 1 + + # Validator is not initialized again + guard_1.parse("some-name") + + assert init_spy.call_count == 1 + + guard_2.parse("some-other-name") + + assert init_spy.call_count == 1 + + def test_guard_use_instance(self, mocker): + init_spy = mocker.spy(LowerCase, "__init__") + + guard = Guard().use(LowerCase()) + + assert init_spy.call_count == 1 + + # Validator is not initialized again + guard.parse("some-name") + + assert init_spy.call_count == 1 + + def test_guard_use_class(self, mocker): + init_spy = mocker.spy(LowerCase, "__init__") + + guard = Guard().use(LowerCase) + + assert init_spy.call_count == 1 + + # Validator is not initialized again + guard.parse("some-name") + + assert init_spy.call_count == 1 + + def test_guard_use_same_instance_on_two_guards(self, mocker): + init_spy = mocker.spy(LowerCase, "__init__") + + lower_case = LowerCase() + + guard_1 = Guard().use(lower_case) + guard_2 = Guard().use(lower_case) + + assert init_spy.call_count == 1 + + # Validator is not initialized again + guard_1.parse("some-name") + + assert init_spy.call_count == 1 + + guard_2.parse("some-other-name") + + assert init_spy.call_count == 1 + + def test_guard_use_many_instance(self, mocker): + init_spy = mocker.spy(LowerCase, "__init__") + + guard = Guard().use_many(LowerCase()) + + assert init_spy.call_count == 1 + + # Validator is not initialized again + guard.parse("some-name") + + assert init_spy.call_count == 1 + + def test_guard_use_many_class(self, mocker): + init_spy = mocker.spy(LowerCase, "__init__") + + guard = Guard().use_many(LowerCase) + + assert init_spy.call_count == 1 + + # Validator is not initialized again + guard.parse("some-name") + + assert init_spy.call_count == 1 + + def test_guard_use_many_same_instance_on_two_guards(self, mocker): + init_spy = mocker.spy(LowerCase, "__init__") + + lower_case = LowerCase() + + guard_1 = Guard().use_many(lower_case) + guard_2 = Guard().use_many(lower_case) + + assert init_spy.call_count == 1 + + # Validator is not initialized again + guard_1.parse("some-name") + + assert init_spy.call_count == 1 + + guard_2.parse("some-other-name") + + assert init_spy.call_count == 1 From 72d61465ba7c54382421a3d648d1915c051127fb Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Mon, 17 Jun 2024 15:29:09 -0500 Subject: [PATCH 2/2] refactor list comprehension --- guardrails/guard.py | 37 +++++++++++++------------------------ 1 file changed, 13 insertions(+), 24 deletions(-) diff --git a/guardrails/guard.py b/guardrails/guard.py index c113d4162..6c2b108c7 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -73,7 +73,6 @@ from guardrails.types.on_fail import OnFailAction from guardrails.types.pydantic import ModelOrListOfModels from guardrails.utils.naming_utils import random_id -from guardrails.utils.safe_get import safe_get from guardrails.utils.api_utils import extract_serializeable_metadata from guardrails.utils.hub_telemetry_utils import HubTelemetry from guardrails.classes.llm.llm_response import LLMResponse @@ -251,29 +250,19 @@ def _fill_validator_map(self): entry: List[Validator] = self._validator_map.get(ref.on, []) # type: ignore # Check if the validator from the reference # has an instance in the validator_map - v = safe_get( - [ - v - for v in entry - if ( - v.rail_alias == ref.id - and ( - v.on_fail_descriptor == ref.on_fail - or ( - v.on_fail_descriptor == OnFailAction.NOOP - and not ref.on_fail - ) - ) - and ( - v.get_args() == ref.kwargs - or not v.get_args() - and not ref.kwargs - ) - ) - ], - 0, - ) - if not v: + existing_instance: Optional[Validator] = None + for v in entry: + same_id = v.rail_alias == ref.id + same_on_fail = v.on_fail_descriptor == ref.on_fail or ( # is default + v.on_fail_descriptor == OnFailAction.NOOP and not ref.on_fail + ) + same_args = v.get_args() == ref.kwargs or ( # Both are empty + not v.get_args() and not ref.kwargs + ) + if same_id and same_on_fail and same_args: + existing_instance = v + break + if not existing_instance: validator = parse_validator_reference(ref) if validator: entry.append(validator)