diff --git a/guardrails/classes/validation/validation_result.py b/guardrails/classes/validation/validation_result.py index d6427db88..50a90f9b4 100644 --- a/guardrails/classes/validation/validation_result.py +++ b/guardrails/classes/validation/validation_result.py @@ -7,6 +7,7 @@ ErrorSpan as IErrorSpan, ) from guardrails.classes.generic.arbitrary_model import ArbitraryModel +from pydantic import BaseModel class ValidationResult(IValidationResult, ArbitraryModel): @@ -185,3 +186,9 @@ class ErrorSpan(IErrorSpan, ArbitraryModel): end: int # reason validation failed, specific to this chunk reason: str + + +class StreamValidationResult(BaseModel): + chunk: Any + original_text: str + metadata: Dict[str, Any] diff --git a/guardrails/merge.py b/guardrails/merge.py new file mode 100644 index 000000000..e25acd3ce --- /dev/null +++ b/guardrails/merge.py @@ -0,0 +1,291 @@ +# SOURCE: https://github.com/spyder-ide/three-merge/blob/master/three_merge/merge.py +from diff_match_patch import diff_match_patch + +# Constants +DIFFER = diff_match_patch() +DIFFER.Diff_Timeout = 0.1 +DIFFER.Diff_EditCost = 4 +PRESERVED = 0 +DELETION = -1 +ADDITION = 1 + + +def merge(source: str, target: str, base: str) -> str: + diff1_l = DIFFER.diff_main(base, source) + diff2_l = DIFFER.diff_main(base, target) + + DIFFER.diff_cleanupEfficiency(diff1_l) + DIFFER.diff_cleanupEfficiency(diff2_l) + + diff1 = iter(diff1_l) + diff2 = iter(diff2_l) + + composed_text = [] + + source = next(diff1, None) # type: ignore + target = next(diff2, None) # type: ignore + + prev_source_text = "" + prev_target_text = "" + + while source is not None and target is not None: + source_status, source_text = source + target_status, target_text = target + if source_status == PRESERVED and target_status == PRESERVED: + # Base is preserved for both source and target + if len(source_text) > len(target_text): + # Addition performed by target + advance = True + composed_text.append(target_text) + tempdiff = DIFFER.diff_main(target_text, source_text) + _, invariant = tempdiff[1] + # _, (_, invariant) = DIFFER.diff_main(target_text, source_text) + prev_target_text = target[1] + target = next(diff2, None) # type: ignore + while invariant != "" and target is not None: + # Apply target changes until invariant is preserved + # target = next(diff2, None) + target_status, target_text = target + if target_status == DELETION: + if len(target_text) > len(invariant): + target_text = target_text[len(invariant) :] + invariant = "" + target = (target_status, target_text) # type: ignore + else: + invariant = invariant[len(target_text) :] + prev_target_text = target[1] + target = next(diff2, None) # type: ignore + elif target_status == ADDITION: + composed_text.append(target_text) + prev_target_text = target[1] + target = next(diff2, None) # type: ignore + else: + # Recompute invariant and advance source + if len(invariant) > len(target_text): + assert invariant[: len(target_text)] == target_text + source = (source_status, invariant[len(target_text) :]) # type: ignore + composed_text.append(target_text) + invariant = "" + advance = False + prev_target_text = target[1] + target = next(diff2, None) # type: ignore + else: + target_text = target_text[len(invariant) :] + composed_text.append(invariant) + invariant = "" + target = (target_status, target_text) # type: ignore + if advance: + prev_source_text = source[1] + source = next(diff1, None) # type: ignore + elif len(source_text) < len(target_text): + # Addition performed by source + advance = True + composed_text.append(source_text) + tempdiff = DIFFER.diff_main(target_text, source_text) + _, invariant = tempdiff[1] + # _, (_, invariant) = DIFFER.diff_main(source_text, target_text) + prev_source_text = source[1] + source = next(diff1, None) # type: ignore + while invariant != "" and target is not None and source is not None: + # Apply source changes until invariant is preserved + source_status, source_text = source + if source_status == DELETION: + if len(source_text) > len(invariant): + source_text = source_text[len(invariant) :] + invariant = "" + source = (source_status, source_text) # type: ignore + else: + invariant = invariant[len(source_text) :] + prev_source_text = source[1] + source = next(diff1, None) # type: ignore + elif source_status == ADDITION: + composed_text.append(source_text) + prev_source_text = source[1] + source = next(diff1, None) # type: ignore + else: + # Recompute invariant and advance source + # invariant = invariant[:len(source_text)] + if len(invariant) > len(source_text): + assert invariant[: len(source_text)] == source_text + target = (target_status, invariant[len(source_text) :]) # type: ignore + composed_text.append(source_text) + invariant = "" + advance = False + prev_source_text = source[1] + source = next(diff1, None) # type: ignore + else: + source_text = source_text[len(invariant) :] + composed_text.append(invariant) + invariant = "" + source = (source_status, source_text) # type: ignore + if advance: + prev_target_text = target[1] + target = next(diff2, None) # type: ignore + else: + # Source and target are equal + composed_text.append(source_text) + prev_source_text = source[1] + prev_target_text = target[1] + source = next(diff1, None) # type: ignore + target = next(diff2, None) # type: ignore + elif source_status == ADDITION and target_status == PRESERVED: + # Source is adding text + composed_text.append(source_text) + prev_source_text = source[1] + source = next(diff1, None) # type: ignore + elif source_status == PRESERVED and target_status == ADDITION: + # Target is adding text + composed_text.append(target_text) + prev_target_text = target[1] + target = next(diff2, None) # type: ignore + elif source_status == DELETION and target_status == PRESERVED: + if len(target_text) > len(source_text): + # Take target text, remove the corresponding part from source + target_text = target_text[len(source_text) :] + # composed_text.append(target_text) + # source = diff1.pop(0) + target = (target_status, target_text) # type: ignore + prev_source_text = source[1] + source = next(diff1, None) # type: ignore + elif len(target_text) <= len(source_text): + source_text = source_text[len(target_text) :] + source = (source_status, source_text) # type: ignore + prev_target_text = target[1] + target = next(diff2, None) # type: ignore + elif source_status == PRESERVED and target_status == DELETION: + if len(source_text) > len(target_text): + # Take source text, remove the corresponding part from target + source_text = source_text[len(target_text) :] + source = (source_status, source_text) # type: ignore + prev_target_text = target[1] + target = next(diff2, None) # type: ignore + elif len(source_text) <= len(target_text): + # Advance to next source + target_text = target_text[len(source_text) :] + target = (target_status, target_text) # type: ignore + prev_source_text = source[1] + source = next(diff1, None) # type: ignore + elif source_status == DELETION and target_status == ADDITION: + # Merge conflict + # Err on the side of deletion. Do not add anything + # composed_text.append("<<<<<<< ++ {0} ".format(target_text)) + # composed_text.append("======= -- {0} ".format(source_text)) + # composed_text.append(">>>>>>>") + prev_source_text = source[1] + prev_target_text = target[1] + source = next(diff1, None) # type: ignore + target = next(diff2, None) # type: ignore + if target is not None: + target_status, target_text = target + if target_text.startswith(source_text): + target_text = target_text[len(source_text) :] + target = (target_status, target_text) # type: ignore + elif source_status == ADDITION and target_status == DELETION: + # Merge conflict + # Err on the side of deletion. Do not add anything + # composed_text.append("<<<<<<< ++ {0} ".format(source_text)) + # composed_text.append("======= -- {0} ".format(target_text)) + # composed_text.append(">>>>>>>") + prev_source_text = source[1] + prev_target_text = target[1] + source = next(diff1, None) # type: ignore + target = next(diff2, None) # type: ignore + if source is not None: + source_status, source_text = source + if source_text.startswith(target_text): + source_text = source_text[len(target_text) :] + source = (source_status, source_text) # type: ignore + elif source_status == ADDITION and target_status == ADDITION: + # Possible merge conflict + if len(source_text) >= len(target_text): + if source_text.startswith(target_text): + composed_text.append(source_text) + else: + # Merge conflict + # Insert text that has highest distance from original + # we assume original is last operation + source_dist = DIFFER.diff_levenshtein( + DIFFER.diff_main(source_text, prev_source_text) + ) + target_dist = DIFFER.diff_levenshtein( + DIFFER.diff_main(target_text, prev_target_text) + ) + if source_dist > target_dist: + composed_text.append(source_text) + else: + composed_text.append(target_text) + else: + if target_text.startswith(source_text): + composed_text.append(target_text) + else: + # Merge conflict + # Insert text that has highest distance from original + source_dist = DIFFER.diff_levenshtein( + DIFFER.diff_main(source_text, prev_source_text) + ) + target_dist = DIFFER.diff_levenshtein( + DIFFER.diff_main(target_text, prev_target_text) + ) + if source_dist > target_dist: + composed_text.append(source_text) + else: + composed_text.append(target_text) + prev_source_text = source[1] + prev_target_text = target[1] + source = next(diff1, None) # type: ignore + target = next(diff2, None) # type: ignore + elif source_status == DELETION and target_status == DELETION: + # Possible merge conflict + merge_conflict = False + if len(source_text) > len(target_text): + if source_text.startswith(target_text): + # Peek target to delete preserved text + source_text = source_text[len(target_text) :] + source = (source_status, source_text) # type: ignore + prev_target_text = target[1] + target = next(diff2, None) # type: ignore + else: + merge_conflict = True + elif len(target_text) > len(source_text): + if target_text.startswith(source_text): + target_text = target_text[len(source_text) :] + target = (target_status, target_text) # type: ignore + prev_source_text = source[1] + source = next(diff1, None) # type: ignore + else: + merge_conflict = True + else: + if target_text == source_text: + # Both source and target remove the same text + prev_source_text = source[1] + prev_target_text = target[1] + source = next(diff1, None) # type: ignore + target = next(diff2, None) # type: ignore + else: + merge_conflict = True + + # Don't handle double deletion scenario + if merge_conflict: + source = next(diff1, None) # type: ignore + target = next(diff2, None) # type: ignore + # composed_text.append("<<<<<<< -- {0} ".format(source_text)) + # composed_text.append("======= -- {0} ".format(target_text)) + # composed_text.append(">>>>>>>") + + while source is not None: + source_status, source_text = source + # assert source_status == ADDITION or source_status == PRESERVED + if source_status == ADDITION: + composed_text.append(source_text) + prev_source_text = source[1] + source = next(diff1, None) # type: ignore + + while target is not None: + target_status, target_text = target + # assert target_status == ADDITION or source_status == PRESERVED + if target_status == ADDITION: + composed_text.append(target_text) + prev_target_text = target[1] + target = next(diff2, None) # type: ignore + + return "".join(composed_text) diff --git a/guardrails/run/stream_runner.py b/guardrails/run/stream_runner.py index 4924d294e..2e041abf4 100644 --- a/guardrails/run/stream_runner.py +++ b/guardrails/run/stream_runner.py @@ -1,5 +1,6 @@ -from typing import Any, Dict, Generator, List, Optional, Union, cast +from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Union, cast +from guardrails import validator_service from guardrails.classes.history import Call, Inputs, Iteration, Outputs from guardrails.classes.output_type import OT, OutputTypes from guardrails.classes.validation_outcome import ValidationOutcome @@ -138,95 +139,71 @@ def step( "the API is returning a generator." ) - fragment = "" - parsed_fragment, validated_fragment, valid_op = None, None, None + parsed_fragment, validated_fragment, valid_op = "", None, None verified = set() validation_response = "" + fragment = "" # Loop over the stream # and construct "fragments" of concatenated chunks # for now, handle string and json schema differently - if self.output_type == OutputTypes.STRING: - stream_finished = False - last_chunk_text = "" - for chunk in stream: - # 1. Get the text from the chunk and append to fragment - chunk_text = self.get_chunk_text(chunk, api) - last_chunk_text = chunk_text - finished = self.is_last_chunk(chunk, api) - if finished: - stream_finished = True - fragment += chunk_text - # 2. Parse the chunk - parsed_chunk, move_to_next = self.parse( - chunk_text, output_schema, verified=verified - ) - if move_to_next: - # Continue to next chunk - continue - validated_text = self.validate( - iteration, - index, - parsed_chunk, - output_schema, - True, - validate_subschema=True, - # if it is the last chunk, validate everything that's left - remainder=finished, - ) - if isinstance(validated_text, SkeletonReAsk): + def prepare_chunk_generator(stream) -> Iterable[Tuple[Any, bool]]: + for chunk in stream: + chunk_text = self.get_chunk_text(chunk, api) + nonlocal fragment + fragment += chunk_text + finished = self.is_last_chunk(chunk, api) + # 2. Parse the chunk + parsed_chunk, move_to_next = self.parse( + chunk_text, output_schema, verified=verified + ) + nonlocal parsed_fragment + # ignore types because output schema guarantees a string + parsed_fragment += parsed_chunk # type: ignore + if move_to_next: + # Continue to next chunk + continue + yield parsed_chunk, finished + + prepped_stream = prepare_chunk_generator(stream) + gen = validator_service.validate_stream( + prepped_stream, + self.metadata, + self.validation_map, + iteration, + self._disable_tracer, + "$", + validate_subschema=True, + ) + + for res in gen: + chunk = res.chunk + original_text = res.original_text + if isinstance(chunk, SkeletonReAsk): raise ValueError( "Received fragment schema is an invalid sub-schema " "of the expected output JSON schema." ) # 4. Introspect: inspect the validated fragment for reasks - reasks, valid_op = self.introspect(validated_text) + reasks, valid_op = self.introspect(chunk) if reasks: raise ValueError( "Reasks are not yet supported with streaming. Please " "remove reasks from schema or disable streaming." ) # 5. Convert validated fragment to a pretty JSON string - validation_response += cast(str, validated_text) + validation_response += cast(str, chunk) passed = call_log.status == pass_status yield ValidationOutcome( call_id=call_log.id, # type: ignore # The chunk or the whole output? - raw_llm_output=chunk_text, - validated_output=validated_text, + raw_llm_output=original_text, + validated_output=chunk, validation_passed=passed, ) - # handle case where generator doesn't give finished status - if not stream_finished: - last_result = self.validate( - iteration, - index, - "", - output_schema, - True, - validate_subschema=True, - remainder=True, - ) - if last_result: - passed = call_log.status == pass_status - - validated_output = None - if passed is True: - validated_output = cast(OT, last_result) - - reask = None - if isinstance(last_result, ReAsk): - reask = last_result - - yield ValidationOutcome( - call_id=call_log.id, # type: ignore - raw_llm_output=last_chunk_text, - validated_output=validated_output, - reask=reask, - validation_passed=passed, - ) + # handle non string schema else: for chunk in stream: @@ -276,10 +253,8 @@ def step( validation_passed=validated_fragment is not None, ) - # Finally, add to logs + # # Finally, add to logs iteration.outputs.raw_output = fragment - # Do we need to care about the type here? - # What happens if parsing continuously fails? iteration.outputs.parsed_output = parsed_fragment or fragment # type: ignore iteration.outputs.validation_response = validation_response iteration.outputs.guarded_output = valid_op diff --git a/guardrails/validator_base.py b/guardrails/validator_base.py index 67a95b45d..5f3660e35 100644 --- a/guardrails/validator_base.py +++ b/guardrails/validator_base.py @@ -233,8 +233,9 @@ def validate_stream( remainder = kwargs.get("remainder", False) if remainder: split_contents = [accumulated_text, ""] + # if no chunks are returned, we haven't accumulated enough if len(split_contents) == 0: - return PassResult() + return None [chunk_to_validate, new_accumulated_chunks] = split_contents self.accumulated_chunks = [new_accumulated_chunks] # exclude last chunk, because it may not be a complete chunk diff --git a/guardrails/validator_service.py b/guardrails/validator_service.py index 1c81e7e9c..cc9f53cbb 100644 --- a/guardrails/validator_service.py +++ b/guardrails/validator_service.py @@ -3,7 +3,7 @@ import os from concurrent.futures import ProcessPoolExecutor from datetime import datetime -from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Awaitable, Dict, Iterable, List, Optional, Tuple, Union, cast from guardrails.actions.filter import Filter, apply_filters from guardrails.actions.refrain import Refrain, apply_refrain @@ -12,9 +12,11 @@ from guardrails.classes.validation.validation_result import ( FailResult, PassResult, + StreamValidationResult, ValidationResult, ) from guardrails.errors import ValidationError +from guardrails.merge import merge from guardrails.types import ValidatorMap, OnFailAction from guardrails.utils.exception_utils import UserFacingException from guardrails.utils.hub_telemetry_utils import HubTelemetry @@ -144,7 +146,7 @@ def after_run_validator( self, validator: Validator, validator_logs: ValidatorLogs, - result: ValidationResult, + result: Optional[ValidationResult], ): end_time = datetime.now() validator_logs.validation_result = result @@ -196,7 +198,7 @@ def run_validator_sync( *, validation_session_id: str, **kwargs, - ) -> ValidationResult: + ) -> Optional[ValidationResult]: result = self.execute_validator( validator, value, @@ -212,8 +214,8 @@ def run_validator_sync( f"Either use AsyncGuard or remove {validator_logs.validator_name}." ) ) - elif result is None: - result = PassResult() + if result is None: + return result return cast(ValidationResult, result) def run_validator( @@ -242,6 +244,256 @@ def run_validator( return self.after_run_validator(validator, validator_logs, result) + def run_validators_stream( + self, + iteration: Iteration, + validator_map: ValidatorMap, + value_stream: Iterable[Tuple[Any, bool]], + metadata: Dict[str, Any], + absolute_property_path: str, + reference_property_path: str, + **kwargs, + ) -> Iterable[StreamValidationResult]: + validators = validator_map.get(reference_property_path, []) + for validator in validators: + if validator.on_fail_descriptor == OnFailAction.FIX: + return self.run_validators_stream_fix( + iteration, + validator_map, + value_stream, + metadata, + absolute_property_path, + reference_property_path, + **kwargs, + ) + return self.run_validators_stream_noop( + iteration, + validator_map, + value_stream, + metadata, + absolute_property_path, + reference_property_path, + **kwargs, + ) + + # requires at least 2 validators + def multi_merge(self, original: str, new_values: list[str]) -> str: + current = new_values.pop() + print("Fmerging these:", new_values) + while len(new_values) > 0: + nextval = new_values.pop() + current = merge(current, nextval, original) + print("\nFmerge result:", current) + return current + + def run_validators_stream_fix( + self, + iteration: Iteration, + validator_map: ValidatorMap, + value_stream: Iterable[Tuple[Any, bool]], + metadata: Dict[str, Any], + absolute_property_path: str, + reference_property_path: str, + **kwargs, + ) -> Iterable[StreamValidationResult]: + validators = validator_map.get(reference_property_path, []) + acc_output = "" + validator_partial_acc: dict[int, str] = {} + for validator in validators: + validator_partial_acc[id(validator)] = "" + last_chunk = None + last_chunk_validated = False + last_chunk_missing_validators = [] + refrain_triggered = False + for chunk, finished in value_stream: + original_text = chunk + acc_output += chunk + fixed_values = [] + last_chunk = chunk + last_chunk_missing_validators = [] + if refrain_triggered: + break + for validator in validators: + # reset chunk to original text + chunk = original_text + validator_logs = self.run_validator( + iteration, + validator, + chunk, + metadata, + absolute_property_path, + True, + remainder=finished, + **kwargs, + ) + result = validator_logs.validation_result + if result is None: + last_chunk_missing_validators.append(validator) + result = cast(ValidationResult, result) + # if we have a concrete result, log it in the validation map + if isinstance(result, FailResult): + is_filter = validator.on_fail_descriptor is OnFailAction.FILTER + is_refrain = validator.on_fail_descriptor is OnFailAction.REFRAIN + if is_filter or is_refrain: + refrain_triggered = True + break + rechecked_value = None + chunk = self.perform_correction( + [result], + chunk, + validator, + validator.on_fail_descriptor, + rechecked_value=rechecked_value, + ) + fixed_values.append(chunk) + validator_partial_acc[id(validator)] += chunk # type: ignore + elif isinstance(result, PassResult): + if ( + validator.override_value_on_pass + and result.value_override is not result.ValueOverrideSentinel + ): + chunk = result.value_override + else: + chunk = result.validated_chunk + fixed_values.append(chunk) + validator_partial_acc[id(validator)] += chunk # type: ignore + validator_logs.value_after_validation = chunk + if result and result.metadata is not None: + metadata = result.metadata + + if refrain_triggered: + # if we have a failresult from a refrain/filter validator, yield empty + yield StreamValidationResult( + chunk="", original_text=acc_output, metadata=metadata + ) + else: + # if every validator has yielded a concrete value, merge and yield + # only merge and yield if all validators have run + # TODO: check if only 1 validator - then skip merging + if len(fixed_values) == len(validators): + last_chunk_validated = True + values_to_merge = [] + for validator in validators: + values_to_merge.append(validator_partial_acc[id(validator)]) + merged_value = self.multi_merge(acc_output, values_to_merge) + # merged_value = self.multi_merge(acc_output, values_to_merge) + # reset validator_partial_acc + for validator in validators: + validator_partial_acc[id(validator)] = "" + yield StreamValidationResult( + chunk=merged_value, original_text=acc_output, metadata=metadata + ) + acc_output = "" + else: + last_chunk_validated = False + # handle case where LLM doesn't yield finished flag + # we need to validate remainder of accumulated chunks + if not last_chunk_validated and not refrain_triggered: + original_text = last_chunk + for validator in last_chunk_missing_validators: + last_log = self.run_validator( + iteration, + validator, + # use empty chunk + # validator has already accumulated the chunk from the first loop + "", + metadata, + absolute_property_path, + True, + remainder=True, + **kwargs, + ) + result = last_log.validation_result + if isinstance(result, FailResult): + rechecked_value = None + last_chunk = self.perform_correction( + [result], + last_chunk, + validator, + validator.on_fail_descriptor, + rechecked_value=rechecked_value, + ) + validator_partial_acc[id(validator)] += last_chunk # type: ignore + elif isinstance(result, PassResult): + if ( + validator.override_value_on_pass + and result.value_override is not result.ValueOverrideSentinel + ): + last_chunk = result.value_override + else: + last_chunk = result.validated_chunk + validator_partial_acc[id(validator)] += last_chunk # type: ignore + last_log.value_after_validation = last_chunk + if result and result.metadata is not None: + metadata = result.metadata + values_to_merge = [] + for validator in validators: + values_to_merge.append(validator_partial_acc[id(validator)]) + merged_value = self.multi_merge(acc_output, values_to_merge) + yield StreamValidationResult( + chunk=merged_value, + original_text=original_text, # type: ignore + metadata=metadata, # type: ignore + ) + # yield merged value + + def run_validators_stream_noop( + self, + iteration: Iteration, + validator_map: ValidatorMap, + value_stream: Iterable[Tuple[Any, bool]], + metadata: Dict[str, Any], + absolute_property_path: str, + reference_property_path: str, + **kwargs, + ) -> Iterable[StreamValidationResult]: + validators = validator_map.get(reference_property_path, []) + # Validate the field + # TODO: Under what conditions do we yield? + # When we have at least one non-None value? + # When we have all non-None values? + # Does this depend on whether we are fix or not? + for chunk, finished in value_stream: + original_text = chunk + for validator in validators: + validator_logs = self.run_validator( + iteration, + validator, + chunk, + metadata, + absolute_property_path, + True, + **kwargs, + ) + result = validator_logs.validation_result + result = cast(ValidationResult, result) + + if isinstance(result, FailResult): + rechecked_value = None + chunk = self.perform_correction( + [result], + chunk, + validator, + validator.on_fail_descriptor, + rechecked_value=rechecked_value, + ) + elif isinstance(result, PassResult): + if ( + validator.override_value_on_pass + and result.value_override is not result.ValueOverrideSentinel + ): + chunk = result.value_override + + validator_logs.value_after_validation = chunk + if result and result.metadata is not None: + metadata = result.metadata + # # TODO: Filter is no longer terminal, so we shouldn't yield, right? + # if isinstance(chunk, (Refrain, Filter, ReAsk)): + # yield chunk, metadata + yield StreamValidationResult( + chunk=chunk, original_text=original_text, metadata=metadata + ) + def run_validators( self, iteration: Iteration, @@ -292,6 +544,7 @@ def run_validators( **kwargs, ) result = validator_logs.validation_result + result = cast(ValidationResult, result) if isinstance(result, FailResult): rechecked_value = None @@ -400,30 +653,28 @@ def validate( def validate_stream( self, - value: Any, + value_stream: Iterable[Tuple[Any, bool]], metadata: dict, validator_map: ValidatorMap, iteration: Iteration, absolute_path: str, reference_path: str, **kwargs, - ) -> Tuple[Any, dict]: + ) -> Iterable[StreamValidationResult]: # I assume validate stream doesn't need validate_dependents # because right now we're only handling StringSchema # Validate the field - value, metadata = self.run_validators( + gen = self.run_validators_stream( iteration, validator_map, - value, + value_stream, metadata, absolute_path, reference_path, - True, **kwargs, ) - - return value, metadata + return gen class MultiprocMixin: @@ -742,18 +993,12 @@ def validate( iteration: Iteration, disable_tracer: Optional[bool] = True, path: Optional[str] = None, - stream: Optional[bool] = False, **kwargs, ): if path is None: path = "$" process_count = int(os.environ.get("GUARDRAILS_PROCESS_COUNT", 10)) - if stream: - sequential_validator_service = SequentialValidatorService(disable_tracer) - return sequential_validator_service.validate_stream( - value, metadata, validator_map, iteration, path, path, **kwargs - ) try: loop = asyncio.get_event_loop() except RuntimeError: @@ -771,6 +1016,24 @@ def validate( ) +def validate_stream( + value_stream: Iterable[Tuple[Any, bool]], + metadata: dict, + validator_map: ValidatorMap, + iteration: Iteration, + disable_tracer: Optional[bool] = True, + path: Optional[str] = None, + **kwargs, +) -> Iterable[StreamValidationResult]: + if path is None: + path = "$" + sequential_validator_service = SequentialValidatorService(disable_tracer) + gen = sequential_validator_service.validate_stream( + value_stream, metadata, validator_map, iteration, path, path, **kwargs + ) + return gen + + async def async_validate( value: Any, metadata: dict, diff --git a/poetry.lock b/poetry.lock index b76f1826a..e66536540 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "aiohttp" @@ -1055,6 +1055,20 @@ wrapt = ">=1.10,<2" [package.extras] dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"] +[[package]] +name = "diff-match-patch" +version = "20230430" +description = "Diff Match and Patch" +optional = false +python-versions = ">=3.7" +files = [ + {file = "diff-match-patch-20230430.tar.gz", hash = "sha256:953019cdb9c9d2c9e47b5b12bcff3cf4746fc4598eb406076fa1fc27e6a1f15c"}, + {file = "diff_match_patch-20230430-py3-none-any.whl", hash = "sha256:dce43505fb7b1b317de7195579388df0746d90db07015ed47a85e5e44930ef93"}, +] + +[package.extras] +dev = ["attribution (==1.6.2)", "black (==23.3.0)", "flit (==3.8.0)", "mypy (==1.2.0)", "ufmt (==2.1.0)", "usort (==1.0.6)"] + [[package]] name = "distlib" version = "0.3.8" @@ -7374,4 +7388,4 @@ vectordb = ["faiss-cpu", "numpy"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "eb8fff0171f3a678dffc37f25f04fa97f1bd13d7879562999e0c4ffc509151e3" +content-hash = "cd01f3a668b819ef4f1328e38e521feb59cffbae549ab2a42d8c17d175f177f3" diff --git a/pyproject.toml b/pyproject.toml index a4d777b05..597a381cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,7 @@ opentelemetry-sdk = "^1.24.0" opentelemetry-exporter-otlp-proto-grpc = "^1.24.0" opentelemetry-exporter-otlp-proto-http = "^1.24.0" guardrails-api-client = ">=0.3.8" +diff-match-patch = "^20230430" guardrails-api = ">=0.0.1" [tool.poetry.extras] diff --git a/tests/integration_tests/test_assets/validators/__init__.py b/tests/integration_tests/test_assets/validators/__init__.py index 58808893b..cd01d953a 100644 --- a/tests/integration_tests/test_assets/validators/__init__.py +++ b/tests/integration_tests/test_assets/validators/__init__.py @@ -8,6 +8,7 @@ from tests.integration_tests.test_assets.validators.valid_choices import ValidChoices from tests.integration_tests.test_assets.validators.valid_length import ValidLength from tests.integration_tests.test_assets.validators.valid_url import ValidURL +from tests.integration_tests.test_assets.validators.detect_pii import MockDetectPII __all__ = [ "EndsWith", @@ -20,4 +21,5 @@ "ValidChoices", "ValidLength", "ValidURL", + "MockDetectPII", ] diff --git a/tests/integration_tests/test_assets/validators/detect_pii.py b/tests/integration_tests/test_assets/validators/detect_pii.py new file mode 100644 index 000000000..68a491dae --- /dev/null +++ b/tests/integration_tests/test_assets/validators/detect_pii.py @@ -0,0 +1,181 @@ +from typing import Any, Callable, Dict, List, Union +import difflib +import nltk + +from guardrails.validator_base import ( + FailResult, + PassResult, + ValidationResult, + Validator, + register_validator, +) +from guardrails.validator_base import ErrorSpan + + +@register_validator(name="guardrails/detect_pii", data_type="string") +class MockDetectPII(Validator): + """Validates that any text does not contain any PII. + + Instead of using Microsoft Presidio, it accepts a map of PII + text to their replacements, and performs a simple string replacement. + For example, if the map is {"John Doe": "REDACTED"}, then the text "John + Doe is a person" will be replaced with "REDACTED is a person". + + **Key Properties** + + | Property | Description | + | ----------------------------- | --------------------------------------- | + | Name for `format` attribute | `pii` | + | Supported data types | `string` | + | Programmatic fix | Anonymized text with PII filtered out | + + Args: + pii_entities (str | List[str], optional): The PII entities to filter. Must be + one of `pii` or `spi`. Defaults to None. Can also be set in metadata. + """ + + PII_ENTITIES_MAP = { + "pii": [ + "EMAIL_ADDRESS", + "PHONE_NUMBER", + "DOMAIN_NAME", + "IP_ADDRESS", + "DATE_TIME", + "LOCATION", + "PERSON", + "URL", + ], + "spi": [ + "CREDIT_CARD", + "CRYPTO", + "IBAN_CODE", + "NRP", + "MEDICAL_LICENSE", + "US_BANK_NUMBER", + "US_DRIVER_LICENSE", + "US_ITIN", + "US_PASSPORT", + "US_SSN", + ], + } + + def chunking_function(self, chunk: str): + """ + Use a sentence tokenizer to split the chunk into sentences. + + Because using the tokenizer is expensive, we only use it if there + is a period present in the chunk. + """ + # using the sentence tokenizer is expensive + # we check for a . to avoid wastefully calling the tokenizer + if "." not in chunk: + return [] + sentences = nltk.sent_tokenize(chunk) + if len(sentences) == 0: + return [] + if len(sentences) == 1: + sentence = sentences[0].strip() + # this can still fail if the last chunk ends on the . in an email address + if sentence[-1] == ".": + return [sentence, ""] + else: + return [] + + # return the sentence + # then the remaining chunks that aren't finished accumulating + return [sentences[0], "".join(sentences[1:])] + + def __init__( + self, + pii_entities: Union[str, List[str], None] = None, + on_fail: Union[Callable[..., Any], None] = None, + replace_map: Dict[str, str] = {}, + **kwargs, + ): + super().__init__(on_fail, pii_entities=pii_entities, **kwargs) + self.pii_entities = pii_entities + self.replace_map = replace_map + + def get_anonymized_text(self, text: str, entities: List[str]) -> str: + """Analyze and anonymize the text for PII. + + Args: + text (str): The text to analyze. + pii_entities (List[str]): The PII entities to filter. + + Returns: + anonymized_text (str): The anonymized text. + """ + anonymized_text = text + # iterate through keys in replace_map + for key in self.replace_map: + anonymized_text = anonymized_text.replace(key, self.replace_map[key]) + return anonymized_text + + def validate(self, value: Any, metadata: Dict[str, Any]) -> ValidationResult: + # Entities to filter passed through metadata take precedence + pii_entities = metadata.get("pii_entities", self.pii_entities) + if pii_entities is None: + raise ValueError( + "`pii_entities` must be set in order to use the `DetectPII` validator." + "Add this: `pii_entities=['PERSON', 'PHONE_NUMBER']`" + "OR pii_entities='pii' or 'spi'" + "in init or metadata." + ) + + pii_keys = list(self.PII_ENTITIES_MAP.keys()) + # Check that pii_entities is a string OR list of strings + if isinstance(pii_entities, str): + # A key to the PII_ENTITIES_MAP + entities_to_filter = self.PII_ENTITIES_MAP.get(pii_entities, None) + if entities_to_filter is None: + raise ValueError(f"`pii_entities` must be one of {pii_keys}") + elif isinstance(pii_entities, list): + entities_to_filter = pii_entities + else: + raise ValueError( + f"`pii_entities` must be one of {pii_keys}" " or a list of strings." + ) + + # Analyze the text, and anonymize it if there is PII + anonymized_text = self.get_anonymized_text( + text=value, entities=entities_to_filter + ) + if anonymized_text == value: + return PassResult() + + # TODO: this should be refactored into a helper method in OSS + # get character indices of differences between two strings + differ = difflib.Differ() + diffs = list(differ.compare(value, anonymized_text)) + start_range = None + diff_ranges = [] + # needs to be tracked separately + curr_index_in_original = 0 + for i in range(len(diffs)): + if start_range is not None and diffs[i][0] != "-": + diff_ranges.append((start_range, curr_index_in_original)) + start_range = None + if diffs[i][0] == "-": + if start_range is None: + start_range = curr_index_in_original + if diffs[i][0] != "+": + curr_index_in_original += 1 + + error_spans = [] + for diff_range in diff_ranges: + error_spans.append( + ErrorSpan( + start=diff_range[0], + end=diff_range[1], + reason=f"PII detected in {value[diff_range[0]:diff_range[1]]}", + ) + ) + + # If anonymized value text is different from original value, then there is PII + error_msg = f"The following text in your response contains PII:\n{value}" + return FailResult( + error_message=(error_msg), + fix_value=anonymized_text, + error_spans=error_spans, + ) diff --git a/tests/integration_tests/test_streaming.py b/tests/integration_tests/test_streaming.py index a74c58fe5..acb24518d 100644 --- a/tests/integration_tests/test_streaming.py +++ b/tests/integration_tests/test_streaming.py @@ -21,7 +21,7 @@ Validator, register_validator, ) -from tests.integration_tests.test_assets.validators import LowerCase +from tests.integration_tests.test_assets.validators import LowerCase, MockDetectPII expected_raw_output = {"statement": "I am DOING well, and I HOPE you aRe too."} expected_fix_output = {"statement": "i am doing well, and i hope you are too."} @@ -324,7 +324,7 @@ def test_streaming_with_openai_chat_callable( STR_LLM_CHUNKS = [ # 38 characters "This sentence is simply just ", - "too long." + "too long.", # 25 characters long "This ", "sentence ", @@ -399,9 +399,327 @@ def test_string_schema_streaming_with_openai_chat(mocker, guard, expected_error_ accumulated_output += op.raw_llm_output error_spans = guard.error_spans_in_output() - # print spans assert len(error_spans) == len(expected_error_spans) for error_span, expected in zip(error_spans, expected_error_spans): assert accumulated_output[error_span.start : error_span.end] == expected[0] assert error_span.reason == expected[1] # TODO assert something about these error spans + + +POETRY_CHUNKS = [ + '"John, under ', + "GOLDEN bridges", + ", roams,\n", + "SAN Francisco's ", + "hills, his HOME.\n", + "Dreams of", + " FOG, and salty AIR,\n", + "In his HEART", + ", he's always THERE.", +] + + +def test_noop_behavior_two_validators(mocker): + mocker.patch( + "openai.resources.chat.completions.Completions.create", + return_value=mock_openai_chat_completion_create(POETRY_CHUNKS), + ) + + guard = gd.Guard().use_many( + MockDetectPII( + on_fail=OnFailAction.NOOP, + pii_entities="pii", + replace_map={"John": "", "SAN Francisco's": ""}, + ), + LowerCase(on_fail=OnFailAction.NOOP), + ) + prompt = """Write me a 4 line poem about John in San Francisco. + Make every third word all caps.""" + gen = guard( + llm_api=openai.chat.completions.create, + prompt=prompt, + model="gpt-4", + stream=True, + ) + text = "" + original = "" + for res in gen: + original = original + res.raw_llm_output + text = text + res.validated_output + assert ( + text + == """"John, under GOLDEN bridges, roams, +SAN Francisco's hills, his HOME. +Dreams of FOG, and salty AIR, +In his HEART, he's always THERE.""" + ) + assert ( + original + == """"John, under GOLDEN bridges, roams, +SAN Francisco's hills, his HOME. +Dreams of FOG, and salty AIR, +In his HEART, he's always THERE.""" + ) + + +def test_fix_behavior_one_validator(mocker): + mocker.patch( + "openai.resources.chat.completions.Completions.create", + return_value=mock_openai_chat_completion_create(POETRY_CHUNKS), + ) + + guard = gd.Guard().use_many( + LowerCase(on_fail=OnFailAction.FIX), + ) + prompt = """Write me a 4 line poem about John in San Francisco. + Make every third word all caps.""" + gen = guard( + llm_api=openai.chat.completions.create, + prompt=prompt, + model="gpt-4", + stream=True, + ) + text = "" + original = "" + for res in gen: + original = original + res.raw_llm_output + text = text + res.validated_output + assert ( + text + == """"john, under golden bridges, roams, +san francisco's hills, his home. +dreams of fog, and salty air, +in his heart, he's always there.""" + ) + assert ( + original + == """"John, under GOLDEN bridges, roams, +SAN Francisco's hills, his HOME. +Dreams of FOG, and salty AIR, +In his HEART, he's always THERE.""" + ) + + +def test_fix_behavior_two_validators(mocker): + mocker.patch( + "openai.resources.chat.completions.Completions.create", + return_value=mock_openai_chat_completion_create(POETRY_CHUNKS), + ) + + guard = gd.Guard().use_many( + MockDetectPII( + on_fail=OnFailAction.FIX, + pii_entities="pii", + replace_map={"John": "", "SAN Francisco's": ""}, + ), + LowerCase(on_fail=OnFailAction.FIX), + ) + prompt = """Write me a 4 line poem about John in San Francisco. + Make every third word all caps.""" + gen = guard( + llm_api=openai.chat.completions.create, + prompt=prompt, + model="gpt-4", + stream=True, + ) + text = "" + original = "" + for res in gen: + original = original + res.raw_llm_output + text = text + res.validated_output + assert ( + text + == """", under golden bridges, roams, + hills, his home. +dreams of fog, and salty air, +in his heart, he's always there.""" + ) + assert ( + original + == """"John, under GOLDEN bridges, roams, +SAN Francisco's hills, his HOME. +Dreams of FOG, and salty AIR, +In his HEART, he's always THERE.""" + ) + + +def test_fix_behavior_three_validators(mocker): + mocker.patch( + "openai.resources.chat.completions.Completions.create", + return_value=mock_openai_chat_completion_create(POETRY_CHUNKS), + ) + + guard = gd.Guard().use_many( + MockDetectPII( + on_fail=OnFailAction.FIX, + pii_entities="pii", + replace_map={"John": "", "SAN Francisco's": ""}, + ), + LowerCase(on_fail=OnFailAction.FIX), + # UpperCase(on_fail=OnFailAction.FIX), + MockDetectPII( + on_fail=OnFailAction.FIX, + pii_entities="pii", + replace_map={ + "John": "REDACTED!!", + "SAN Francisco's": "REDACTED!!", + "GOLDEN": "purple!!", + }, + ), + ) + prompt = """Write me a 4 line poem about John in San Francisco. + Make every third word all caps.""" + gen = guard( + llm_api=openai.chat.completions.create, + prompt=prompt, + model="gpt-4", + stream=True, + ) + text = "" + original = "" + for res in gen: + original = original + res.raw_llm_output + text = text + res.validated_output + print("FINAL TEXT", text) + assert ( + text + == """"REDACTED!!, under purple!! bridges, roams, + hills, his home. +dreams of fog, and salty air, +in his heart, he's always there.""" + ) + assert ( + original + == """"John, under GOLDEN bridges, roams, +SAN Francisco's hills, his HOME. +Dreams of FOG, and salty AIR, +In his HEART, he's always THERE.""" + ) + + +# This case does not work! +# def test_fix_behavior_three_validators_overlap(mocker): +# mocker.patch( +# "openai.resources.chat.completions.Completions.create", +# return_value=mock_openai_chat_completion_create(POETRY_CHUNKS), +# ) + +# guard = gd.Guard().use_many( +# MockDetectPII( +# on_fail=OnFailAction.FIX, +# pii_entities="pii", +# replace_map={"John": "", "SAN Francisco's": ""}, +# ), +# LowerCase(on_fail=OnFailAction.FIX), +# # UpperCase(on_fail=OnFailAction.FIX), +# MockDetectPII( +# on_fail=OnFailAction.FIX, +# pii_entities="pii", +# replace_map={ +# "John, under GOLDEN": "REDACTED!!", +# "SAN Francisco's hills": "REDACTED!!", +# "GOLDEN bridges": "gold!!!!", +# }, +# ), +# ) +# prompt = """Write me a 4 line poem about John in San Francisco. +# Make every third word all caps.""" +# gen = guard( +# llm_api=openai.chat.completions.create, +# prompt=prompt, +# model="gpt-4", +# stream=True, +# ) +# text = "" +# original = "" +# for res in gen: +# original = original + res.raw_llm_output +# text = text + res.validated_output +# print("TEXT", text) +# assert ( +# text +# == """"REDACTED!!, under gold!!!! bridges, roams, +# hills, his home. +# dreams of fog, and salty air, +# in his heart, he's always there.""" +# ) +# assert ( +# original +# == """"John, under GOLDEN bridges, roams, +# SAN Francisco's hills, his HOME. +# Dreams of FOG, and salty AIR, +# In his HEART, he's always THERE.""" +# ) + + +def test_refrain_behavior(mocker): + mocker.patch( + "openai.resources.chat.completions.Completions.create", + return_value=mock_openai_chat_completion_create(POETRY_CHUNKS), + ) + + guard = gd.Guard().use_many( + MockDetectPII( + on_fail=OnFailAction.REFRAIN, + pii_entities="pii", + replace_map={"John": "", "SAN Francisco's": ""}, + ), + LowerCase(on_fail=OnFailAction.FIX), + ) + + prompt = """Write me a 4 line poem about John in San Francisco. + Make every third word all caps.""" + gen = guard( + llm_api=openai.chat.completions.create, + prompt=prompt, + model="gpt-4", + stream=True, + ) + text = "" + original = "" + for res in gen: + original = original + res.raw_llm_output + text = text + res.validated_output + assert text == "" + assert ( + original + == """"John, under GOLDEN bridges, roams, +SAN Francisco's hills, his HOME. +""" + ) + + +def test_filter_behavior(mocker): + mocker.patch( + "openai.resources.chat.completions.Completions.create", + return_value=mock_openai_chat_completion_create(POETRY_CHUNKS), + ) + + guard = gd.Guard().use_many( + MockDetectPII( + on_fail=OnFailAction.FIX, + pii_entities="pii", + replace_map={"John": "", "SAN Francisco's": ""}, + ), + LowerCase(on_fail=OnFailAction.FILTER), + ) + prompt = """Write me a 4 line poem about John in San Francisco. + Make every third word all caps.""" + gen = guard( + llm_api=openai.chat.completions.create, + prompt=prompt, + model="gpt-4", + stream=True, + ) + text = "" + original = "" + for res in gen: + original = original + res.raw_llm_output + text = text + res.validated_output + assert text == "" + assert ( + original + == """"John, under GOLDEN bridges, roams, +SAN Francisco's hills, his HOME. +""" + ) diff --git a/tests/unit_tests/test_merge.py b/tests/unit_tests/test_merge.py new file mode 100644 index 000000000..517a8efd1 --- /dev/null +++ b/tests/unit_tests/test_merge.py @@ -0,0 +1,56 @@ +import pytest +from guardrails.validator_service import SequentialValidatorService + + +validator_service = SequentialValidatorService() + + +@pytest.mark.parametrize( + "original, new_values, expected", + [ + # test behavior on blank fixes + ("hello world", ["", "hello nick"], "nick"), + ("hello world", ["", "hello world"], ""), + # test behavior on non overlapping replacements + ( + """John is a shitty person who works at Anthropic on Claude, + and lives in San Francisco""", + [ + """ is a shitty person who works at Anthropic on , + and lives in """, + """John is a ****** person who works at Anthropic on Claude, + and lives in San Francisco""", + ], + """ is a ****** person who works at Anthropic on , + and lives in """, + ), + # test behavior with lowercase + ( + """JOE is FUNNY and LIVES in NEW york""", + [ + """ is FUNNY and lives in """, + """joe is funny and lives in new york""", + ], + """ is funny and lives in """, + ), + # broken! + # ( + # """JOHN lives IN SAN francisco""", + # [ + # """ lives in """, + # """john lives in san francisco""", + # ], + # """ lives in """, + # ) + # (broken) test behavior with a word close to PERSON + # ("""Perry is FUNNY and LIVES in NEW york""", + # [""" is FUNNY and lives in """, + # """perry is funny and lives in new york"""], + # """ is funny and lives in """), + ], +) +def test_merge(original, new_values, expected): + print("testing", original, new_values, expected) + res = validator_service.multi_merge(original, new_values) + print("res", res) + assert res == expected