# Document Understanding with Llava-NEXT and Structured Generation

In [1]:
import os
import json
import base64
import requests
from io import BytesIO

import numpy as np
import pandas as pd
import huggingface_hub
from PIL import Image, ImageOps
from PIL.Image import Image as PILImage
from transformers import LlavaNextProcessor
from transformers.image_processing_utils import select_best_resolution

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
SYSTEM_PROMPT_FORMAT = """You are a document information extractor. You will receive an image and you must answer the user's question from the data in the image. Be exact, concise, and don't yap. Sample final answers: "INV392834", "Jollibee", "05/11/2024". Output in the following json format: <json_format>."""
PROMPT_FORMAT = "<system_prompt> USER: <image>\\n<user> ASSISTANT: "

In [3]:
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf")
possible_resolutions = processor.image_processor.image_grid_pinpoints

In [4]:
possible_resolutions += [
    [672, 1008],
    [1008, 672],
    [1008, 1008],
    [1008, 1344],
    [1344, 1008],
]

In [13]:
def resize_and_pad_image(image: PILImage) -> PILImage:
    best_resolution = select_best_resolution(image.size, possible_resolutions)
    print(f"{best_resolution = }")
    resized_image = ImageOps.cover(image, best_resolution)
    resized_and_padded_image = ImageOps.pad(
        resized_image,
        best_resolution,
        method=processor.image_processor.resample,
        color=(255,255,255,0),
    )
    return resized_and_padded_image


def encode_local_image(image_path, resize_and_pad: bool=True):
    # load image
    image = Image.open(image_path)
    if ".gif" in image_path:
        image = image.convert("RGB")
    if resize_and_pad:
        image = resize_and_pad_image(image)
        print(f"New size: {image.size}")

    # Convert the image to a base64 string
    buffer = BytesIO()
    image.save(buffer, format="PNG")  # Use the appropriate format (e.g., JPEG, PNG)
    base64_image = base64.b64encode(buffer.getvalue()).decode('utf-8')

    # add string formatting required by the endpoint
    image_string = f"data:image/png;base64,{base64_image}"

    return image_string

In [14]:
def extract_key_from_question(question: str) -> str:
    question = " ".join(question.split(" "))
    assert question.startswith("<image>\nWhat is the ")
    assert question.endswith(" in the image?")
    key = question[len("<image>\nWhat is the "):-len(" in the image?")]
    if key.startswith("[") or key.startswith("‘"):
        key = key[1:]
    if key.endswith("?"):
        key = key[:-1]
    return (
        key
        .replace(" ", "_")
        .replace("_no", "_number")
        .replace("_$", "_dollars")
    )

In [15]:
def build_doc_extraction_tool(key: str, max_length: int=100):
    return {
        "type": "function",
        "function": {
            "name": "doc_extraction_tool",
            "description": "Extract information from a document",
            "parameters": {
                "type": "object",
                "properties": {
                    "1_reasoning": {
                        "type": "string"
                    },
                    f"2_{key}": {
                        "type": "integer" if key == "page" else "string",
                        "description": "The answer, exactly as it appears in the document.",
                        "maxLength": max_length,
                    }
                },
                "required": ["1_reasoning", f"2_{key}"],
            },
        }
    }

In [16]:
API_URL = "https://bkliyhzstf7g5dyz.us-east-1.aws.endpoints.huggingface.cloud"
headers = {
	"Accept" : "application/json",
	"Authorization": f"Bearer {huggingface_hub.get_token()}",
	"Content-Type": "application/json" 
}

def query(payload):
	response = requests.post(API_URL, headers=headers, json=payload)
	return response.json()

In [17]:
def run_inference(image_path, question, seed=0, max_length: int=100):
    image_base64 = encode_local_image(image_path)

    key = extract_key_from_question(question)
    print(f"{key = }")
    if key.strip() == "":
        return "This question is unanswerable."

    tool = build_doc_extraction_tool(key, max_length)

    system_prompt = SYSTEM_PROMPT_FORMAT.replace(
        "<json_format>",
        json.dumps(tool["function"]["parameters"]["properties"]),
    )
    prompt = (
        PROMPT_FORMAT
        .replace("<system_prompt>", system_prompt)
        .replace("<image>", f"![]({image_base64})")
        .replace("<user>", question[len("<image>\n"):])
    )
    # print(f"{prompt = }")

    # This version of TGI uses an older version of Outlines
    # which re-orders the keys in the JSON in alphabetical order.
    # Hence the prefixes in the keys in the grammer
    response = query({
        "inputs": prompt,
        "parameters": {
            "return_full_text": False,
            "max_new_tokens": 2048,
            "top_p": 0.95,
            "frequency_penalty": 0,
            "presence_penalty": 0,
            "grammar": {
                "type": "json",
                "value": tool["function"]["parameters"],
            }
        }
    })
    print(f"{response = }")

    return json.loads(response[0]["generated_text"])[f"2_{key}"]

In [18]:
image_path = "data/raw_datasets/mydoc/images/56d7d0711831b8fda7e7c7f272d407d2dd0fd4e578090c1d74761089733d6813.png"

In [None]:
image = Image.open(image_path)
image = resize_and_pad_image(image)
image

In [None]:
run_inference(
    image_path,
    "<image>\nWhat is the credit status in the image?",
    max_length=30,
)

In [None]:
dataset_name = "mydoc"
dataset_path = os.path.join("data/raw_datasets", dataset_name, "annot_wo_answer.json")
print(dataset_path)
assert os.path.exists(dataset_path)

df_data = pd.read_json(dataset_path)

In [22]:
!mkdir -p inference_results/llava-1-6-vicuna-13b-hf

In [23]:
failed_idx = set()

In [None]:
# failed_idx = set()
for idx, row in df_data.iterrows():
    id = row["id"]
    answer_txt_path = f"inference_results/llava-1-6-vicuna-13b-hf/{id}.txt"

    image_path = f"data/raw_datasets/{dataset_name}/images/{row['image']}"
    question = row["conversations"][0]["value"]

    if os.path.exists(answer_txt_path):
        continue
    print(image_path)
    print(id, idx, question)

    try:
        answer = run_inference(image_path, question, seed=7283703)
        print(f"{answer = }")
        with open(answer_txt_path, "w") as f:
            f.write(str(answer))
    except Exception as e:
        print(">>>>>>> ERROR", idx, row, e, "<<<<<<<")
        failed_idx.add(idx)
    print("---------")

In [None]:
for idx, row in df_data.iterrows():
    id = row["id"]
    answer_txt_path = f"inference_results/llava-1-6-vicuna-13b-hf/{id}.txt"
    if not os.path.exists(answer_txt_path):
        failed_idx.add(idx)
        continue

    with open(answer_txt_path, "r") as f:
        answer = f.read()

    if len(answer) >= 50:
        failed_idx.add(idx)

In [None]:
failed_idx

In [None]:
df_data.iloc[list(failed_idx)]

In [None]:
failed_idx_2 = set()
for idx, row in df_data.iloc[list(failed_idx)].iterrows():
    id = row["id"]
    answer_txt_path = f"inference_results/llava-1-6-vicuna-13b-hf/{id}.txt"
    if os.path.exists(answer_txt_path):
        os.remove(answer_txt_path)
    print(id, idx)

    image_path = f"data/raw_datasets/{dataset_name}/images/{row['image']}"
    question = row["conversations"][0]["value"]

    try:
        answer = run_inference(image_path, question, seed=42)
        with open(answer_txt_path, "w") as f:
            f.write(answer)
    except Exception as e:
        print(idx, row, e)
        failed_idx_2.add(idx)

In [None]:
failed_idx_2

In [None]:
df_data.iloc[220]["image"]

In [None]:
df_data.iloc[220]["conversations"]