In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [1]:
!pip install qwen-vl-utils gradio 

Collecting qwen-vl-utils
  Downloading qwen_vl_utils-0.0.10-py3-none-any.whl.metadata (6.3 kB)
Collecting gradio
  Downloading gradio-5.14.0-py3-none-any.whl.metadata (16 kB)
Collecting av (from qwen-vl-utils)
  Downloading av-14.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.7 kB)
Collecting fastapi<1.0,>=0.115.2 (from gradio)
  Downloading fastapi-0.115.8-py3-none-any.whl.metadata (27 kB)
Collecting ffmpy (from gradio)
  Downloading ffmpy-0.5.0-py3-none-any.whl.metadata (3.0 kB)
Collecting gradio-client==1.7.0 (from gradio)
  Downloading gradio_client-1.7.0-py3-none-any.whl.metadata (7.1 kB)
Collecting markupsafe~=2.0 (from gradio)
  Downloading MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB)
Collecting python-multipart>=0.0.18 (from gradio)
  Downloading python_multipart-0.0.20-py3-none-any.whl.metadata (1.8 kB)
Collecting ruff>=0.9.3 (from gradio)
  Downloading ruff-0.9.4-py3-none-manylinux_2_17_x86_64.manyli

# Import libraries

In [2]:
import json
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from PIL import Image
import pytesseract
import torch
import re

# Model Class

In [3]:

class InvoiceProcessor:
    """
    A class to handle processing invoices using the Qwen2-VL-2B-Instruct model.
    This class provides functionality to extract information from invoice images and text prompts.
    """

    def __init__(self, model_name="Qwen/Qwen2-VL-2B-Instruct", device=None):
        """
        Initialize the InvoiceProcessor class.

        Args:
            model_name (str): The name of the model to be loaded.
            device (str): Device to load the model on ('cuda', 'cpu', or None for auto-detection).
        """
        min_pixels = 256 * 28 * 28  # Minimum size for visual tokens
        max_pixels = 1280 * 28 * 28  # Maximum size for visual tokens
        self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
        self.model_name = model_name
        self.model = Qwen2VLForConditionalGeneration.from_pretrained(
            self.model_name, torch_dtype=torch.float16, device_map="auto").eval()
        self.processor = AutoProcessor.from_pretrained(self.model_name ,min_pixels=min_pixels,max_pixels=max_pixels)
        # Initialize the OCR model (with pretrained weights)
        self.response = ""


    def process_image(self, image_path):
        """
        Prepares the inputs for the model.

        Args:
            image_path (str): Path to the invoice image.
            prompt (str): Prompt guiding the model's behavior.

        Returns:
            dict: Prepared inputs ready for the model.
        """
        torch.cuda.empty_cache()
        processed_image = self.preprocess_image(image_path)
        messages = [
            {
                "role": "system",
                "content": """You are an advanced model tasked with extracting structured data from invoices. An invoice includes the following key sections. Extract the data accurately and return it in the specified JSON format.

                Key Sections to Extract:
                               
                Invoice Information: Extract the information for invoice like invoice number, invoice date, Due date etc.
                Business: Extract the Business name, Business address,business mail,business phone number,business GSTIN,business PAN etc of the issuing business.
                Customer: Extract the Customer name,Customer address,business mail,business phone number,business GSTIN, customer PAN etc of the billed customer.
                Product/Service: Extract the details of each billed item/servic include all subfield mentioned.
                shipment: Extract details related to shipment.
                Bank detail: extract bank related information.
                Taxes Information: Extract details for each Taxable value,Central taxes,State taxes(Total Tax Amount,Tax Rate) .
                Total Amount Information: Extract all kind of amount information.
                if specific section or parameter not present skip those.
            """
            },
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": processed_image},
                    {"type": "text", "text": f"Extract all the important and relevant information from this image"},
                ],
            },
        ]

        # Prepare the chat input and vision info
        text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt= False)
        image_inputs, video_inputs = process_vision_info(messages)

        # Prepare and return model inputs
        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs.to(self.device)
        # Inference: Generation of the output
        generated_ids = self.model.generate(**inputs, max_new_tokens= 1200, temperature = 0.01,top_p=1.0 )
        generated_ids_trimmed = [
           out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
           ]
        output_text = self.processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
           )
        del inputs, generated_ids, generated_ids_trimmed
        torch.cuda.empty_cache()

        cleaned_text = ''.join(output_text).strip("```json\n")
        self.response = cleaned_text

        return cleaned_text

    def parse_json(self):
          """
               Parse a JSON string into a flat dictionary of key-value pairs, ignoring intermediate keys if the value contains key-value pairs.
          """
          parsed_dict = {}
          json_string = self.response

          def recursive_parser(data, parent_key=""):
              if isinstance(data, dict):
                 for key, value in data.items():
                     # If the value is a dictionary, recurse without appending the key
                     if isinstance(value, dict):
                       recursive_parser(value, parent_key)
                     else:
                        # Create a new key by appending the parent key (if exists)
                        new_key = f"{parent_key}.{key}" if parent_key else key
                        recursive_parser(value, new_key)
              elif isinstance(data, list):
                 for index, item in enumerate(data):
                     # Include index in key for list items
                     new_key = f"{parent_key}[{index}]"
                     recursive_parser(item, new_key)
              else:
                 # For leaf nodes, store the value in the dictionary
                 parsed_dict[parent_key] = data

          # Parse the JSON string
          try:
              json_data = json.loads(json_string)
              recursive_parser(json_data)
          except json.JSONDecodeError as e:
                 print("JSONDecodeError: Falling back to regex-based parsing.")

                 # Fallback: Parse using regex
                 # Regex pattern to capture key-value pairs inside quotes
                 pattern = r'"([^"]+)"\s*:\s*("[^"]+"|\d+|\[.*?\]|\{.*?\})'
                 data = json_string
                 matches = re.findall(pattern, data)
                 extracted_data = {}

                 for key, value in matches:
                       # Remove surrounding quotes from keys and values
                       key = key.strip('"')
                       value = value.strip('"')

                       # Attempt to interpret lists or nested JSON-like values
                       if value.startswith("{") and value.endswith("}"):
                          try:
                              extracted_data[key] = json.loads(value)
                          except json.JSONDecodeError:
                              extracted_data[key] = value  # Treat as raw string if invalid JSON
                       elif value.startswith("[") and value.endswith("]"):
                          try:
                              extracted_data[key] = json.loads(value)
                          except json.JSONDecodeError:
                              extracted_data[key] = value.split(",")  # Fallback: Treat as list of strings
                       else:
                            extracted_data[key] = value  # Treat as plain string or number

                 return extracted_data

          return parsed_dict
          
    def preprocess_image(self,image_path, max_resolution=(1024, 1024)):
      """
        Resize the image to fit within the max_resolution while maintaining aspect ratio.
      """
      img = Image.open(image_path).convert("RGB")
      original_width, original_height = img.size
      #print(f"Original Image Size: {original_width}x{original_height} pixels")

      if original_width > max_resolution[0] or original_height > max_resolution[1]:
         img.thumbnail(max_resolution, Image.Resampling.LANCZOS)     
      return img    

    def image_QA_processing(self, image_path, user_input):
        """
        Prepares the inputs for the model.

        Args:
            image_path (str): Path to the invoice image.
            prompt (str): Prompt guiding the model's behavior.

        Returns:
            dict: Prepared inputs ready for the model.
        """

        torch.cuda.empty_cache()
        processed_image = self.preprocess_image(image_path)
        messages = [
            {
                "role": "system",
                "content": """You are a highly capable Vision-Language Model (Qwen2-2B VLM) specializing in invoice data extraction. 
                Your task is to accurately answer user questions based on the content of an invoice image.
                -If value not found return "`No Data Present`"
                             
                  """
            },
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": processed_image},
                    {"type": "text", "text": f"{user_input}"},
                ],
            },
        ]

        # Prepare the chat input and vision info
        text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt= False)
        image_inputs, video_inputs = process_vision_info(messages)

        # Prepare and return model inputs
        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs.to(self.device)
        # Inference: Generation of the output
        generated_ids = self.model.generate(**inputs, max_new_tokens= 1200, temperature = 0.01,top_p=1.0 )
        generated_ids_trimmed = [
           out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
           ]
        output_text = self.processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
           )
        del inputs, generated_ids, generated_ids_trimmed
        torch.cuda.empty_cache()
        text_string = output_text[0]  # Access the string from the list
        try:
          if "system\n" in text_string:
             # If "system\n" is present, split and extract
             parts = text_string.split('\n')
             extracted_text = parts[1] if len(parts) > 1 else ""  # Handle cases with no text after "system\n"
             return extracted_text
          else:
            # If "system\n" is not present, return the original text
            extracted_text = text_string
            return extracted_text
        except:
            return output_text

# Gradio App

In [None]:
import gradio as gr
# ... (other imports and functions)

VLM_obj = InvoiceProcessor()
torch.cuda.empty_cache()

def parsing_image(image,data_state):
    torch.cuda.empty_cache()
    VLM_obj.process_image(image)
    data_state = VLM_obj.parse_json()
    #data = data_state
    
    return "Image successfully Processed!",data_state

def answer_question(image,question):
    torch.cuda.empty_cache()
    answer = VLM_obj.image_QA_processing(image,question)
    return answer

with gr.Blocks() as demo:
    with gr.Row():
        # Column 1: Image Upload Section
        with gr.Column():
            # ... (image upload components)
            gr.Markdown("### Upload an Image")
            image_input = gr.Image(label="Upload Image", type="filepath")
            image_output = gr.Textbox(label="Image Status")
            image_button = gr.Button("Process")
            data_state = gr.State([])  # Define data_state here
            image_button.click(parsing_image, 
                               inputs=[image_input, data_state], 
                               outputs=[image_output, data_state])

        # Column 2: Tabs Section
        with gr.Column():
            # ... (other components)
            with gr.Tabs():
                with gr.Tab("Fields"):
                    # ... (Fields tab code)
                    # Access and use data_state here, e.g.:
                    
                    # Example: Display the data in a JSON component
                    #json_display = gr.JSON(label="Extracted Data")
                    data_state.change(lambda x: x, inputs=data_state, outputs=None) 
                    
                    with gr.Row():
                       field_name = gr.Textbox(label="Field Name", autofocus=True)
                       field_value = gr.Textbox(label="Field Value")

                    def add_field(data_state, new_field_name, new_field_value):
                       if not new_field_name or not new_field_value:  # Check if either field is blank
                          gr.Warning("Field name or value cannot be blank")
                          return data_state, gr.update(value="", visible=True), gr.update(value="", visible=True) 
                       elif new_field_name in data_state:
                          gr.Warning("Field already existed")
                          return data_state, "", ""
                       else:
                          data_state[new_field_name] = new_field_value
                          return data_state, "", ""

                    add_field_btn = gr.Button("Add Field")
                    add_field_btn.click(add_field, [data_state, field_name, field_value], [data_state, field_name, field_value])

                    @gr.render(inputs=data_state)
                    def render_fields(data_dict):
                        gr.Markdown(f"### Fields ({len(data_dict)})")

                        for field_name, field_value in data_dict.items():
                           with gr.Row():
                              # Make Textboxes editable
                              name_textbox = gr.Textbox(field_name, show_label=False, container=False, label="Field Name", interactive=True,scale=2)
                              value_textbox = gr.Textbox(field_value, show_label=False, container=False, label="Field Value", interactive=True,scale=2)

                              
                             
                              delete_btn = gr.Button("Delete", scale=0, variant="stop")
                              update_btn = gr.Button("Update", scale=0, variant="primary")

                              def delete_field(data_state=data_dict, field_name=field_name):
                                  del data_dict[field_name]
                                  return data_dict

                              delete_btn.click(delete_field, None, [data_state])

                              def create_edit_handler(original_field_name, data_state=data_dict):
                                 def edit_field(new_name, new_value):
                                    if not new_name or not new_value:
                                       gr.Warning("Field Name or Value cannot be blank!")
                                       return gr.update(value=original_field_name),gr.update(value=data_dict[original_field_name]),data_dict
                                    elif new_name in data_dict and original_field_name != new_name:
                                       gr.Warning("Field Name already Exist!")
                                       return gr.update(value=original_field_name),gr.update(value=data_dict[original_field_name]),data_dict
                                    elif original_field_name in data_dict:
                                       if original_field_name != new_name or data_dict[original_field_name] != new_value:
                                           keys = list(data_dict.keys())
                                           values = list(data_dict.values())
                                           for i in range(len(keys)):
                                               if original_field_name == keys[i]:
                                                   keys[i] = new_name
                                                   values[i] = new_value
                                                   break

                                           # Update the dictionary in-place
                                           data_dict.clear()
                                           data_dict.update(dict(zip(keys, values)))

                                           # data_dict = dict(zip(keys, values))
                                           return gr.update(value=new_name), gr.update(value=new_value), data_dict
                                    
                                 return edit_field

                              # ... (Inside render_fields function)
                              update_btn.click(create_edit_handler(field_name), [name_textbox, value_textbox], [name_textbox,value_textbox,data_state])


                        return []

                    # Add a button to get the updated data
                    get_data_btn = gr.Button("Get Data")

                    def get_data(data_state):
                      return data_state  # Return the current data_state

                    get_data_btn.click(get_data, inputs=[data_state], outputs=[gr.JSON(label="Updated Data")])    

                with gr.Tab("Questions-Answer"):
                    # ... (Questions-Answer tab code)
                    # Access data_state if needed in this tab as well
                    # ...
                    gr.Markdown("#### Ask a Question")
                    question_input = gr.Textbox(label="Your Question")
                    answer_output = gr.Textbox(label="Answer")
                    question_button = gr.Button("Get Answer")
                    question_button.click(answer_question, inputs=[image_input,question_input], outputs=answer_output)

demo.launch(debug=True)

* Running on local URL:  http://127.0.0.1:7860
Kaggle notebooks require sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

* Running on public URL: https://bac15a14419e67441d.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/gradio/queueing.py", line 625, in process_events
    response = await route_utils.call_process_api(
  File "/usr/local/lib/python3.10/dist-packages/gradio/route_utils.py", line 322, in call_process_api
    output = await app.get_blocks().process_api(
  File "/usr/local/lib/python3.10/dist-packages/gradio/blocks.py", line 2044, in process_api
    result = await self.call_function(
  File "/usr/local/lib/python3.10/dist-packages/gradio/blocks.py", line 1591, in call_function
    prediction = await anyio.to_thread.run_sync(  # type: ignore
  File "/usr/local/lib/python3.10/dist-packages/anyio/to_thread.py", line 33, in run_sync
    return await get_asynclib().run_sync_in_worker_thread(
  File "/usr/local/lib/python3.10/dist-packages/anyio/_backends/_asyncio.py", line 877, in run_sync_in_worker_thread
    return await future
  File "/usr/local/lib/python3.10/dist-packages/anyio/_backends/_asyncio.py", line 8

JSONDecodeError: Falling back to regex-based parsing.


    Output components:
        []
    Output values returned:
        [{'Invoice Number': '783944', 'Invoice Date': '11/5/2019', 'Due Date': '12/5/2019', 'Business Name': 'Bag of Beans Cafe & Restaurant Inc.', 'Business Address': '117 Aguinaldo Highway Crossing Mendez', 'Business Mail': 'West Tagaytay City Cavite 4120', 'Business Phone Number': '008-117-738-000', 'Business GSTIN': 'IN 008 117 738 000', 'Business PAN': 'IN 008 117 738 000', 'Customer Name': 'Guest 3', 'Customer Address': '117 Aguinaldo Highway Crossing Mendez', 'Customer PAN': 'IN 008 117 738 000', 'SubTotal': '610', 'PreTax': '544', 'Serv Charge (10%)': '54', '12% VAT': '65', 'Total Tax Amount': '0', 'Tax Rate': '0', 'Amount Due': '664'}]
