In [2]:
from typing import List
from haystack.components.builders import PromptBuilder
from pydantic import BaseModel

In [3]:
class City(BaseModel):
    country: str
    name: str
    population: int
    
class CityData(BaseModel):
    cities: List[City]    

In [5]:
json_schema = CityData.schema_json(indent=2)

In [6]:
import json
import pydantic
from pydantic import ValidationError
from typing import Optional,List
from colorama import Fore, Style
from haystack import component

In [12]:

# Define the component input parameters
@component
class OutputValidator:
    def __init__(self, pydantic_model: pydantic.BaseModel):
        self.pydantic_model = pydantic_model
        self.iteration_counter = 0

    # Define the component output
    @component.output_types(valid_replies=List[str], invalid_replies=Optional[List[str]], error_message=Optional[str])
    def run(self, replies: List[str]):

        self.iteration_counter += 1

        ## Try to parse the LLM's reply ##
        # If the LLM's reply is a valid object, return `"valid_replies"`
        try:
            output_dict = json.loads(replies[0])
            self.pydantic_model.parse_obj(output_dict)
            print(
                Fore.GREEN
                + f"OutputValidator at Iteration {self.iteration_counter}: Valid JSON from LLM - No need for looping: {replies[0]}"
            )
            return {"valid_replies": replies}

        # If the LLM's reply is corrupted or not valid, return "invalid_replies" and the "error_message" for LLM to try again
        except (ValueError, ValidationError) as e:
            print(
                Fore.RED
                + f"OutputValidator at Iteration {self.iteration_counter}: Invalid JSON from LLM - Let's try again.\n"
                f"Output from LLM:\n {replies[0]} \n"
                f"Error from OutputValidator: {e}"
            )
            return {"invalid_replies": replies, "error_message": str(e)}

In [13]:
output_validator = OutputValidator(pydantic_model=CityData)

In [14]:
prompt_template = """
Create a JSON object from the information present in this passage: {{passage}}.
Only use information that is present in the passage. Follow this JSON schema, but only return the actual instances without any additional schema definition:
{{schema}}
Make sure your response is a dict and not a list.
{% if invalid_replies and error_message %}
  You already created the following output in a previous attempt: {{invalid_replies}}
  However, this doesn't comply with the format requirements from above and triggered this Python exception: {{error_message}}
  Correct the output and try again. Just return the corrected output without any extra explanations.
{% endif %}
"""

In [20]:
prompt_builder = PromptBuilder(template=prompt_template)

In [16]:
import os
from getpass import getpass
from haystack.components.generators import OpenAIGenerator


In [18]:
os.environ["OPENAI_API_KEY"] = "sk-v0NS"
generator = OpenAIGenerator()

In [19]:
from haystack import Pipeline
pipeline = Pipeline(max_loops_allowed=5)

In [21]:
pipeline.add_component(instance=prompt_builder, name="prompt_builder")
pipeline.add_component(instance=generator, name="generatorLLM")
pipeline.add_component(instance=output_validator, name="output_validator")


In [22]:
pipeline.connect("prompt_builder", "generatorLLM")
pipeline.connect("generatorLLM", "output_validator")

<haystack.core.pipeline.pipeline.Pipeline object at 0x00000201C1C238F0>
🚅 Components
  - prompt_builder: PromptBuilder
  - generatorLLM: OpenAIGenerator
  - output_validator: OutputValidator
🛤️ Connections
  - prompt_builder.prompt -> generatorLLM.prompt (str)
  - generatorLLM.replies -> output_validator.replies (List[str])

In [23]:
passage = "Berlin is the capital of Germany. It has a population of 3,850,809. Paris, France's capital, has 2.161 million residents. Lisbon is the capital and the largest city of Portugal with the population of 504,718."


In [24]:
result = pipeline.run({"prompt_builder": {"passage": passage, "schema": json_schema}})

[32mOutputValidator at Iteration 1: Valid JSON from LLM - No need for looping: {
  "cities": [
    {
      "country": "Germany",
      "name": "Berlin",
      "population": 3850809
    },
    {
      "country": "France",
      "name": "Paris",
      "population": 2161000
    },
    {
      "country": "Portugal",
      "name": "Lisbon",
      "population": 504718
    }
  ]
}
