# Installation of required library



## Installation required for OCR

In [None]:
%pip install datasets
%pip install easyocr
%pip install matplotlib
%pip install opencv-python

In [2]:
!sudo apt-get install tesseract-ocr
!pip install pytesseract

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
  tesseract-ocr-eng tesseract-ocr-osd
The following NEW packages will be installed:
  tesseract-ocr tesseract-ocr-eng tesseract-ocr-osd
0 upgraded, 3 newly installed, 0 to remove and 49 not upgraded.
Need to get 4,816 kB of archives.
After this operation, 15.6 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 tesseract-ocr-eng all 1:4.00~git30-7274cfa-1.1 [1,591 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy/universe amd64 tesseract-ocr-osd all 1:4.00~git30-7274cfa-1.1 [2,990 kB]
Get:3 http://archive.ubuntu.com/ubuntu jammy/universe amd64 tesseract-ocr amd64 4.1.1-2.1build1 [236 kB]
Fetched 4,816 kB in 1s (4,595 kB/s)
debconf: unable to initialize frontend: Dialog
debconf: (No usable dialog-like program is installed, so the dialog based frontend cannot be used. at /usr/share/perl5/Debc

# Importing libraries

In [3]:
import os
import json
import numpy as np
import pandas as pd
import re
from difflib import get_close_matches
from datasets import load_dataset
from PIL import Image, ImageDraw
import cv2
import matplotlib.pyplot as plt

## Importing library for OCR

In [None]:
# If we use easyocr
import easyocr

In [4]:
# if we use tesseract for ocr
import pytesseract
from pytesseract import Output

# HuggingFace login

In [5]:
from google.colab import userdata
secret_hf = userdata.get('HUGGINGFACE_TOKEN')
!huggingface-cli login --token $secret_hf

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: fineGrained).
Your token has been saved to /root/.cache/huggingface/token
Login successful


# Loading the Dataset from Huggingface.

In [6]:
ds = load_dataset('katanaml-org/invoices-donut-data-v1')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading readme:   0%|          | 0.00/1.04k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/167M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/19.9M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/10.4M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/425 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/50 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/26 [00:00<?, ? examples/s]

In [7]:
print(ds.keys())

dict_keys(['train', 'validation', 'test'])


In [None]:
example = ds["train"][1]
print(example.keys())

dict_keys(['image', 'ground_truth'])


In [None]:
print(type(example['ground_truth']))

<class 'str'>


# OCR Implementaation

## OCR implementaion using easyocr

In [None]:
def load_ocr_model():
  """Loads the EasyOCR model into memory.

  Returns:
    The EasyOCR reader object.
  """

  reader = easyocr.Reader(['en'])
  return reader

In [None]:
def read_text_from_image(image_path):
  """Reads text from an image using EasyOCR.

  Args:
    image_path: The path to the image file.

  Returns:
    A list of detected text results, each containing the bounding box coordinates,
    text, and confidence.
  """

  img = cv2.imread(image_path)
  results = reader.readtext(img)
  return results

In [None]:
def visualize_text_detection(image_path, results):
  """Visualizes the detected text on the image.

  Args:
    image_path: The path to the image file.
    results: The list of detected text results.
  """

  img = cv2.imread(image_path)
  plt.imshow(img)
  for res in results:
      xy = res[0]
      xy1, xy2, xy3, xy4 = xy[0], xy[1], xy[2], xy[3]
      det, conf = res[1], res[2]
      plt.plot([xy1[0], xy2[0], xy3[0], xy4[0], xy1[0]], [xy1[1], xy2[1], xy3[1], xy4[1], xy1[1]], 'r-')
      plt.text(xy1[0], xy1[1], f'{det} [{round(conf, 2)}]')
  plt.show()

In [None]:
reader = load_ocr_model()
image_path = "image.png"
results = read_text_from_image(image_path)
visualize_text_detection(image_path, results)

## Result and Conclusion:

## OCR implementaion using tesseract

In [19]:
def ocr_using_teseract(image):
    """
    Extracts text from an image using Tesseract OCR.

    Args:
        image: The input image to process.

    Returns:
        A list of paragraphs, where each paragraph is a dictionary containing:
            - coordinates: A list of four integers representing the bounding box (x1, y1, x2, y2) of the paragraph.
            - text: The extracted text from the paragraph.
    """
    data = pytesseract.image_to_data(image, output_type=Output.DICT)

    paragraphs = []
    current_paragraph = {'coordinates': [float('inf'), float('inf'), 0, 0], 'text': ""}

    for i in range(len(data['text'])):
        text = data['text'][i].strip()
        if text and int(data['conf'][i]) > 0:
            x, y, w, h = data['left'][i], data['top'][i], data['width'][i], data['height'][i]
            current_paragraph['coordinates'][0] = min(current_paragraph['coordinates'][0], x)  # x1
            current_paragraph['coordinates'][1] = min(current_paragraph['coordinates'][1], y)  # y1
            current_paragraph['coordinates'][2] = max(current_paragraph['coordinates'][2], x + w)  # x2
            current_paragraph['coordinates'][3] = max(current_paragraph['coordinates'][3], y + h)  # y2
            current_paragraph['text'] += " " + text

            if i < len(data['text']) - 1:
                next_y = data['top'][i + 1]
                if abs(next_y - y) > 10:
                    paragraphs.append(current_paragraph)
                    current_paragraph = {'coordinates': [float('inf'), float('inf'), 0, 0], 'text': ""}

    if current_paragraph['text']:
        paragraphs.append(current_paragraph)
    return paragraphs


In [20]:
def draw_cordinates_and_save_image(paragraphs, image, filename):
    """
    Draws bounding boxes around paragraphs on an image and saves the result.

    This function takes three arguments:

    * `paragraphs`: A list of dictionaries where each dictionary represents a paragraph
    and contains a key 'coordinates' with a list of four integers representing the
    top-left and bottom-right coordinates of the paragraph's bounding box (x1, y1, x2, y2).
    * `image`: A PIL Image object representing the image on which to draw the boxes.
    * `filename`: The desired filename (without extension) to save the modified image.

    The function iterates through each paragraph in the `paragraphs` list and draws a blue
    rectangle with a width of 2 pixels around its bounding box coordinates. Finally, it
    saves the modified image as a PNG file with the provided filename and extension.

    Optionally, the function also displays the modified image using `image.show()`. You
    may want to comment out this line if you only want to save the image.
    """
    draw = ImageDraw.Draw(image)

    # Draw lines around each paragraph
    for para in paragraphs:
        x1, y1, x2, y2 = para['coordinates']
        draw.rectangle([x1, y1, x2, y2], outline="blue", width=2)

    # Save or display the image
    image.show()
    image.save(f'{filename}.png')

In [21]:
def create_and_save_dataframe(input1, input2, input3, filename):
    """Creates or updates a CSV file with a new row of data.

    Args:
      input1: The value for the 'index' column.
      input2: The value for the 'context' column.
      input3: The value for the 'context_with_coordinates' column.
      filename: The name of the CSV file.

    Returns:
      None
    """

    if os.path.exists(filename):
      df = pd.read_csv(filename)
    else:
      df = pd.DataFrame(columns=['index', 'context', 'context_with_coordinates'])
    df.loc[len(df)] = [input1, input2, input3]
    df.to_csv(filename, index=False)

In [22]:
data_part = "train"
start_row = 0
end_row = 100

In [None]:

dataset = ds[data_part].skip(start_row).take(end_row - start_row)
# dataset = ds[data_part].skip(4).take(1)
for id, example in enumerate(dataset):
    image_np = np.array(example["image"])
    image = Image.fromarray(image_np.astype(np.uint8))
    paragraphs = ocr_using_teseract(image)
    # draw_cordinates_and_save_image(paragraphs, image, "image")

    context = ""
    dict_of_paragraphs = {}
    for idx, para in enumerate(paragraphs):
        para_text = para['text'].strip()
        coordinates = para['coordinates']
        # context += f"Paragraph {idx + 1}:\n{para_text}\nCoordinates: {coordinates}\n\n"
        dict_of_paragraphs[para_text] = coordinates
        context += f"{para_text} \t\t"
    # draw_cordinates_and_save_image(paragraphs, image, "image")
    create_and_save_dataframe(id, context, dict_of_paragraphs, "ocr_results.csv")



In [23]:
def extract_invoice_date(context, context_with_coordinates):
    """
    Extracts the invoice date from the provided context string and returns it along with its coordinates from a separate dictionary.

    Args:
        context (str): The invoice data string to search for the date.
        context_with_coordinates (dict[str, Any]): A dictionary containing invoice dates (as keys) and their corresponding coordinates (as values).

    Returns:
        tuple[str, str]: A tuple containing two elements:
            - The extracted invoice date in DD/MM/YYYY format (empty string if not found).
            - The coordinates associated with the extracted date from the context_with_coordinates dictionary (or "NA" if not found or no close match exists).
    """
    match = re.search(r"\d{2}/\d{2}/\d{4}", context)

    if match:
        extracted_date = match.group()
        print("Extracted date:", extracted_date)
        coord = context_with_coordinates.get(extracted_date)
        if coord is None:
            close_matches = get_close_matches(extracted_date, context_with_coordinates.keys(), n=1)
            coord = context_with_coordinates.get(close_matches[0], "NA") if close_matches else "NA"

        return (extracted_date, coord)

    else:
        print("Date not found in the invoice data.")
        return ("", "")

In [24]:
def extract_invoice_number(context, context_with_coordinates):
    """Extracts invoice number and coordinate (if available).

    Uses regex to find an invoice number in context. If found, returns the number
    and its coordinate from context_with_coordinates (if present). Otherwise, returns
    empty strings.

    Args:
    context: The string containing the invoice data.
    context_with_coordinates: A dictionary mapping invoice numbers (strings)
    to their corresponding coordinates. May be None.

    Returns:
    A tuple of (extracted_invoice_number, coordinate).
    """
    match = re.search(r"\d+", context)
    if match:
        extracted_invoice_number = match.group()
        print("Extracted invoice number:", extracted_invoice_number)
        invoice = f"Invoice no: {extracted_invoice_number}"
        coord = context_with_coordinates.get(invoice, "NA")
        return (extracted_invoice_number, coord)

    else:
        print("Invoice number not found in the invoice data.")
        return ("" , "")

In [25]:
def extract_client_name(context, context_with_coordinates):
    """
    Extracts the client name from a given context string and retrieves its corresponding coordinates from another dictionary.

    Args:
        context (str): The text to search for the client name.
        context_with_coordinates (dict): A dictionary mapping client names to their coordinates.

    Returns:
        tuple[str, str]: A tuple containing the extracted client name (empty string if not found) 
                         and the corresponding coordinates (or "NA" if not found).
    """

    pattern = r"Client:\s*([A-Za-z, \-]+)"
    match = re.search(pattern, context)

    if match:
        client_name = match.group(1).strip()
        print(f"Client name: {client_name}")
        coord = context_with_coordinates.get(client_name, "NA")
        return (client_name, coord)
    else:
        print("Client name not found.")
        return ("" , "")

In [30]:
def extract_gross_worth(context, context_with_coordinates):
    """Extracts the maximum gross worth from a given text context and returns its formatted string and corresponding coordinate.

    Args:
        context (str): The text context to extract the gross worth from.
        context_with_coordinates (dict): A dictionary mapping gross worth strings to their corresponding coordinates.

    Returns:
        tuple(str, str): A tuple containing the formatted string of the maximum gross worth and its coordinate. If no valid gross worth is found, both values will be empty strings.
    """
    pattern = r"\$\s*\d{1,3}(?:[.,]\d{3})*(?:[.,]\d{2})"
    matches = re.findall(pattern, context)

    if matches:
        gross_values = []
        for match in matches:
            # Remove dollar sign and spaces
            cleaned_value = match.replace('$', '').replace(' ', '').strip()

            # Replace commas with dots if they are in decimal place format (e.g., "58667,76")
            if cleaned_value.count(',') == 1 and cleaned_value.count('.') == 0:
                cleaned_value = cleaned_value.replace(',', '.')

            # Remove commas used as thousand separators (e.g., "1,000,000" -> "1000000")
            cleaned_value = cleaned_value.replace(',', '')

            try:
                # Convert to float
                numeric_value = float(cleaned_value)
                gross_values.append((match, numeric_value))
            except ValueError:
                continue  # Skip if conversion fails due to bad format

        # Find the max gross value
        if gross_values:
            max_gross_value = max(gross_values, key=lambda x: x[1])[0]  # Get the original formatted string of the max value
            # print(f"Maximum gross worth: {max_gross_value}")
            coord = context_with_coordinates.get(max_gross_value, "NA")
            return (max_gross_value, coord)

    # In case no valid matches are found
    # print("Gross worth not found.")
    return ("", "")

In [36]:
import re

def extract_gross_worth(context, context_with_coordinates):
    """Extracts the maximum gross worth from a given text context and its associated coordinates.

    Args:
        context (str): The text context to search for gross worth values.
        context_with_coordinates (dict): A dictionary mapping gross worth values to their corresponding coordinates.

    Returns:
        tuple: A tuple containing the maximum gross worth value (original formatted string) and its associated coordinate.
    """
    # Updated pattern to match "$36 946,22", "$ 36 946,22", "$36,946.22", etc., without extra tabs around
    pattern = r"\$\s*\d{1,3}(?:[\s,]\d{3})*(?:[.,]\d{2})"  # Matches $ followed by space-separated or comma-separated thousands
    matches = re.findall(pattern, context)

    if matches:
        gross_values = []
        for match in matches:
            # Keep the original format
            original_value = match.strip()

            # Remove dollar sign and spaces
            cleaned_value = match.replace('$', '').replace(' ', '').strip()

            # Replace commas with dots if they are in decimal place format (e.g., "36 946,22")
            if cleaned_value.count(',') == 1 and cleaned_value.count('.') == 0:
                cleaned_value = cleaned_value.replace(',', '.')

            # Remove spaces or commas used as thousand separators (e.g., "36 946" -> "36946" or "1,000,000" -> "1000000")
            cleaned_value = cleaned_value.replace(' ', '').replace(',', '')

            try:
                # Convert to float
                numeric_value = float(cleaned_value)
                gross_values.append((original_value, numeric_value))
            except ValueError:
                continue  # Skip if conversion fails due to bad format

        # Find the max gross value
        if gross_values:
            # Get the original formatted string of the max value
            max_gross_value_original = max(gross_values, key=lambda x: x[1])[0]
            print(f"Maximum gross worth: {max_gross_value_original}")
            coord = context_with_coordinates.get(max_gross_value_original, "NA")
            return (max_gross_value_original, coord)

    # In case no valid matches are found
    print("Gross worth not found.")
    return ("", "")


In [27]:
import json

def create_json(client_name, invoice_number, invoice_date,   gross_worth):
  """Creates a JSON object where one key has two child keys,
  corresponding to the first and second elements of a tuple.

  Args:
    tuple_data: A tuple containing the data to be added to the JSON object.

  Returns:
    A JSON object with the specified structure.
  """

  json_data = {}
  json_data["Customer Name"] = {
    "value": client_name[0],
    "Coordinates": client_name[1]
  }
  json_data["Invoice Number"] = {
    "value": invoice_number[0],
    "Coordinates": invoice_number[1]
  }
  json_data["Invoice Date"] = {
    "value": invoice_date[0],
    "Coordinates": invoice_date[1]
  }

  json_data["Total Amount"] = {
    "value": gross_worth[0],
    "Coordinates": gross_worth[1]
  }
  return json.dumps(json_data)


In [28]:
def create_and_save_json_df(input1, input2, filename):
    """Reads a CSV file, processes each row starting from the specified start_row, and extracts JSON data from the 'context' column.

    Args:
      filename: The name of the CSV file to read.
      start_row: The starting row index from which to begin processing.

    Returns:
      None
    """

    if os.path.exists(filename):
      df = pd.read_csv(filename)
    else:
      df = pd.DataFrame(columns=['index', 'json_data'])
    df.loc[len(df)] = [input1, input2]
    df.to_csv(filename, index=False)

In [43]:
def read_csv_with_dict(filename):
  """Reads a CSV file, processes each row starting from the specified start_row, and extracts JSON data from the 'context' column.

  Args:
    filename: The name of the CSV file to read.
    start_row: The starting row index from which to begin processing.

  Returns:
    None
  """
  df = pd.read_csv(filename)
  for index, row in df.iterrows():
    index_value = row['index']
    context = row['context']
    context_with_coordinates = eval(row['context_with_coordinates'])

    invoice_date = extract_invoice_date(context, context_with_coordinates)
    invoice_number = extract_invoice_number(context, context_with_coordinates)
    client_name = extract_client_name(context, context_with_coordinates)
    gross_worth = extract_gross_worth(context, context_with_coordinates)
    json_data = create_json(client_name, invoice_number, invoice_date, gross_worth)
    create_and_save_json_df(index_value, json_data, "ocr_re_results.csv")

# Example usage:
read_csv_with_dict("ocr_results.csv")


Extracted date: 10/15/2012
Extracted invoice number: 40378170
Client name: Jackson, Odonnell and Jackson
Maximum gross worth: $ 8,25
Extracted date: 09/06/2012
Extracted invoice number: 61356291
Client name: Rodriguez-Stevens
Maximum gross worth: $ 212,09
Extracted date: 10/28/2019
Extracted invoice number: 49565075
Client name: Garcia Inc
Maximum gross worth: $ 96,73
Extracted date: 07/19/2016
Extracted invoice number: 95611677
Client name: Johnson Group
Maximum gross worth: $ 1 054,10
Extracted date: 11/19/2019
Extracted invoice number: 26020078
Client name: Ochoa, Crane and Johnston
Maximum gross worth: $ 116,52
Extracted date: 03/15/2012
Extracted invoice number: 42485588
Client name: Knight-Brown
Maximum gross worth: $ 214,41
Extracted date: 01/12/2015
Extracted invoice number: 94689364
Client name: Wilson PLC
Maximum gross worth: $ 37 153,67
Extracted date: 02/06/2020
Extracted invoice number: 48402876
Client name: Chapman-Pineda
Maximum gross worth: $ 4 618,75
Extracted date: 03

In [44]:
df_data = ds['train'].to_pandas()
df_data.drop('image', axis=1, inplace=True)
df_data = df_data.iloc[start_row:end_row]
df_data.reset_index(inplace=True)
merged_df = pd.merge(df_data, pd.read_csv('ocr_re_json_results.csv'), on='index')
merged_df['ground_truth'] = merged_df['ground_truth'].apply(json.loads)
merged_df['json_data'] = merged_df['json_data'].apply(json.loads)
merged_df.to_csv('final_results_re.csv', index=False)


In [45]:
merged_df.head()

Unnamed: 0,index,ground_truth,json_data
0,0,{'gt_parse': {'header': {'invoice_no': '403781...,"{'Customer Name': {'value': 'Jackson, Odonnell..."
1,1,{'gt_parse': {'header': {'invoice_no': '613562...,{'Customer Name': {'value': 'Rodriguez-Stevens...
2,2,{'gt_parse': {'header': {'invoice_no': '495650...,"{'Customer Name': {'value': 'Garcia Inc', 'Coo..."
3,3,{'gt_parse': {'header': {'invoice_no': '956116...,"{'Customer Name': {'value': 'Johnson Group', '..."
4,4,{'gt_parse': {'header': {'invoice_no': '260200...,"{'Customer Name': {'value': 'Ochoa, Crane and ..."


In [46]:
type(merged_df['json_data'])

In [47]:
comparison_list = []

for index, row in merged_df.iterrows():
    json_data = row['json_data']
    if json_data:
        invoice_number_1 = json_data.get('Invoice Number', {}).get('value', '')
        invoice_date_1 = json_data.get('Invoice Date', {}).get('value', '')
        client_name_1 = json_data.get('Customer Name', {}).get('value', '')
        gross_worth_1 = json_data.get('Total Amount', {}).get('value', '')
    else:
        invoice_number_1 = ''
        invoice_date_1 = ''
        client_name_1 = ''
        gross_worth_1 = ''

    ground_truth = row['ground_truth']
    if ground_truth:
        invoice_number_2 = ground_truth.get('gt_parse', {}).get('header', {}).get('invoice_no', '')
        invoice_date_2 = ground_truth.get('gt_parse', {}).get('header', {}).get('invoice_date', '')
        client_name_2 = ground_truth.get('gt_parse', {}).get('header', {}).get('client', '')
        gross_worth_2 = ground_truth.get('gt_parse', {}).get('summary', {}).get('total_gross_worth', '')
    else:
        invoice_number_2 = ''
        invoice_date_2 = ''
        client_name_2 = ''
        gross_worth_2 = ''

    comparison_results = {
        'index': index,
        'Customer Name Matches': client_name_1 in client_name_2,
        'Invoice Number Matches': invoice_number_1 == invoice_number_2,
        'Invoice Date Matches': invoice_date_1 == invoice_date_2,
        'Total Amount Matches': gross_worth_1.strip().replace('$ ', '').replace('$','').replace(' ', '') == gross_worth_2.strip().replace('$ ', '').replace('$','').replace(' ', '')
    }

    # Add comparison results to the list
    comparison_list.append(comparison_results)

    # Output the details
    print(f"Invoice Number ocr, {index}:", invoice_number_1)
    print(f"Invoice Date ocr, {index}:", invoice_date_1)
    print(f"Client Name ocr, {index}:", client_name_1)
    print(f"Gross Worth ocr, {index}:", gross_worth_1)
    print(f"Invoice Number, {index}:", invoice_number_2)
    print(f"Invoice Date, {index}:", invoice_date_2)
    print(f"Client Name, {index}:", client_name_2)
    print(f"Gross Worth, {index}:", gross_worth_2)
    print("----" * 30)

# Convert comparison results to DataFrame
comparison_df = pd.DataFrame(comparison_list)

# Calculate percentage matches for each field
total_comparisons = len(comparison_df)
percentage_matches = {
    'Customer Name Matches (%)': (comparison_df['Customer Name Matches'].sum() / total_comparisons) * 100,
    'Invoice Number Matches (%)': (comparison_df['Invoice Number Matches'].sum() / total_comparisons) * 100,
    'Invoice Date Matches (%)': (comparison_df['Invoice Date Matches'].sum() / total_comparisons) * 100,
    'Total Amount Matches (%)': (comparison_df['Total Amount Matches'].sum() / total_comparisons) * 100
}

print("\nComparison Results Summary:")
for key, value in percentage_matches.items():
    print(f"{key}: {value:.2f}%")

Invoice Number ocr, 0: 40378170
Invoice Date ocr, 0: 10/15/2012
Client Name ocr, 0: Jackson, Odonnell and Jackson
Gross Worth ocr, 0: $ 8,25
Invoice Number, 0: 40378170
Invoice Date, 0: 10/15/2012
Client Name, 0: Jackson, Odonnell and Jackson 267 John Track Suite 841 Jenniferville, PA 98601
Gross Worth, 0: $8,25
------------------------------------------------------------------------------------------------------------------------
Invoice Number ocr, 1: 61356291
Invoice Date ocr, 1: 09/06/2012
Client Name ocr, 1: Rodriguez-Stevens
Gross Worth ocr, 1: $ 212,09
Invoice Number, 1: 61356291
Invoice Date, 1: 09/06/2012
Client Name, 1: Rodriguez-Stevens 2280 Angela Plain Hortonshire, MS 93248
Gross Worth, 1: $ 212,09
------------------------------------------------------------------------------------------------------------------------
Invoice Number ocr, 2: 49565075
Invoice Date ocr, 2: 10/28/2019
Client Name ocr, 2: Garcia Inc
Gross Worth ocr, 2: $ 96,73
Invoice Number, 2: 49565075
Invoice

In [49]:

from sklearn.metrics import confusion_matrix
from typing import Dict

def flatten_json(json_obj: Dict) -> Dict:
    """
    Flattens a nested JSON object into a single-level dictionary.

    Args:
        json_obj (Dict): The nested JSON object to flatten.

    Returns:
        Dict: The flattened JSON object as a single-level dictionary.
    """
    flat_dict = {}

    def flatten(data, parent_key=''):
        if isinstance(data, dict):
            for k, v in data.items():
                new_key = f"{parent_key}{k}"
                if isinstance(v, dict) or isinstance(v, list):
                    flatten(v, new_key + '_')
                else:
                    flat_dict[new_key] = v
        elif isinstance(data, list):
            for i, item in enumerate(data):
                flatten(item, parent_key + str(i) + '_')

    flatten(json_obj)
    return flat_dict

def evaluate_json(ground_truth: Dict, predicted: Dict) -> Dict:
    """Evaluates the accuracy of a predicted JSON object against a ground truth JSON object.

    Args:
        ground_truth (Dict): The ground truth JSON object.
        predicted (Dict): The predicted JSON object.

    Returns:
        Dict: A dictionary containing the following metrics:
            - True Positive: Number of correctly predicted values.
            - False Positive: Number of incorrectly predicted values.
            - True Negative: Number of correctly predicted absences.
            - False Negative: Number of incorrectly predicted absences.
            - value_accuracy: Overall accuracy based on the number of matching values.
    """

    metrics = {'True Positive': 0, 'False Positive': 0, 'True Negative': 0, 'False Negative': 0}

    gt_flat = flatten_json(ground_truth)
    # print(gt_flat)
    pred_flat = flatten_json(predicted)
    # print(pred_flat)
    # Define matching criteria
    matches = {
        'Customer Name Matches': pred_flat.get('Customer Name_value', 'NA') in gt_flat.get('gt_parse_header_client', 'NA'),
        'Invoice Number Matches': pred_flat.get('Invoice Number_value', 'NA') == gt_flat.get('gt_parse_header_invoice_no', 'NA'),
        'Invoice Date Matches': pred_flat.get('Invoice Date_value', 'NA') == gt_flat.get('gt_parse_header_invoice_date', 'NA'),
        'Total Amount Matches': (
            pred_flat.get('Total Amount_value', 'NA').strip().replace('$ ', '').replace('$', '').replace(' ', '').replace('.', ',') ==
            gt_flat.get('gt_parse_summary_total_gross_worth', 'NA').strip().replace('$ ', '').replace('$', '').replace(' ', '').replace('.', ',')
        )
    }


    groundt_truth = {'Customer Name Matches':'gt_parse_header_client', 'Invoice Number Matches':'gt_parse_header_invoice_no','Invoice Date Matches': 'gt_parse_header_invoice_date', 'Total Amount Matches':'gt_parse_summary_total_gross_worth'}
    predict_val = {'Customer Name Matches':'Customer Name_value', 'Invoice Number Matches':'Invoice Number_value','Invoice Date Matches': 'Invoice Date_value', 'Total Amount Matches':'Total Amount_value'}
    # Calculate confusion matrix components based on matching criteria
    for match_type, is_match in matches.items():
        gt_value_present = gt_flat.get(groundt_truth.get(match_type), 'NA') != 'NA'
        pred_value_present = pred_flat.get(predict_val.get(match_type), 'NA') != 'NA'

        if gt_value_present and pred_value_present and is_match:
            # True Positive: Both values are present and match
            metrics['True Positive'] += 1
        elif not pred_value_present and gt_value_present or not is_match:
            # False Positive: Predicted value not present, but ground truth value is present
            metrics['False Positive'] += 1
        elif pred_value_present and not gt_value_present:
            # False Negative: Predicted value present, but ground truth value is not present
            metrics['False Negative'] += 1
        elif not gt_value_present and not pred_value_present:
            # True Negative: Both values are not present
            metrics['True Negative'] += 1

    # Compute accuracy
    total_matches = len(matches)
    value_accuracy = (metrics['True Positive'] / total_matches) * 100 if total_matches > 0 else 0
    metrics['value_accuracy'] = value_accuracy

    # print(metrics)
    return metrics

# Initialize accumulators
total_confusion = {'True Positive': 0, 'False Positive': 0, 'True Negative': 0, 'False Negative': 0}
total_values = 0
total_correct = 0

# Iterate over the DataFrame rows
for index, row in merged_df.iterrows():
    ground_truth = row['ground_truth']
    predicted = row['json_data']
    metrics = evaluate_json(ground_truth, predicted)

    # Accumulate results
    total_confusion['True Positive'] += metrics['True Positive']
    total_confusion['False Positive'] += metrics['False Positive']
    total_confusion['True Negative'] += metrics['True Negative']
    total_confusion['False Negative'] += metrics['False Negative']

    total_values += len(flatten_json(row['json_data']))
    total_correct += metrics['value_accuracy'] * len(flatten_json(row['json_data'])) / 100

# Compute aggregate metrics
aggregated_metrics = {
    'True Positive': total_confusion['True Positive'],
    'False Positive': total_confusion['False Positive'],
    'True Negative': total_confusion['True Negative'],
    'False Negative': total_confusion['False Negative'],
    'Total Values': total_values,
    'Total Correct': total_correct,
    'Overall Accuracy (%)': (total_correct / total_values) * 100
}

# Display aggregated results
aggregated_results_df = pd.DataFrame([aggregated_metrics])
print(aggregated_results_df)


   True Positive  False Positive  True Negative  False Negative  Total Values  \
0            367              33              0               0           435   

   Total Correct  Overall Accuracy (%)  
0         394.75             90.747126  
