In [None]:
!pip install guardrails-ai peft transformers jinja2 pytest flash_attn

In [35]:
from typing import Dict, Optional

import torch
from guardrails.validator_base import (
    FailResult,
    PassResult,
    ValidationResult,
    Validator,
    register_validator,
)
from jinja2 import Template
from peft import PeftConfig, PeftModel
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    pipeline,
)

HALLUCINATION_EVAL_BASE = """
       {% set knowledge_line = "" if reference == "" else "[Knowledge]: " + reference + "\n        " %}
        Your job is to evaluate whether a machine learning model has hallucinated or not.
        A hallucination occurs when the response is coherent but factually incorrect or nonsensical
        outputs that are not grounded in the provided context.
        You are given the following information:
        ####INFO####
        {{ knowledge_line }}[User Input]: {{ query }}
        [Model Response]: {{ response }}
        ####END INFO####
        Based on the information provided is the model output a hallucination? Respond with only "yes" or "no"
"""


@register_validator(name="guardrails/grounded-ai-hallucination", data_type="string")
class GroundedaiHallucination(Validator):
    """Validates whether a given response is a hallucination based on the provided query, response
    and optional reference.

    This validator uses a fine-tuned language model by GroundedAI to determine if the response
    is grounded in the given context or if it's a hallucination.

    **Key Properties**

    | Property                      | Description                           |
    | ----------------------------- | ------------------------------------- |
    | Name for `format` attribute   | `guardrails/grounded-ai-hallucination`|
    | Supported data types          | `string`                              |
    | Programmatic fix              | `None`                                |

    Args:
        quant (bool): Whether to use quantization for the model.
        base_prompt (Optional[str]): The base prompt template for hallucination evaluation.
            Defaults to HALLUCINATION_EVAL_BASE.
    """
  # noqa

    BASE_MODEL_ID = "microsoft/Phi-3-mini-4k-instruct"
    GROUNDEDAI_EVAL_ID = "grounded-ai/phi3.5-hallucination-judge"

    def __init__(
        self,
        quant: bool,
        base_prompt: Optional[str] = HALLUCINATION_EVAL_BASE,
    ):
        super().__init__()
        self._quantize = quant
        self._base_prompt = base_prompt
        self._base_model = None
        self._tokenizer = None
        self._merged_model = None

        self.warmup()

    def load_model(self):
        """Loads the base model with or without quantization."""
        compute_dtype = (
            torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
        )
        attn_implementation = (
            "flash_attention_2" if torch.cuda.is_bf16_supported() else "sdpa"
        )

        tokenizer = AutoTokenizer.from_pretrained(self.BASE_MODEL_ID)
        model_kwargs = {
            "attn_implementation": attn_implementation,
            "torch_dtype": compute_dtype,
        }
        if self._quantize:
            model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
        base_model = AutoModelForCausalLM.from_pretrained(
            self.BASE_MODEL_ID, **model_kwargs
        )

        self._base_model = base_model
        self._tokenizer = tokenizer

    def merge_adapter(self, groundedai_eval_id: str):
        """Merges the PEFT adapter into the base model."""
        # TODO Error handling for adapter merging could be added here
        config = PeftConfig.from_pretrained(groundedai_eval_id)
        model_peft = PeftModel.from_pretrained(
            self._base_model, groundedai_eval_id, config=config
        )
        self._merged_model = model_peft.merge_and_unload()
        if not self._quantize:
            self._merged_model.to("cuda")

    def warmup(self):
        """Warmup the model by loading it and merging the adapter"""
        self.load_model()
        self.merge_adapter(self.GROUNDEDAI_EVAL_ID)

    def format_input(self, query: str, response: str, reference: str = None) -> str:
        template = Template(self._base_prompt)
        rendered_prompt = template.render(
            reference=reference, query=query, response=response
        )
        return rendered_prompt

    def run_model(self, query: str, response: str, reference: str = "") -> str:
        input = self.format_input(query, response, reference)
        messages = [{"role": "user", "content": input}]

        pipe = pipeline(
            "text-generation",
            model=self._merged_model,
            tokenizer=self._tokenizer,
        )

        generation_args = {
            "max_new_tokens": 2,
            "return_full_text": False,
            "temperature": 0.01,
            "do_sample": True,
        }

        output = pipe(messages, **generation_args)
        torch.cuda.empty_cache()
        return output[0]["generated_text"].strip().lower()

    def validate(self, value: Dict, metadata: Dict = {}) -> ValidationResult:
        """Validates that {fill in how you validator interacts with the passed value}."""
        hallucination = self.run_model(
            value["query"], value["response"], value["reference"]
        )
        if "yes" in hallucination:
            return FailResult(
                error_message="{The provided input was classified as a hallucination}",
            )
        return PassResult()


In [30]:
import pytest
from guardrails import Guard

guard = Guard().use(
    GroundedaiHallucination(quant=False)
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [31]:
guard.validate({
    "query": "What is the capital of France?",
    "response": "The capital of France is Paris.",
    "reference": "Paris is the capital and most populous city of France."
})

ValidationError: 1 validation error for Inputs
llm_output
  Input should be a valid string [type=string_type, input_value={'query': 'What is the ca...pulous city of France.'}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.9/v/string_type

In [28]:
def test_hallucination_pass():
    test_input = {
        "query": "What is the capital of France?",
        "response": "The capital of France is Paris.",
        "reference": "Paris is the capital and most populous city of France."
    }
    result = guard.validate(test_input)

    assert result.validation_passed is True
    assert result.validated_output == test_input

def test_hallucination_fail():
    with pytest.raises(Exception) as exc_info:
        test_input = {
            "query": "What is the capital of France?",
            "response": "The capital of France is London.",
            "reference": "Paris is the capital and most populous city of France."
        }
        guard.parse(test_input)

    # Assert the exception has your error_message
    assert str(exc_info.value) == "Validation failed for field with errors: {The provided input was classified as a hallucination}"

def test_no_reference():
    test_input = {
        "query": "Who was the first person to walk on the moon?",
        "response": "Neil Armstrong was the first person to walk on the moon.",
        "reference": ""
    }
    result = guard.parse(test_input)

    assert result.validation_passed is True
    assert result.validated_output == test_input

def test_complex_query():
    test_input = {
        "query": "Explain the process of photosynthesis.",
        "response": "Photosynthesis is the process by which plants use sunlight, water, and carbon dioxide to produce oxygen and energy in the form of sugar.",
        "reference": "Photosynthesis is a process used by plants and other organisms to convert light energy into chemical energy that can later be released to fuel the organism's activities."
    }
    result = guard.parse(test_input)

    assert result.validation_passed is True
    assert result.validated_output == test_input