diff --git a/guardrails/classes/history/call.py b/guardrails/classes/history/call.py index 287d344cc..503cefb9a 100644 --- a/guardrails/classes/history/call.py +++ b/guardrails/classes/history/call.py @@ -99,8 +99,8 @@ def reask_prompts(self) -> Stack[Optional[str]]: @property def instructions(self) -> Optional[str]: - """The instructions as provided by the user when initializing or calling - the Guard.""" + """The instructions as provided by the user when initializing or + calling the Guard.""" return self.inputs.instructions @property diff --git a/guardrails/guard.py b/guardrails/guard.py index 335c832af..63e9e6da9 100644 --- a/guardrails/guard.py +++ b/guardrails/guard.py @@ -1366,6 +1366,10 @@ def validate(self, llm_output: str, *args, **kwargs) -> ValidationOutcome[str]: # def __call__(self, llm_output: str, *args, **kwargs) -> ValidationOutcome[str]: # return self.validate(llm_output, *args, **kwargs) + @deprecated( + """'Guard.invoke' is deprecated and will be removed in \ +versions 0.5.x and beyond. Use Guard.to_runnable() instead.""" + ) def invoke( self, input: InputType, config: Optional[RunnableConfig] = None ) -> InputType: @@ -1545,3 +1549,8 @@ def _call_server( ) else: raise ValueError("Guard does not have an api client!") + + def to_runnable(self) -> Runnable: + from guardrails.integrations.langchain.guard_runnable import GuardRunnable + + return GuardRunnable(self) diff --git a/guardrails/integrations/langchain/guard_runnable.py b/guardrails/integrations/langchain/guard_runnable.py new file mode 100644 index 000000000..873357bb2 --- /dev/null +++ b/guardrails/integrations/langchain/guard_runnable.py @@ -0,0 +1,48 @@ +import json +from copy import deepcopy +from typing import Dict, Optional, cast +from langchain_core.messages import BaseMessage +from langchain_core.runnables import Runnable, RunnableConfig +from guardrails.classes.input_type import InputType +from guardrails.errors import ValidationError +from guardrails.guard import Guard + + +class GuardRunnable(Runnable): + guard: Guard + + def __init__(self, guard: Guard): + self.name = guard.name + self.guard = guard + + def invoke( + self, input: InputType, config: Optional[RunnableConfig] = None + ) -> InputType: + output = BaseMessage(content="", type="") + str_input = None + input_is_chat_message = False + if isinstance(input, BaseMessage): + input_is_chat_message = True + str_input = str(input.content) + output = deepcopy(input) + else: + str_input = str(input) + + response = self.guard.validate(str_input) + + validated_output = response.validated_output + if not validated_output: + raise ValidationError( + ( + "The response from the LLM failed validation!" + "See `guard.history` for more details." + ) + ) + + if isinstance(validated_output, Dict): + validated_output = json.dumps(validated_output) + + if input_is_chat_message: + output.content = validated_output + return cast(InputType, output) + return cast(InputType, validated_output) diff --git a/guardrails/integrations/langchain/validator_runnable.py b/guardrails/integrations/langchain/validator_runnable.py new file mode 100644 index 000000000..7740bd73a --- /dev/null +++ b/guardrails/integrations/langchain/validator_runnable.py @@ -0,0 +1,43 @@ +from copy import deepcopy +from typing import Optional, cast +from langchain_core.messages import BaseMessage +from langchain_core.runnables import Runnable, RunnableConfig +from guardrails.classes.input_type import InputType +from guardrails.errors import ValidationError +from guardrails.validator_base import FailResult, Validator + + +class ValidatorRunnable(Runnable): + validator: Validator + + def __init__(self, validator: Validator): + self.name = validator.rail_alias + self.validator = validator + + def invoke( + self, input: InputType, config: Optional[RunnableConfig] = None + ) -> InputType: + output = BaseMessage(content="", type="") + str_input = None + input_is_chat_message = False + if isinstance(input, BaseMessage): + input_is_chat_message = True + str_input = str(input.content) + output = deepcopy(input) + else: + str_input = str(input) + + response = self.validator.validate(str_input, self.validator._metadata) + + if isinstance(response, FailResult): + raise ValidationError( + ( + "The response from the LLM failed validation!" + f" {response.error_message}" + ) + ) + + if input_is_chat_message: + output.content = str_input + return cast(InputType, output) + return cast(InputType, str_input) diff --git a/guardrails/validator_base.py b/guardrails/validator_base.py index 2b92915d9..fa01573f1 100644 --- a/guardrails/validator_base.py +++ b/guardrails/validator_base.py @@ -15,6 +15,7 @@ Union, cast, ) +from typing_extensions import deprecated from warnings import warn from langchain_core.messages import BaseMessage @@ -541,6 +542,10 @@ def __stringify__(self): } ) + @deprecated( + """'Validator.invoke' is deprecated and will be removed in \ + versions 0.5.x and beyond. Use Validator.to_runnable() instead.""" + ) def invoke( self, input: InputType, config: Optional[RunnableConfig] = None ) -> InputType: @@ -592,5 +597,12 @@ def with_metadata(self, metadata: Dict[str, Any]): self._metadata = metadata return self + def to_runnable(self) -> Runnable: + from guardrails.integrations.langchain.validator_runnable import ( + ValidatorRunnable, + ) + + return ValidatorRunnable(self) + ValidatorSpec = Union[Validator, Tuple[Union[Validator, str, Callable], str]] diff --git a/poetry.lock b/poetry.lock index 5b3139c96..228329555 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -1662,29 +1662,62 @@ optional = true python-versions = ">=3.7" files = [ {file = "greenlet-3.0.3-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:9da2bd29ed9e4f15955dd1595ad7bc9320308a3b766ef7f837e23ad4b4aac31a"}, + {file = "greenlet-3.0.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d353cadd6083fdb056bb46ed07e4340b0869c305c8ca54ef9da3421acbdf6881"}, + {file = "greenlet-3.0.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dca1e2f3ca00b84a396bc1bce13dd21f680f035314d2379c4160c98153b2059b"}, + {file = "greenlet-3.0.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3ed7fb269f15dc662787f4119ec300ad0702fa1b19d2135a37c2c4de6fadfd4a"}, {file = "greenlet-3.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd4f49ae60e10adbc94b45c0b5e6a179acc1736cf7a90160b404076ee283cf83"}, {file = "greenlet-3.0.3-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:73a411ef564e0e097dbe7e866bb2dda0f027e072b04da387282b02c308807405"}, + {file = "greenlet-3.0.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:7f362975f2d179f9e26928c5b517524e89dd48530a0202570d55ad6ca5d8a56f"}, {file = "greenlet-3.0.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:649dde7de1a5eceb258f9cb00bdf50e978c9db1b996964cd80703614c86495eb"}, + {file = "greenlet-3.0.3-cp310-cp310-win_amd64.whl", hash = "sha256:68834da854554926fbedd38c76e60c4a2e3198c6fbed520b106a8986445caaf9"}, {file = "greenlet-3.0.3-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:b1b5667cced97081bf57b8fa1d6bfca67814b0afd38208d52538316e9422fc61"}, + {file = "greenlet-3.0.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:52f59dd9c96ad2fc0d5724107444f76eb20aaccb675bf825df6435acb7703559"}, + {file = "greenlet-3.0.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:afaff6cf5200befd5cec055b07d1c0a5a06c040fe5ad148abcd11ba6ab9b114e"}, + {file = "greenlet-3.0.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fe754d231288e1e64323cfad462fcee8f0288654c10bdf4f603a39ed923bef33"}, {file = "greenlet-3.0.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2797aa5aedac23af156bbb5a6aa2cd3427ada2972c828244eb7d1b9255846379"}, {file = "greenlet-3.0.3-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b7f009caad047246ed379e1c4dbcb8b020f0a390667ea74d2387be2998f58a22"}, + {file = "greenlet-3.0.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c5e1536de2aad7bf62e27baf79225d0d64360d4168cf2e6becb91baf1ed074f3"}, {file = "greenlet-3.0.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:894393ce10ceac937e56ec00bb71c4c2f8209ad516e96033e4b3b1de270e200d"}, + {file = "greenlet-3.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:1ea188d4f49089fc6fb283845ab18a2518d279c7cd9da1065d7a84e991748728"}, {file = "greenlet-3.0.3-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:70fb482fdf2c707765ab5f0b6655e9cfcf3780d8d87355a063547b41177599be"}, + {file = "greenlet-3.0.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d4d1ac74f5c0c0524e4a24335350edad7e5f03b9532da7ea4d3c54d527784f2e"}, + {file = "greenlet-3.0.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:149e94a2dd82d19838fe4b2259f1b6b9957d5ba1b25640d2380bea9c5df37676"}, + {file = "greenlet-3.0.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:15d79dd26056573940fcb8c7413d84118086f2ec1a8acdfa854631084393efcc"}, {file = "greenlet-3.0.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:881b7db1ebff4ba09aaaeae6aa491daeb226c8150fc20e836ad00041bcb11230"}, {file = "greenlet-3.0.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fcd2469d6a2cf298f198f0487e0a5b1a47a42ca0fa4dfd1b6862c999f018ebbf"}, + {file = "greenlet-3.0.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:1f672519db1796ca0d8753f9e78ec02355e862d0998193038c7073045899f305"}, {file = "greenlet-3.0.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2516a9957eed41dd8f1ec0c604f1cdc86758b587d964668b5b196a9db5bfcde6"}, + {file = "greenlet-3.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:bba5387a6975598857d86de9eac14210a49d554a77eb8261cc68b7d082f78ce2"}, {file = "greenlet-3.0.3-cp37-cp37m-macosx_11_0_universal2.whl", hash = "sha256:5b51e85cb5ceda94e79d019ed36b35386e8c37d22f07d6a751cb659b180d5274"}, + {file = "greenlet-3.0.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:daf3cb43b7cf2ba96d614252ce1684c1bccee6b2183a01328c98d36fcd7d5cb0"}, + {file = "greenlet-3.0.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:99bf650dc5d69546e076f413a87481ee1d2d09aaaaaca058c9251b6d8c14783f"}, + {file = "greenlet-3.0.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2dd6e660effd852586b6a8478a1d244b8dc90ab5b1321751d2ea15deb49ed414"}, {file = "greenlet-3.0.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3391d1e16e2a5a1507d83e4a8b100f4ee626e8eca43cf2cadb543de69827c4c"}, {file = "greenlet-3.0.3-cp37-cp37m-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e1f145462f1fa6e4a4ae3c0f782e580ce44d57c8f2c7aae1b6fa88c0b2efdb41"}, + {file = "greenlet-3.0.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:1a7191e42732df52cb5f39d3527217e7ab73cae2cb3694d241e18f53d84ea9a7"}, {file = "greenlet-3.0.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:0448abc479fab28b00cb472d278828b3ccca164531daab4e970a0458786055d6"}, + {file = "greenlet-3.0.3-cp37-cp37m-win32.whl", hash = "sha256:b542be2440edc2d48547b5923c408cbe0fc94afb9f18741faa6ae970dbcb9b6d"}, + {file = "greenlet-3.0.3-cp37-cp37m-win_amd64.whl", hash = "sha256:01bc7ea167cf943b4c802068e178bbf70ae2e8c080467070d01bfa02f337ee67"}, {file = "greenlet-3.0.3-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:1996cb9306c8595335bb157d133daf5cf9f693ef413e7673cb07e3e5871379ca"}, + {file = "greenlet-3.0.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3ddc0f794e6ad661e321caa8d2f0a55ce01213c74722587256fb6566049a8b04"}, + {file = "greenlet-3.0.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c9db1c18f0eaad2f804728c67d6c610778456e3e1cc4ab4bbd5eeb8e6053c6fc"}, + {file = "greenlet-3.0.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7170375bcc99f1a2fbd9c306f5be8764eaf3ac6b5cb968862cad4c7057756506"}, {file = "greenlet-3.0.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b66c9c1e7ccabad3a7d037b2bcb740122a7b17a53734b7d72a344ce39882a1b"}, {file = "greenlet-3.0.3-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:098d86f528c855ead3479afe84b49242e174ed262456c342d70fc7f972bc13c4"}, + {file = "greenlet-3.0.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:81bb9c6d52e8321f09c3d165b2a78c680506d9af285bfccbad9fb7ad5a5da3e5"}, {file = "greenlet-3.0.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fd096eb7ffef17c456cfa587523c5f92321ae02427ff955bebe9e3c63bc9f0da"}, + {file = "greenlet-3.0.3-cp38-cp38-win32.whl", hash = "sha256:d46677c85c5ba00a9cb6f7a00b2bfa6f812192d2c9f7d9c4f6a55b60216712f3"}, + {file = "greenlet-3.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:419b386f84949bf0e7c73e6032e3457b82a787c1ab4a0e43732898a761cc9dbf"}, {file = "greenlet-3.0.3-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:da70d4d51c8b306bb7a031d5cff6cc25ad253affe89b70352af5f1cb68e74b53"}, + {file = "greenlet-3.0.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:086152f8fbc5955df88382e8a75984e2bb1c892ad2e3c80a2508954e52295257"}, + {file = "greenlet-3.0.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d73a9fe764d77f87f8ec26a0c85144d6a951a6c438dfe50487df5595c6373eac"}, + {file = "greenlet-3.0.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b7dcbe92cc99f08c8dd11f930de4d99ef756c3591a5377d1d9cd7dd5e896da71"}, {file = "greenlet-3.0.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1551a8195c0d4a68fac7a4325efac0d541b48def35feb49d803674ac32582f61"}, {file = "greenlet-3.0.3-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:64d7675ad83578e3fc149b617a444fab8efdafc9385471f868eb5ff83e446b8b"}, + {file = "greenlet-3.0.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b37eef18ea55f2ffd8f00ff8fe7c8d3818abd3e25fb73fae2ca3b672e333a7a6"}, {file = "greenlet-3.0.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:77457465d89b8263bca14759d7c1684df840b6811b2499838cc5b040a8b5b113"}, + {file = "greenlet-3.0.3-cp39-cp39-win32.whl", hash = "sha256:57e8974f23e47dac22b83436bdcf23080ade568ce77df33159e019d161ce1d1e"}, + {file = "greenlet-3.0.3-cp39-cp39-win_amd64.whl", hash = "sha256:c5ee858cfe08f34712f548c3c363e807e7186f03ad7a5039ebadb29e8c6be067"}, {file = "greenlet-3.0.3.tar.gz", hash = "sha256:43374442353259554ce33599da8b692d5aa96f8976d567d4badf263371fbe491"}, ] @@ -2807,6 +2840,7 @@ files = [ {file = "lxml-4.9.4-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:e8f9f93a23634cfafbad6e46ad7d09e0f4a25a2400e4a64b1b7b7c0fbaa06d9d"}, {file = "lxml-4.9.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3f3f00a9061605725df1816f5713d10cd94636347ed651abdbc75828df302b20"}, {file = "lxml-4.9.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:953dd5481bd6252bd480d6ec431f61d7d87fdcbbb71b0d2bdcfc6ae00bb6fb10"}, + {file = "lxml-4.9.4-cp312-cp312-win32.whl", hash = "sha256:266f655d1baff9c47b52f529b5f6bec33f66042f65f7c56adde3fcf2ed62ae8b"}, {file = "lxml-4.9.4-cp312-cp312-win_amd64.whl", hash = "sha256:f1faee2a831fe249e1bae9cbc68d3cd8a30f7e37851deee4d7962b17c410dd56"}, {file = "lxml-4.9.4-cp35-cp35m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:23d891e5bdc12e2e506e7d225d6aa929e0a0368c9916c1fddefab88166e98b20"}, {file = "lxml-4.9.4-cp35-cp35m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:e96a1788f24d03e8d61679f9881a883ecdf9c445a38f9ae3f3f193ab6c591c66"}, @@ -4349,8 +4383,8 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.20.3", markers = "python_version < \"3.10\""}, - {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, + {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -5343,6 +5377,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -8321,4 +8356,4 @@ vectordb = ["faiss-cpu", "numpy"] [metadata] lock-version = "2.0" python-versions = "^3.8.1" -content-hash = "ebae16aa66c7e7789668d5f7ec3779ee60055f427df9b4ede98596a3a6bc5d2f" +content-hash = "fe6b6c42df209d16b013d3b8cd070399bb5804d4c10a28868a92d513c7fc520a" diff --git a/pyproject.toml b/pyproject.toml index f8413f823..64d9bd868 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,7 @@ pydoc-markdown = "4.8.2" opentelemetry-sdk = "1.20.0" opentelemetry-exporter-otlp-proto-grpc = "1.20.0" opentelemetry-exporter-otlp-proto-http = "1.20.0" -langchain-core = "^0.1.18" +langchain-core = ">=0.1,<0.3" coloredlogs = "^15.0.1" requests = "^2.31.0" guardrails-api-client = "^0.1.1" diff --git a/tests/integration_tests/integrations/langchain/test_guard_runnable.py b/tests/integration_tests/integrations/langchain/test_guard_runnable.py new file mode 100644 index 000000000..02205fe0c --- /dev/null +++ b/tests/integration_tests/integrations/langchain/test_guard_runnable.py @@ -0,0 +1,61 @@ +from typing import Optional + +import pytest + +from guardrails.guard import Guard + + +@pytest.mark.parametrize( + "output,throws", + [ + ("Ice cream is frozen.", False), + ("Ice cream is a frozen dairy product that is consumed in many places.", True), + ("This response isn't relevant.", True), + ], +) +def test_guard_as_runnable(output: str, throws: bool): + from langchain_core.language_models import LanguageModelInput + from langchain_core.messages import AIMessage, BaseMessage + from langchain_core.output_parsers import StrOutputParser + from langchain_core.prompts import ChatPromptTemplate + from langchain_core.runnables import Runnable, RunnableConfig + + from guardrails.errors import ValidationError + from guardrails.validators import ReadingTime, RegexMatch + + class MockModel(Runnable): + def invoke( + self, input: LanguageModelInput, config: Optional[RunnableConfig] = None + ) -> BaseMessage: + return AIMessage(content=output) + + prompt = ChatPromptTemplate.from_template("ELIF: {topic}") + model = MockModel() + guard = ( + Guard() + .use( + RegexMatch("Ice cream", match_type="search", on_fail="refrain"), on="output" + ) + .use(ReadingTime(0.05, on_fail="refrain")) # 3 seconds + ) + output_parser = StrOutputParser() + + chain = prompt | model | guard.to_runnable() | output_parser + + topic = "ice cream" + if throws: + with pytest.raises(ValidationError) as exc_info: + chain.invoke({"topic": topic}) + + assert str(exc_info.value) == ( + "The response from the LLM failed validation!" + "See `guard.history` for more details." + ) + + assert guard.history.last.status == "fail" + assert guard.history.last.status == "fail" + + else: + result = chain.invoke({"topic": topic}) + + assert result == output diff --git a/tests/integration_tests/integrations/langchain/test_validator_runnable.py b/tests/integration_tests/integrations/langchain/test_validator_runnable.py new file mode 100644 index 000000000..a42546d4e --- /dev/null +++ b/tests/integration_tests/integrations/langchain/test_validator_runnable.py @@ -0,0 +1,57 @@ +from typing import Optional + +import pytest + + +@pytest.mark.parametrize( + "output,throws,expected_error", + [ + ("Ice cream is frozen.", False, None), + ( + "Ice cream is a frozen dairy product that is consumed in many places.", + True, + "String should be readable within 0.05 minutes.", + ), + ("This response isn't relevant.", True, "Result must match Ice cream"), + ], +) +def test_guard_as_runnable(output: str, throws: bool, expected_error: Optional[str]): + from langchain_core.language_models import LanguageModelInput + from langchain_core.messages import AIMessage, BaseMessage + from langchain_core.output_parsers import StrOutputParser + from langchain_core.prompts import ChatPromptTemplate + from langchain_core.runnables import Runnable, RunnableConfig + + from guardrails.errors import ValidationError + from guardrails.validators import ReadingTime, RegexMatch + + class MockModel(Runnable): + def invoke( + self, input: LanguageModelInput, config: Optional[RunnableConfig] = None + ) -> BaseMessage: + return AIMessage(content=output) + + prompt = ChatPromptTemplate.from_template("ELIF: {topic}") + model = MockModel() + regex_match = RegexMatch( + "Ice cream", match_type="search", on_fail="refrain" + ).to_runnable() + reading_time = ReadingTime(0.05, on_fail="refrain").to_runnable() + + output_parser = StrOutputParser() + + chain = prompt | model | regex_match | reading_time | output_parser + + topic = "ice cream" + if throws: + with pytest.raises(ValidationError) as exc_info: + chain.invoke({"topic": topic}) + + assert str(exc_info.value) == ( + "The response from the LLM failed validation!" f" {expected_error}" + ) + + else: + result = chain.invoke({"topic": topic}) + + assert result == output