In [10]:
# Benchmark: Text Extraction from FUNSD Dataset
# Comparing ChatGPT and Anthropic Claude models

import numpy as np
import pandas as pd
from datasets import DatasetDict
from sklearn.metrics import jaccard_score
from typing import List, Dict
import anthropic
import base64
from openai import OpenAI
import os
from dotenv import load_dotenv
import re
import azure_di


# Load environment variables from .env file
load_dotenv()

# Replace the API key assignments with these lines
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
ANclient = anthropic.Anthropic(api_key=anthropic_api_key)

openai_api_key = os.getenv("OPENAI_API_KEY")
# openai_api_key = "sk-proj-"
OAIclient = OpenAI(api_key=openai_api_key)


In [11]:
# Load the FUNSD dataset
dataset_path = "dataset/prepared/"
dataset = DatasetDict.load_from_disk(dataset_path)

In [12]:
system_message = "You are an OCR bot for extraction information various documents. When given an image representing a document, you will extract ALL the content of the document. Under no circumstance leave out any details no matter how small! Extract the text as seen in the document without modifying any characters, even if it's a typo. ALWAYS respond in this form: RESPONSE: <extracted_text>. Do not add any further comments. The extracted text should be in the same order as it appears in the document."

In [13]:
def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')


def get_model_response(model: str, image_path: str) -> str:
    base64_image = encode_image(image_path)
    if model.startswith("gpt"):
        # OpenAI API call with image
        response = OAIclient.chat.completions.create(
            model=model,
            messages=[
                {
                    "role": "system",
                    "content": [
                        {
                            "type": "text",
                            "text": system_message
                        }
                    ]
                },
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image_url",
                            "image_url": {"url": f"data:image/png;base64,{base64_image}"}
                        },
                    ],
                }
            ],
            max_tokens=2000
        )
        return response.choices[0].message.content
    elif model.startswith("claude"):
        response = ANclient.messages.create(
            model=model,
            max_tokens=2000,
            system=system_message,
            messages=[
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            "source": {
                                "type": "base64",
                                "media_type": "image/png",
                                "data": f"{base64_image}"
                            }
                        }
                    ]
                }
            ]
        )
        return response.content[0].text
    elif model.startswith("azure"):
        response = azure_di.analyze_read(image_path)
        return "RESPONSE: " + response
    else:
        raise ValueError(f"Unknown model: {model}")

In [5]:
path = dataset["test"][5]["image_path"]

In [None]:
get_model_response("azure-document-intelligence-prebuilt", path)

In [None]:
res = get_model_response("gpt-4-turbo", path)
res

In [None]:
res = get_model_response("claude-3-5-sonnet-20240620", path)
res

In [14]:
def compare_llm_output_to_labels(llm_output, labels):
    # Step 1: Filter out one-letter strings from labels
    filtered_labels = set([label for label in labels if len(label) > 2])
    
    llm_output_lower = llm_output.lower()

    llm_output_no_spaces = re.sub(r'\s*([^\w\s])\s*', r'\1', llm_output_lower)


    matched_labels = set()
    for label in filtered_labels:
        if (label.lower() in llm_output.lower()) or (label.lower() in llm_output_no_spaces):
            matched_labels.add(label)

    # Calculate metrics
    true_positives = len(matched_labels)
    false_negatives = len(filtered_labels) - len(matched_labels)

    recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0

    return {
        'llm_output': llm_output,
        'filtered_labels': filtered_labels,
        'matched_labels': matched_labels,
        'unmatched_labels': set(filtered_labels) - matched_labels,
        'recall': recall,
        'true_positives': true_positives,
        'false_negatives': false_negatives
    }

# Example usage
labels = ['AAA', ':', 'BBB', 'CCC', 'DDD', 'EEE ', 'DDD', '132-123/123', 'X@Y.com', '10,000.00', 'A&B Corp.']
llm_output = "AAA is my name BBB CCC EE 132 - 123 / 123 X @ Y . com 10, 000.00 A & B Corp."

result = compare_llm_output_to_labels(llm_output, labels)

print(f"Matched labels: {result['matched_labels']}")
print(f"Unmatched labels: {result['unmatched_labels']}")
print(f"Recall: {result['recall']:.2f}")
print(f"True Positives: {result['true_positives']}")
print(f"False Negatives: {result['false_negatives']}")

Matched labels: {'132-123/123', 'X@Y.com', 'AAA', 'CCC', '10,000.00', 'A&B Corp.', 'BBB'}
Unmatched labels: {'DDD', 'EEE '}
Recall: 0.78
True Positives: 7
False Negatives: 2


In [15]:
# List of models to benchmark
models = [
    "azure-document-intelligence-prebuilt",
    "gpt-4-turbo",
    "gpt-4o-mini",
    "claude-3-haiku-20240307",
    "gpt-4o",
    "claude-3-5-sonnet-20240620",
]

In [None]:
results = {model : [] for model in models}
for i, dat in enumerate(dataset["test"]):
    print(i)
    labels = dat["tokens"]
    image_path = dat["image_path"]
    path = image_path.replace("\\", "/")
    print(f"Image path: {path}")
    try:
        llm_outputs = {model: get_model_response(model, path) for model in models}
    except Exception as e:
        print(f"Error processing image: {path}")
        print(e)
        continue
    for model, llm_output in llm_outputs.items():
        print(f"Model: {model}")
        if "RESPONSE:" not in llm_output:
            print("No response")
            continue
        output = llm_output.split("RESPONSE:")[1]
        result = compare_llm_output_to_labels(llm_output, labels)
        print(result)
        result['file'] = path
        results[model].append(result)
        

In [17]:
# results

In [18]:
for i, dat in enumerate(dataset["train"]):
    print(i)
    labels = dat["tokens"]
    image_path = dat["image_path"]
    path = image_path.replace("\\", "/")
    print(f"Image path: {path}")
    try:
        llm_outputs = {model: get_model_response(model, path) for model in models}
    except Exception as e:
        print(f"Error processing image: {path}")
        print(e)
        continue
    for model, llm_output in llm_outputs.items():
        print(f"Model: {model}")
        if "RESPONSE:" not in llm_output:
            print("No response")
            continue
        output = llm_output.split("RESPONSE:")[1]
        result = compare_llm_output_to_labels(llm_output, labels)
        print(result)
        result['file'] = path
        results[model].append(result)
        

0
Image path: dataset/training_data/images/0000971160.png
Model: azure-document-intelligence-prebuilt
{'llm_output': 'RESPONSE: R&D\nR&D QUALITY IMPROVEMENT SUGGESTION/SOLUTION FORM\nName/Phone Ext .: M. Hamann, P. Harper, P. Martinez\nDate:\n9/3/92\nSupervisor/Manager: J. S. Wigand\nR&D Group:\nLicensee\nSuggestion: Discontinue coal retention analyses on licensee submitted product samples. (Note: Coal Retention testing is not performed by most licensees. Other B&W physical measurements as ends stability and inspection for soft spots in cigarettes are thought to be sufficient measures to assure cigarette physical integrity. The proposed action will increase laboratory productivity. )\nSuggested Solution(s): Delete coal retention from the list of standard analyses performed on licensee submitted product samples. Special requests for coal retention testing could still be submitted on an exception basis.\nHave you contacted your Manager/Supervisor?\nYes\nNo\nManager Comments: Manager, ple

In [19]:
len(results['azure-document-intelligence-prebuilt'])

199

In [20]:
def set_to_list(obj):
    if isinstance(obj, set):
        return list(obj)
    return obj

In [21]:
# Export document-oriented VQA results to json file
import json
with open('funsd_benchmark_results.json', 'w') as f:
    json.dump(results, f, default=set_to_list, indent=4)

## Plot

In [None]:
import json
import matplotlib.pyplot as plt
import numpy as np

# Load the JSON data
with open('funsd_benchmark_results.json', 'r') as f:
    data = json.load(f)

# Extract recall values for each model
model_recalls = {}
for model, result in data.items():
    print(model)
    recalls = [example['recall'] for example in result]
    model_recalls[model] = np.mean(recalls)

# Prepare data for plotting
models = list(model_recalls.keys())
recalls = list(model_recalls.values())

# Create the bar plot
plt.figure(figsize=(10, 6))
plt.bar(models, recalls)
plt.title('Average Recall by Model')
plt.xlabel('Model')
plt.ylabel('Average Recall')
plt.ylim(0, 1)  # Set y-axis limit from 0 to 1 for recall values
plt.xticks(rotation=45, ha='right')  # Rotate x-axis labels for better readability

# Add value labels on top of each bar
for i, v in enumerate(recalls):
    plt.text(i, v, f'{v:.3f}', ha='center', va='bottom')

plt.tight_layout()
plt.show()