Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions guardrails/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@
set_tracer,
set_tracer_context,
)
from guardrails.types.on_fail import OnFailAction
from guardrails.types.pydantic import ModelOrListOfModels
from guardrails.utils.naming_utils import random_id
from guardrails.utils.safe_get import safe_get
from guardrails.utils.api_utils import extract_serializeable_metadata
from guardrails.utils.hub_telemetry_utils import HubTelemetry
from guardrails.classes.llm.llm_response import LLMResponse
Expand Down Expand Up @@ -250,19 +250,19 @@ def _fill_validator_map(self):
entry: List[Validator] = self._validator_map.get(ref.on, []) # type: ignore
# Check if the validator from the reference
# has an instance in the validator_map
v = safe_get(
[
v
for v in entry
if (
v.rail_alias == ref.id
and v.on_fail_descriptor == ref.on_fail
and v.get_args() == ref.kwargs
)
],
0,
)
if not v:
existing_instance: Optional[Validator] = None
for v in entry:
same_id = v.rail_alias == ref.id
same_on_fail = v.on_fail_descriptor == ref.on_fail or ( # is default
v.on_fail_descriptor == OnFailAction.NOOP and not ref.on_fail
)
same_args = v.get_args() == ref.kwargs or ( # Both are empty
not v.get_args() and not ref.kwargs
)
if same_id and same_on_fail and same_args:
existing_instance = v
break
if not existing_instance:
validator = parse_validator_reference(ref)
if validator:
entry.append(validator)
Expand Down
186 changes: 184 additions & 2 deletions tests/integration_tests/test_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from typing import Optional, Union

import pytest
from pydantic import BaseModel
from guardrails_api_client import Guard as IGuard, GuardHistory
from pydantic import BaseModel, Field
from guardrails_api_client import Guard as IGuard, GuardHistory, ValidatorReference

import guardrails as gd
from guardrails.actions.reask import SkeletonReAsk
Expand All @@ -23,6 +23,7 @@
RegexMatch,
ValidLength,
ValidChoices,
LowerCase,
)

from .mock_llm_outputs import (
Expand Down Expand Up @@ -1186,3 +1187,184 @@ def test_guard_from_pydantic_with_mock_hf_model():
tokenizer=tokenizer,
prompt="Don't care about the output. Just don't crash.",
)


class TestValidatorInitializedOnce:
def test_guard_init(self, mocker):
init_spy = mocker.spy(LowerCase, "__init__")

guard = Guard(validators=[ValidatorReference(id="lower-case", on="$")])

# Validator is not initialized until the guard is used
assert init_spy.call_count == 0

guard.parse("some-name")

assert init_spy.call_count == 1

# Validator is not initialized again
guard.parse("some-other-name")

assert init_spy.call_count == 1

def test_from_rail(self, mocker):
init_spy = mocker.spy(LowerCase, "__init__")

guard = Guard.from_rail_string(
"""
<rail version="0.1">
<output
type="string"
validators="lower-case"
/>
</rail>
"""
)

assert init_spy.call_count == 1

# Validator is not initialized again
guard.parse("some-name")

assert init_spy.call_count == 1

def test_from_pydantic_validator_instance(self, mocker):
init_spy = mocker.spy(LowerCase, "__init__")

class MyModel(BaseModel):
name: str = Field(..., validators=[LowerCase()])

guard = Guard().from_pydantic(MyModel)

assert init_spy.call_count == 1

# Validator is not initialized again
guard.parse('{ "name": "some-name" }')

assert init_spy.call_count == 1

def test_from_pydantic_str(self, mocker):
init_spy = mocker.spy(LowerCase, "__init__")

class MyModel(BaseModel):
name: str = Field(..., validators=[("lower-case", "noop")])

guard = Guard().from_pydantic(MyModel)

assert init_spy.call_count == 1

# Validator is not initialized again
guard.parse('{ "name": "some-name" }')

assert init_spy.call_count == 1

def test_from_pydantic_same_instance_on_two_models(self, mocker):
init_spy = mocker.spy(LowerCase, "__init__")

lower_case = LowerCase()

class MyModel(BaseModel):
name: str = Field(..., validators=[lower_case])

class MyOtherModel(BaseModel):
name: str = Field(..., validators=[lower_case])

guard_1 = Guard.from_pydantic(MyModel)
guard_2 = Guard.from_pydantic(MyOtherModel)

assert init_spy.call_count == 1

# Validator is not initialized again
guard_1.parse("some-name")

assert init_spy.call_count == 1

guard_2.parse("some-other-name")

assert init_spy.call_count == 1

def test_guard_use_instance(self, mocker):
init_spy = mocker.spy(LowerCase, "__init__")

guard = Guard().use(LowerCase())

assert init_spy.call_count == 1

# Validator is not initialized again
guard.parse("some-name")

assert init_spy.call_count == 1

def test_guard_use_class(self, mocker):
init_spy = mocker.spy(LowerCase, "__init__")

guard = Guard().use(LowerCase)

assert init_spy.call_count == 1

# Validator is not initialized again
guard.parse("some-name")

assert init_spy.call_count == 1

def test_guard_use_same_instance_on_two_guards(self, mocker):
init_spy = mocker.spy(LowerCase, "__init__")

lower_case = LowerCase()

guard_1 = Guard().use(lower_case)
guard_2 = Guard().use(lower_case)

assert init_spy.call_count == 1

# Validator is not initialized again
guard_1.parse("some-name")

assert init_spy.call_count == 1

guard_2.parse("some-other-name")

assert init_spy.call_count == 1

def test_guard_use_many_instance(self, mocker):
init_spy = mocker.spy(LowerCase, "__init__")

guard = Guard().use_many(LowerCase())

assert init_spy.call_count == 1

# Validator is not initialized again
guard.parse("some-name")

assert init_spy.call_count == 1

def test_guard_use_many_class(self, mocker):
init_spy = mocker.spy(LowerCase, "__init__")

guard = Guard().use_many(LowerCase)

assert init_spy.call_count == 1

# Validator is not initialized again
guard.parse("some-name")

assert init_spy.call_count == 1

def test_guard_use_many_same_instance_on_two_guards(self, mocker):
init_spy = mocker.spy(LowerCase, "__init__")

lower_case = LowerCase()

guard_1 = Guard().use_many(lower_case)
guard_2 = Guard().use_many(lower_case)

assert init_spy.call_count == 1

# Validator is not initialized again
guard_1.parse("some-name")

assert init_spy.call_count == 1

guard_2.parse("some-other-name")

assert init_spy.call_count == 1