https://haystack.deepset.ai/tutorials/28_structured_output_with_loop

Loop-based autocorrection

This tutorial uses gpt-4o-mini to change unstructured passages into JSON outputs that follow the Pydantic schema. It uses a custom OutputValidator component to validate the JSON and loop back to make corrections, if necessary.

In [25]:
import logging
logging.basicConfig()
logging.getLogger('canls.pipeline.pipeline').setLevel(logging.DEBUG)

from typing import List,Optional
from pydantic import BaseModel

#Haystack
from haystack import component
from haystack.dataclasses import ChatMessage
import json

from pprint import pprint


In [15]:
import os
openai_key = False
hf_token = False 
with open("secrets") as file:
    for line in file.readlines():
        key,value = line.strip().split("=")
        if key == 'OPENAI_API_KEY':
            openai_key = True
            os.environ[key]=value
        elif key == 'HF_TOKEN':
            hf_token = True
            os.environ[key]=value
assert openai_key, 'OPENAI_API_KEY not found'
assert hf_token, 'HF_TOKEN not found'


In [5]:
class City(BaseModel):
    name: str
    country: str
    population: int

class CitiesData(BaseModel):
    cities: List[City]

#Createa a json schemd
json_schema = CitiesData.schema_json(indent = 2)

In [42]:
# Validate the output generated by LLM
@component
class OutputValidator:
    def __init__(self, pydantic_model: BaseModel):
        self.pydantic_model = pydantic_model
        self.iter_count = 0
    
    @component.output_types(valid_replies = List[str], invalid_replies=Optional[List[str]], error_msg = Optional[str])
    def run(self, replies: List[ChatMessage]):
        self.iter_count += 1

        try:
            output_dict = json.loads(replies[0].text)
            self.pydantic_model.model_validate(output_dict)
            print(f'[OK], iter={self.iter_count} valid json')
            return {'valid_replies': replies}
        
        except ValueError as e:
            print(f'[NOT OK] error from validator {e}\niteration={self.iter_count}\nLLM output:{str(e)}')
            return {'invalid_replies' : replies, 'error_msg': str(e)}
        

In [43]:
# Making the prompt

from haystack.components.builders import ChatPromptBuilder


prompt_template = [
    ChatMessage.from_user(
        """
Create a JSON object from the information present in this passage: {{passage}}.
Only use information that is present in the passage. Follow this JSON schema, but only return the actual instances without any additional schema definition:
{{schema}}
Make sure your response is a dict and not a list.
{% if invalid_replies and error_message %}
  You already created the following output in a previous attempt: {{invalid_replies}}
  However, this doesn't comply with the format requirements from above and triggered this Python exception: {{error_message}}
  Correct the output and try again. Just return the corrected output without any extra explanations.
{% endif %}
"""
    )
]

In [44]:
# init chat generator
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack import Pipeline

In [45]:
pipeline = Pipeline(max_runs_per_component=5)

pipeline.add_component(name="prompt_builder", instance = ChatPromptBuilder(template=prompt_template))
pipeline.add_component(name="llm", instance = OpenAIChatGenerator())
pipeline.add_component(name="output_validator", instance = OutputValidator(pydantic_model=CitiesData))

#connect
pipeline.connect('prompt_builder.prompt', 'llm.messages')
pipeline.connect('llm.replies', 'output_validator')
pipeline.connect('output_validator.invalid_replies', "prompt_builder.invalid_replies")
pipeline.connect('output_validator.error_msg', "prompt_builder.error_message")

pipeline.draw('sample.png')


In [46]:
test = json.loads('''
{
  "cities": [
    {
      "name": "Berlin",
      "country": "Germany",
      "population": 3850809
    },
    {
      "name": "Paris",
      "country": "France",
      "population": 2161000
    },
    {
      "name": "Lisbon",
      "country": "Portugal",
      "population": 504718
    }
  ]
}

''')

pprint(test)

{'cities': [{'country': 'Germany', 'name': 'Berlin', 'population': 3850809},
            {'country': 'France', 'name': 'Paris', 'population': 2161000},
            {'country': 'Portugal', 'name': 'Lisbon', 'population': 504718}]}


In [47]:
passage = "Berlin is the capital of Germany. It has a population of 3,850,809. Paris, France's capital, has 2.161 million residents. Lisbon is the capital and the largest city of Portugal with the population of 504,718."
result = pipeline.run({"prompt_builder": {"passage": passage, "schema": json_schema}})


[OK], iter=1 valid json




In [48]:
pprint(result['output_validator']['valid_replies'][0].text)

('{\n'
 '  "cities": [\n'
 '    {\n'
 '      "name": "Berlin",\n'
 '      "country": "Germany",\n'
 '      "population": 3850809\n'
 '    },\n'
 '    {\n'
 '      "name": "Paris",\n'
 '      "country": "France",\n'
 '      "population": 2161000\n'
 '    },\n'
 '    {\n'
 '      "name": "Lisbon",\n'
 '      "country": "Portugal",\n'
 '      "population": 504718\n'
 '    }\n'
 '  ]\n'
 '}')


In [49]:
print(json.loads(result['output_validator']['valid_replies'][0].text))

{'cities': [{'name': 'Berlin', 'country': 'Germany', 'population': 3850809}, {'name': 'Paris', 'country': 'France', 'population': 2161000}, {'name': 'Lisbon', 'country': 'Portugal', 'population': 504718}]}
