
# **Extracting Structured Data from Unstructured Text with LLMs**

This notebook demonstrates how to extract data in JSON, YAML, or other structured formats from long text documents, such as transcripts, PDFs, HTML, etc. Not only that, but I demonstrate how to perform data extraction efficiently and scalably, with support for concurrent processing and schema-based validation to ensure accuracy and consistency of the extracted data.  

## Motivation

It can be challenging to extract specific information from a large body of text. For example, suppose you wanted to get a list of all the names mentioned in a long transcript. Writing a Python script to do this would be challenging. You could try a simple keyword search for common names or invent some heuristic, but these method are unlikely to yield consistent results. You could also try creating [embeddings](/handbook/embeddings) for the transcript and list of names, then doing a vector similarity search, but again, this is difficult to do well, especially without missing any names.  

Large language models (LLMs), however, are well-suited for this task since they possess a robust world model, whereby we can ask for just "names" without having to be too specific. They are also flexible in allowing us to ask for the results in a variety of formats, like JSON or YAML. From there, we can use these structured outputs much more easily in other programs.  

## Example output

Given the transcript for the Berkshire Hathaway annual shareholders meeting, we'll use an LLM and the code in this notebook to extract the names and organizations mentioned, as follows:  

```
{
    "names": [
        "Judd Zaberski",
        "Sue Decker",
        "Charlie Munger",
        "Jane Frazier",
        "Randy Jeffs",
        "Ken Chenault",
        "Wally Weiss",
        "Ron Olson",
        "Warren Buffett",
        ...
    ],
    "organisations": [
        "Apple",
        "BNSF",
        "Berkshire",
        "Berkshire Hathaway",
        "Gap",
        "Travelers",
        "Kelly Toys",
        "Allstate",
        ...
    ]
}
```

## Why should you read this notebook?

You want to:
- Extract data from long bodies of text  
- Ensure the data results are in a useful format such as JSON or YAML, which you can then use as input to another process or application.  
- Perform data extraction efficiently and scalably, with support for concurrent processing and schema-based validation to ensure accuracy and consistency of the extracted data.  

## Source Code

The Python scripts used in this notebook are available in the [`ai-cookbook`](https://github.com/gadkins/ai-cookbook/tree/main/data-processing/data-extraction) repo on my GitHub.

**Attribution:** Much of the code here is based on the work of [Trelis Research](https://www.youtube.com/watch?v=zmf1Kujygt8), with modifications for my needs. I've also rewritten parts, so it can be run interatively in a Jupyter Notebook.

## Pre-requisites

- Access to an LLM, such as OpenAI or an open-source model like OpenChat 3.5.  

# Get Started

The following code allows for the extraction of structured data (JSON or YAML) from unstructured text in an efficient and scalable manner, with support for concurrent processing and schema-based validation to ensure accuracy and consistency of the extracted data.

It starts by:  
- Parsing command-line arguments to configure the extraction process  
- Reads the input text file and splits it into manageable chunks.   - For each chunk, generate a prompt based on a predefined schema  
- Send these prompts to a model via an API (in this example OpenChat 3.5 hosted on Runpod.io).  
- Responses are validated and aggregated according to the schema, then compiled into a final output file in the desired format (JSON or YAML).  

This process allows for the extraction of structured data from unstructured text in an efficient and scalable manner, with support for concurrent processing and schema-based validation to ensure accuracy and consistency of the extracted data.

## Install dependencies

In [None]:
# Install required packages
!pip install -q -U transformers tqdm jsonschema pyyaml termcolor dotenv tenacity

## If using Google Drive to store input/output files

In [5]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Import libraries

In [8]:
import yaml
import json
import os
import argparse
from termcolor import colored
import subprocess
import time
from tenacity import retry, wait_random_exponential, stop_after_attempt
from transformers import AutoTokenizer
from tqdm import tqdm
import concurrent.futures
import jsonschema
from jsonschema import validate


## If using a private model on Hugging Face

In [None]:
## Authenticate to Hugging Face to pull and push models
# !pip install huggingface_hub -q
# from huggingface_hub import notebook_login

# notebook_login()

## Configuration

In [40]:
# Here I'm using a self-hosted model, but you could swap this for gpt-3.5-turbo, etc.
# I've made available a one-click deployment template for Runpod here:  
# https://runpod.io/console/gpu-cloud?template=t6sgcn049x&ref=n2u8jwou
model = "openchat/openchat_3.5" # for extraction
api_endpoint = "https://xd3lef1do5g8d0-8080.proxy.runpod.net" # where its being served


# Towards the end of this notebook, we'll instantiate this class and use it to
# perform extraction on a text file
class Config:
    def __init__(self, chunk_length=8000, output_format="json", output_file_name="output",
                 batching=True, input_file_name="input.txt"):
        self.chunk_length = chunk_length
        self.output_format = output_format
        self.output_file_name = output_file_name
        self.batching = batching
        self.input_file_name = input_file_name


## Utility Functions

In [29]:
# utils.py
from termcolor import colored
import os


def pretty_print_conversation(messages):
    role_to_color = {
        "system": "red",
        "user": "green",
        "assistant": "blue",
        "tool": "magenta",
    }

    for message in messages:
        if message["role"] == "system":
            print(
                colored(
                    f"system: {message['content']}\n", role_to_color[message["role"]]
                )
            )
        elif message["role"] == "user":
            print(
                colored(f"user: {message['content']}\n", role_to_color[message["role"]])
            )
            with open("user_request.txt", "w") as file:
                file.write(message["content"] + "\n")
        elif message["role"] == "assistant" and message.get("function_call"):
            print(
                colored(
                    f"assistant: {message['function_call']}\n",
                    role_to_color[message["role"]],
                )
            )
        elif message["role"] == "assistant" and not message.get("function_call"):
            print(
                colored(
                    f"assistant: {message['content']}\n", role_to_color[message["role"]]
                )
            )
        elif message["role"] == "tool":
            print(
                colored(
                    f"function ({message['name']}): {message['content']}\n",
                    role_to_color[message["role"]],
                )
            )


def read_text_file(text_file):
    with open(text_file, "r") as file:
        text = file.read()
    return text


def check_output_file_format(output_file_name, output_format):
    # Check if output_file has an extension
    _, file_extension = os.path.splitext(output_file_name)
    if not file_extension:
        # If not, add extension based on output_format
        output_file_name = f"{output_file_name}.{output_format}"

    return output_file_name


## Prompt

 This code is generates prompts that guide the extraction of structured data (like names and organizations) from unstructured text. It supports handling JSON and YAML schemas to define the structure of the data to be extracted.



In [30]:
# prompts.py
import json
import yaml

def read_schema(file_path):
    """Reads a JSON or YAML schema from a given file path."""
    try:
        if file_path.endswith('.json'):
            with open(file_path, "r") as f:
                return json.load(f)
        elif file_path.endswith('.yaml') or file_path.endswith('.yml'):
            with open(file_path, "r") as f:
                return yaml.safe_load(f)
        else:
            raise ValueError("Unsupported file format. Please use '.json' or '.yaml/.yml'.")
    except Exception as e:
        raise FileNotFoundError(f"Error reading file: {e}")

def generate_example(schema):
    """Generates an example object based on the provided schema."""
    example = {}
    for key, value in schema["properties"].items():
        data_type = value.get("type", "string")
        if isinstance(data_type, list):
            data_type = data_type[0]

        example[key] = {
            "string": f"sample_string",
            "integer": 1,
            "boolean": True,
            "array": generate_array_example(value)
        }.get(data_type, "sample_value")

    return example

def generate_array_example(value):
    """Generates an example array based on the array type in schema."""
    item_type = value.get("items", {}).get("type", "string")
    if isinstance(item_type, list):
        item_type = item_type[0]

    return {
        "string": [f"sample_string_{i+1}" for i in range(2)],
        "integer": [i+1 for i in range(2)],
        "boolean": [True, False]
    }.get(item_type, ["sample_value"])

def create_extract_prompt(schema, data_format):
    """Creates an extraction prompt based on the provided schema and data format."""
    example = generate_example(schema)
    if data_format.lower() == "json":
        schema_str = json.dumps(schema, indent=4)
        example_str = json.dumps(example, indent=4)
    elif data_format.lower() == "yaml":
        schema_str = yaml.dump(schema, default_flow_style=False, indent=4)
        example_str = yaml.dump(example, default_flow_style=False, indent=4)
    else:
        raise ValueError("Unsupported data format. Please use 'JSON' or 'YAML'.")

    prompt = (
        f"Extract names and organizations from the provided text, and return them in {data_format} format. "
        f"Use the following schema:\n\n{schema_str}\n\n"
        f"Here's an example of a response in {data_format} format:\n\n{example_str}\n\n"
        f"Do not include anything that is not explicitly mentioned in the text. "
        f"Analyse the text carefully to ensure all requested data is extracted. "
        f"Include each name and organization only once. "
        f"Adhere strictly to the response format without adding extra spaces or text."
    )
    return prompt


## Aggregate (JSON)

This code defines a Python class named JSONAggregator that performs several operations on JSON data based on a provided schema. The class is designed to aggregate data from multiple JSON documents into a single structured format, ensuring that the aggregated data conforms to a predefined schema.

#### `json_schema.json`
```json
{
    "type": "object",
    "properties": {
        "names": {
            "type": "array",
            "items": {
                "type": "string"
            }
        },
        "organisations": {
            "type": "array",
            "items": {
                "type": "string"
            }
        }
    },
    "required": [
        "names",
        "organisations"
    ]
}
```

The primary method `aggregate_json()` attempts to validate the provided JSON data using the `validate_json` method. If the data is valid, it aggregates the data into `self.aggregated_data`. For each key in the input data, if the corresponding value is a list, the method updates the set associated with that key in the aggregated data (to ensure uniqueness and handle the aggregation of list values). This method increments the self.success counter if the data is valid; otherwise, it increments the `self.fail counter`.

In [31]:
# json_validation_aggregation.py
import json
import jsonschema
from jsonschema import validate
from typing import Dict, Any

class JsonAggregator:
    def __init__(self, schema_file: str):
        self.schema = self.load_schema(schema_file)
        self.aggregated_data = {key: set() for key in self.schema["properties"].keys()}
        self.success = 0
        self.fail = 0

    def load_schema(self, schema_file: str) -> Dict[str, Any]:
        with open(schema_file, "r") as file:
            return json.load(file)

    def validate_json(self, data: Dict[str, Any]) -> bool:
        try:
            validate(instance=data, schema=self.schema)
            return True
        except jsonschema.exceptions.ValidationError:
            return False

    def aggregate_json(self, json_data: Dict[str, Any]):
        if self.validate_json(json_data):
            self.success += 1
            for key, values in json_data.items():
                if isinstance(values, list):
                    self.aggregated_data[key].update(values)
        else:
            self.fail += 1

    def write_aggregated_data(self, output_file: str):
        final_data = {key: list(value) for key, value in self.aggregated_data.items()}
        with open(output_file, "w") as file:
            json.dump(final_data, file, indent=4)
        print(f"Aggregation complete! The aggregated data has been written to '{output_file}'.")

# Example usage
# aggregator = JsonAggregator("your_schema_file.json")
# aggregator.aggregate_json(your_json_data)
# aggregator.write_aggregated_data("output_file.json")

## Aggregate (YAML)

This code defines a Python class named YamlAggregator that performs several operations on YAML data based on a provided schema. The class is designed to aggregate data from multiple YAML documents into a single structured format, ensuring that the aggregated data conforms to a predefined schema.

#### `yaml_schema.yaml`

```yaml
type: object
properties:
  names:
    type: array
    items:
      type: string
  organisations:
    type: array
    items:
      type: string
required: [names, organisations]

```

The primary method `aggregate_yaml()`, attempts to validate a piece of YAML data using the `validate_yaml` method. If the data is valid, it proceeds to aggregate it into the `aggregated_data` dictionary. For each key in the input data that matches a key in the schema, it either appends the value to a list (if the schema expects an array) or updates the value directly (if the schema expects a single value). If the schema specifies an array, the method also de-duplicates and sorts the list. The method increments the success or fail counter based on whether the data was valid.

In [32]:
# yaml_validation_aggregation
import yaml
from jsonschema import validate
from typing import Dict, Any
import jsonschema

class YamlAggregator:
    """
    A class used to aggregate YAML data based on a provided schema.

    ...

    Attributes
    ----------
    schema : Dict[str, Any]
        a dictionary representing the YAML schema
    aggregated_data : Dict[str, Any]
        a dictionary to store the aggregated data

    Methods
    -------
    load_schema(schema_file: str)
        Loads the YAML schema from a file.
    validate_yaml(data: Dict[str, Any])
        Validates the YAML data against the schema.
    aggregate_yaml(yaml_data: Dict[str, Any])
        Aggregates the YAML data.
    write_aggregated_data(output_file: str)
        Writes the aggregated data to a file.
    """

    def __init__(self, schema_file: str):
        self.schema = self.load_schema(schema_file)
        self.aggregated_data = {
            key: [] if self.schema["properties"][key]["type"] == "array" else None
            for key in self.schema["properties"].keys()
        }
        self.success = 0
        self.fail = 0

    def load_schema(self, schema_file: str) -> Dict[str, Any]:
        """Loads the YAML schema from a file."""
        with open(schema_file, "r") as file:
            return yaml.safe_load(file)

    def validate_yaml(self, data: Dict[str, Any]) -> bool:
        """Validates the YAML data against the schema."""
        try:
            validate(instance=data, schema=self.schema)
            print("YAML validation successful!")
            return True
        except jsonschema.exceptions.ValidationError as ve:
            print(f"Invalid yaml error - {ve}")
            return False

    def aggregate_yaml(self, yaml_data: Dict[str, Any]):
        """Aggregates the YAML data."""
        # Validate the YAML data
        is_valid = self.validate_yaml(yaml_data)
        if is_valid:
            self.success += 1
            # Aggregate the data
            for key, value in yaml_data.items():
                if key in self.aggregated_data:
                    # If the key is in the aggregated data, append or update the value based on its type
                    if self.schema["properties"][key]["type"] == "array":
                        # If the value is a list, extend the existing list
                        self.aggregated_data[key].extend(value)
                        # De-duplicate and sort the list
                        self.aggregated_data[key] = sorted(set(self.aggregated_data[key]))
                    else:
                        # If the value is not a list, update the existing value
                        self.aggregated_data[key] = value
        else:
            self.fail += 1

    def write_aggregated_data(self, output_file: str):
        """Writes the aggregated data to a file."""
        with open(output_file, "w") as file:
            yaml.dump(self.aggregated_data, file)
        print(
            f"Aggregation complete! The aggregated data has been written to '{output_file}'."
        )


## Chat completion request

Make a request to the chat completions API (In my example, I'm using OpenChat 3.5 deploy on Runpod.io.)  

In [41]:
import os
import subprocess
import json
import time
from tenacity import retry, wait_random_exponential, stop_after_attempt
from transformers import AutoTokenizer

# model = "openchat/openchat_3.5" # for extraction
# api_endpoint = "https://xd3lef1do5g8d0-8080.proxy.runpod.net" # model endpoint

tgi_api_base = api_endpoint + "/generate"

tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)

# # Manual chat template
# tokenizer.chat_template = '''{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{%- set ns = namespace(found=false) -%}{%- for message in messages -%}{%- if message['role'] == 'system' -%}{%- set ns.found = true -%}{%- endif -%}{%- endfor -%}{{bos_token}}{%- if not ns.found -%}{# Suppressed System Message #}{%- endif %}{%- for message in messages %}{%- if message['role'] != 'system' %}{%- if message['role'] == 'user' %}{{'### Instruction:\\n' + message['content'] + '\\n'}}{%- else %}{{'### Response:\\n' + message['content'] + '\\n\\n'}}{%- endif %}{%- endif %}{%- endfor %}{% if add_generation_prompt %}{{'### Response:'}}{% endif %}'''

@retry(wait=wait_random_exponential(multiplier=1, max=40), stop=stop_after_attempt(3))
def chat_completion_request_runpod(messages):
    # formatted_messages = format_messages(messages)

    formatted_messages = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )

    print(formatted_messages)

    # Properly escape the string for JSON and for shell execution
    json_payload = json.dumps(
        {
            "inputs": formatted_messages,
            "parameters": {
                "max_new_tokens": 500,
                "do_sample": False,
                # "repetition_penalty": 1.1, #can be useful for json, less so for yaml.
            },
        }
    )
    escaped_json_payload = json_payload.replace(
        "'", "'\\''"
    )  # Escape single quotes for shell

    start_time = time.time()  # Start timing

    try:
        # Execute the curl command
        curl_command = f"curl -s {tgi_api_base} -X POST -d '{escaped_json_payload}' -H 'Content-Type: application/json'"

        response = subprocess.run(
            curl_command, shell=True, check=True, stdout=subprocess.PIPE
        )
        response_time = time.time() - start_time  # Calculate response time

        response = response.stdout.decode()

        # print(response)

        response = json.loads(response).get("generated_text", "No generated text found")

        # Calculate tokens per second
        tokens_generated = len(response) / 4  # assuming 4 characters per word
        tokens_per_second = tokens_generated / response_time if response_time > 0 else 0

        # Print time taken and tokens per second
        print(f"Tokens generated: {tokens_generated:.2f}")
        print(f"Total Time Taken: {response_time:.2f} seconds")
        print(f"Tokens per Second: {tokens_per_second:.2f}")
        print(response)

        return response
    except subprocess.CalledProcessError as e:
        print("Unable to generate ChatCompletion response")
        print(f"Exception: {e}")
        return str(e)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/1.62k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/491 [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


## Extract

In [64]:
## Instatiate a new configuration
project_dir = "/content/drive/My Drive/data_extraction" # Adjust if using Google Drive
output_dir = "/outputs"

config = Config(
    chunk_length=8000, # Customize this as needed
    output_format="json", # or "yaml"
    output_file_name="output.json", # Adjust based on your preference
    batching=True,
    input_file_name=f"{project_dir}/input_files/berkshire23_60k.txt"
)

# Define schema
json_schema_file = f"{project_dir}/json_files/json_schema.json"
yaml_schema_file = f"{project_dir}/yaml_files/yaml_schema.yaml"
json_schema = read_schema(json_schema_file) # Adjust location as necessary
yaml_schema = read_schema(yaml_schema_file) # Adjust location as necessary

# Create prompts
json_extract_prompt = create_extract_prompt(json_schema, "JSON")
yaml_extract_prompt = create_extract_prompt(yaml_schema, "YAML")

In [43]:
# Read input file
text = read_text_file(config.input_file_name)

In [44]:
# Prepare prompts
prompt = json_extract_prompt if config.output_format == "json" else yaml_extract_prompt


In [45]:
# Split text into chunks
block_size = config.chunk_length
chunks = [text[i : i + block_size] for i in range(0, len(text), block_size)]


In [65]:
import concurrent.futures

# Define a function to send a request
def send_request(message):
    chat_response = chat_completion_request_runpod([message])
    return chat_response, message


# Define a function to process the chat response
def process_chat_response(chat_response, output_format):
    try:
        chat_response_dict = (
            json.loads(chat_response)
            if output_format == "json"
            else yaml.safe_load(chat_response.strip())
        )
        aggregator.aggregate_json(
            chat_response_dict
        ) if output_format == "json" else aggregator.aggregate_yaml(chat_response_dict)
    except (json.JSONDecodeError, yaml.YAMLError):
        print(f"Invalid {output_format.upper()} in chat response: {chat_response}")
        aggregator.fail += 1


# Create messages
message_lists = [
    [
        {
            "role": "user",
            "content": f"""{prompt}\n\n[TEXT_START]\n\n...{text[i : i + block_size]}...\n\n[TEXT_END]\n\nNow, answer immediately and only in {config.output_format} format.""",
        }
    ]
    for i in range(0, len(text), block_size)
]

if config.batching:
    # Initialize a counter
    request_counter = 0

    with concurrent.futures.ThreadPoolExecutor() as executor:
        # Send the requests in parallel
        future_to_chat_response = {
            executor.submit(send_request, messages[0]): messages for messages in message_lists
        }

        for future in concurrent.futures.as_completed(future_to_chat_response):
            messages = future_to_chat_response[future]
            try:
                chat_response, _ = future.result()
            except Exception as exc:
                print(f"{messages[0]} generated an exception: {exc}")
            else:
                # Increment the counter
                request_counter += 1

                # Process the chat response
                process_chat_response(chat_response, config.output_format)

    print(f"Total number of requests: {request_counter}")

else:
    for messages in tqdm(message_lists):
        chat_response = chat_completion_request_runpod(messages)

        # Process the chat response
        process_chat_response(chat_response, config.output_format)

# Write the aggregated data to a file
aggregator.write_aggregated_data(f"{project_dir+output_dir}/{config.output_file_name}")
if not aggregator.success:
    print("All validations failed")
else:
    total_attempts = aggregator.success + aggregator.fail
    if total_attempts > 0:
        error_rate = aggregator.fail / total_attempts
        print(f"Error rate is {error_rate}")
    else:
        print("No attempts were made, so the error rate cannot be calculated.")



<s>GPT4 Correct User: Extract names and organizations from the provided text, and return them in JSON format. Use the following schema:

{
    "type": "object",
    "properties": {
        "names": {
            "type": "array",
            "items": {
                "type": "string"
            }
        },
        "organisations": {
            "type": "array",
            "items": {
                "type": "string"
            }
        }
    },
    "required": [
        "names",
        "organisations"
    ]
}

Here's an example of a response in JSON format:

{
    "names": [
        "sample_string_1",
        "sample_string_2"
    ],
    "organisations": [
        "sample_string_1",
        "sample_string_2"
    ]
}

Do not include anything that is not explicitly mentioned in the text. Analyse the text carefully to ensure all requested data is extracted. Include each name and organization only once. Adhere strictly to the response format without adding extra spaces or text.

[TEXT