Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions guardrails/classes/history/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def reask_prompts(self) -> Stack[Optional[str]]:

@property
def instructions(self) -> Optional[str]:
"""The instructions as provided by the user when initializing or calling
the Guard."""
"""The instructions as provided by the user when initializing or
calling the Guard."""
return self.inputs.instructions

@property
Expand Down
9 changes: 9 additions & 0 deletions guardrails/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -1366,6 +1366,10 @@ def validate(self, llm_output: str, *args, **kwargs) -> ValidationOutcome[str]:
# def __call__(self, llm_output: str, *args, **kwargs) -> ValidationOutcome[str]:
# return self.validate(llm_output, *args, **kwargs)

@deprecated(
"""'Guard.invoke' is deprecated and will be removed in \
versions 0.5.x and beyond. Use Guard.to_runnable() instead."""
)
def invoke(
self, input: InputType, config: Optional[RunnableConfig] = None
) -> InputType:
Expand Down Expand Up @@ -1545,3 +1549,8 @@ def _call_server(
)
else:
raise ValueError("Guard does not have an api client!")

def to_runnable(self) -> Runnable:
from guardrails.integrations.langchain.guard_runnable import GuardRunnable

return GuardRunnable(self)
48 changes: 48 additions & 0 deletions guardrails/integrations/langchain/guard_runnable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import json
from copy import deepcopy
from typing import Dict, Optional, cast
from langchain_core.messages import BaseMessage
from langchain_core.runnables import Runnable, RunnableConfig
from guardrails.classes.input_type import InputType
from guardrails.errors import ValidationError
from guardrails.guard import Guard


class GuardRunnable(Runnable):
guard: Guard

def __init__(self, guard: Guard):
self.name = guard.name
self.guard = guard

def invoke(
self, input: InputType, config: Optional[RunnableConfig] = None
) -> InputType:
output = BaseMessage(content="", type="")
str_input = None
input_is_chat_message = False
if isinstance(input, BaseMessage):
input_is_chat_message = True
str_input = str(input.content)
output = deepcopy(input)
else:
str_input = str(input)

response = self.guard.validate(str_input)

validated_output = response.validated_output
if not validated_output:
raise ValidationError(
(
"The response from the LLM failed validation!"
"See `guard.history` for more details."
)
)

if isinstance(validated_output, Dict):
validated_output = json.dumps(validated_output)

if input_is_chat_message:
output.content = validated_output
return cast(InputType, output)
return cast(InputType, validated_output)
43 changes: 43 additions & 0 deletions guardrails/integrations/langchain/validator_runnable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from copy import deepcopy
from typing import Optional, cast
from langchain_core.messages import BaseMessage
from langchain_core.runnables import Runnable, RunnableConfig
from guardrails.classes.input_type import InputType
from guardrails.errors import ValidationError
from guardrails.validator_base import FailResult, Validator


class ValidatorRunnable(Runnable):
validator: Validator

def __init__(self, validator: Validator):
self.name = validator.rail_alias
self.validator = validator

def invoke(
self, input: InputType, config: Optional[RunnableConfig] = None
) -> InputType:
output = BaseMessage(content="", type="")
str_input = None
input_is_chat_message = False
if isinstance(input, BaseMessage):
input_is_chat_message = True
str_input = str(input.content)
output = deepcopy(input)
else:
str_input = str(input)

response = self.validator.validate(str_input, self.validator._metadata)

if isinstance(response, FailResult):
raise ValidationError(
(
"The response from the LLM failed validation!"
f" {response.error_message}"
)
)

if input_is_chat_message:
output.content = str_input
return cast(InputType, output)
return cast(InputType, str_input)
12 changes: 12 additions & 0 deletions guardrails/validator_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Union,
cast,
)
from typing_extensions import deprecated
from warnings import warn

from langchain_core.messages import BaseMessage
Expand Down Expand Up @@ -541,6 +542,10 @@ def __stringify__(self):
}
)

@deprecated(
"""'Validator.invoke' is deprecated and will be removed in \
versions 0.5.x and beyond. Use Validator.to_runnable() instead."""
)
def invoke(
self, input: InputType, config: Optional[RunnableConfig] = None
) -> InputType:
Expand Down Expand Up @@ -592,5 +597,12 @@ def with_metadata(self, metadata: Dict[str, Any]):
self._metadata = metadata
return self

def to_runnable(self) -> Runnable:
from guardrails.integrations.langchain.validator_runnable import (
ValidatorRunnable,
)

return ValidatorRunnable(self)


ValidatorSpec = Union[Validator, Tuple[Union[Validator, str, Callable], str]]
41 changes: 38 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pydoc-markdown = "4.8.2"
opentelemetry-sdk = "1.20.0"
opentelemetry-exporter-otlp-proto-grpc = "1.20.0"
opentelemetry-exporter-otlp-proto-http = "1.20.0"
langchain-core = "^0.1.18"
langchain-core = ">=0.1,<0.3"
coloredlogs = "^15.0.1"
requests = "^2.31.0"
guardrails-api-client = "^0.1.1"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import Optional

import pytest

from guardrails.guard import Guard


@pytest.mark.parametrize(
"output,throws",
[
("Ice cream is frozen.", False),
("Ice cream is a frozen dairy product that is consumed in many places.", True),
("This response isn't relevant.", True),
],
)
def test_guard_as_runnable(output: str, throws: bool):
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnableConfig

from guardrails.errors import ValidationError
from guardrails.validators import ReadingTime, RegexMatch

class MockModel(Runnable):
def invoke(
self, input: LanguageModelInput, config: Optional[RunnableConfig] = None
) -> BaseMessage:
return AIMessage(content=output)

prompt = ChatPromptTemplate.from_template("ELIF: {topic}")
model = MockModel()
guard = (
Guard()
.use(
RegexMatch("Ice cream", match_type="search", on_fail="refrain"), on="output"
)
.use(ReadingTime(0.05, on_fail="refrain")) # 3 seconds
)
output_parser = StrOutputParser()

chain = prompt | model | guard.to_runnable() | output_parser

topic = "ice cream"
if throws:
with pytest.raises(ValidationError) as exc_info:
chain.invoke({"topic": topic})

assert str(exc_info.value) == (
"The response from the LLM failed validation!"
"See `guard.history` for more details."
)

assert guard.history.last.status == "fail"
assert guard.history.last.status == "fail"

else:
result = chain.invoke({"topic": topic})

assert result == output
Loading