This is used for benchmark purposes, not a optimized pipeline for llm extraction

This is a two stage llm pipeline, we first extract key information and then transform to a json format. This allows a performance boost.

In [None]:
# pip install accelerate
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

from transformers import AutoProcessor, Gemma3ForConditionalGeneration
from PIL import Image
import requests
import torch

model_id = "google/gemma-3-4b-it"

model = Gemma3ForConditionalGeneration.from_pretrained(
    model_id, device_map="auto"
).eval()

processor = AutoProcessor.from_pretrained(model_id)


In [None]:
import os
from pymongo import UpdateOne
from tqdm import tqdm
from pymongo import MongoClient
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime

client = MongoClient("localhost", 29012)
db = client["test-database"]
collection_json = db["collection-json"]
collection_txt = db["collection-txt2"]
collection_CNN = db["CNN_DE"]

def get_text(publication_number):
    publication_number = f"{publication_number[:2]}.{publication_number[2:-1]}.{publication_number[-1:]}"
    pipeline = [
        {"$match": {"Publication Number": publication_number}},  # Filter documents where 'type' is 'text'
        {"$group": {
            "_id": "$Publication Number",  # Group by 'Publication Number'
            "pages": {"$push": "$page"},   # Collect 'page' values into an array
            "text": {"$push": "$OCR"}  # Collect 'OCR' values into an array
        }}
    ]

    result_db = list(collection_CNN.aggregate(pipeline))
    if not result_db:
        return None
    if len(result_db[0]["text"]) == 0:
        return None
    result ={
        "Publication Number": result_db[0]["_id"],
        "pages": result_db[0]["pages"],
        "OCR": ' '.join([item for item in result_db[0]["text"]])  # Join the text from all pages
    }
    return result

def generate_query(item):
    if not isinstance(item, dict):
        print(item)
        raise ValueError("Expected 'item' to be a dictionary.")
    
    if 'OCR' not in item:
        raise KeyError("'OCR' key is missing in the item dictionary.")
    
    if isinstance(item['OCR'], str):
        text_clean = item['OCR'].replace('\n', ' ').strip()
        return text_clean
    if isinstance(item['OCR'], bytes):
        text_clean = item['OCR'].decode('utf-8').replace('\n', ' ').strip()
        return text_clean
    
    raise TypeError("'OCR' must be either a string or bytes.")



In [None]:
from typing import List, Optional
from pydantic import BaseModel
import datetime
class Patent(BaseModel):
    title : str
    Application_Date: Optional[datetime.datetime]
    Publication_Date: Optional[datetime.datetime]
    KeyEntity: list[str]
    #Inventors: list[str]

    class Config:
        schema_extra = {
            "additionalProperties": False,
            "json_encoders": {
                datetime.datetime: lambda v: v.isoformat()
            }
        }
# serialize pydantic model into json schema
pydantic_schema = Patent.schema_json()

prompt = f"You are a helpful assistant that transform historical scans of patents to a JSON format. Make sure to get the dates and names correct and only include keys 'title', 'Application_Date', 'Publication_Date', 'KeyEntity' with the names of Applicants, Inventors and Assignees. Take a second to think about your answer. Here's the json schema you must adhere to:\n{pydantic_schema}\n"

In [None]:
item = collection_json.find_one({'Country':"US", 'OCR': {'$exists': True}, 'Title': {'$exists': True}, 'C_Application Date': {'$exists': True}, 'C_Publication Date': {'$exists': True}, 'clean_applicants': {'$exists': True}, 'clean_inventor': {'$exists': True}})

In [6]:
def first_stage(item):
    og_query = generate_query(item)
    if len(og_query) > 40000:
        og_query = og_query[:34000] + " " + og_query[-1000:]
    messages = [
        {
            "role": "system",
            "content": [{"type": "text", "text": "Identify the patent title, application date, publication date, applicants, inventors and assignee as key entity (people or company) from the text. If data is missing, interpret it based on the information you have, values can be None. Please don't add any comments or explanations."}]
        },
        {
            "role": "user",
            "content": [
                {"type": "text", "text":  og_query},
            ]
        }
    ]
    print(f"Length of query is {len(generate_query(item))}")
    inputs = processor.apply_chat_template(
        messages, add_generation_prompt=True, tokenize=True,
        return_dict=True, return_tensors="pt"
    ).to(model.device, dtype=torch.bfloat16)

    input_len = inputs["input_ids"].shape[-1]
    with torch.inference_mode():
        generation = model.generate(**inputs, max_new_tokens=150, do_sample=False)
        generation = generation[0][input_len:]

    decoded = processor.decode(generation, skip_special_tokens=True)
    del inputs
    del generation
    
    return decoded#.replace('assignee', 'applicant').replace('Assignee', 'Applicant')

def second_stage(decoded, prompt):
    messages = [
        {
            "role": "system",
            "content": [{"type": "text", "text": prompt}]
        },
        {
            "role": "user",
            "content": [
                #{"type": "image", "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
                {"type": "text", "text": decoded},
            ]
        }
    ]

    inputs = processor.apply_chat_template(
        messages, add_generation_prompt=True, tokenize=True,
        return_dict=True, return_tensors="pt"
    ).to(model.device, dtype=torch.bfloat16)

    input_len = inputs["input_ids"].shape[-1]
    with torch.inference_mode():
        generation = model.generate(**inputs, max_new_tokens=300, do_sample=False)
        generation = generation[0][input_len:]

    decoded = processor.decode(generation, skip_special_tokens=True)
    del inputs
    del generation
    
    return decoded

For txt

In [None]:
from tqdm import tqdm
import json_repair
import json
import random
query = {'Country': "DE"}
training_data = []
total_count = collection_txt.count_documents(query)
random_indexes = random.sample(range(total_count), 500)
for item in tqdm(random_indexes):
    item = collection_txt.find_one({'Country': "DE"}, skip=item)
    publication_number = item['Publication Number']
    json_ocr = get_text(publication_number)
    if json_ocr is None:
        print(f"Publication Number {item['Publication Number']} not found")
        continue
    summary = first_stage(json_ocr)
    print("---- summary ----")
    print(summary)
    json_llm = second_stage(summary, prompt)
    print("---- json_llm ----")
    print(json_llm)
    json_llm = json_repair.loads(json_llm)
    if 'Applicant' in json_llm:
        json_llm['KeyPeople'] = json_llm.pop('Applicant')
    if 'Assignee' in json_llm:
        json_llm['KeyPeople'] = json_llm.pop('Assignee')
    if 'Inventor' in json_llm:
        json_llm['KeyPeople'] = json_llm.pop('Inventor')
    # Run the model evaluator
    with open(f"/scratch/students/ndillenb/metadata/processing/llm/json_compare/de2_gemma-3-4b-it_json_compare/json_llm_{item['_id']}.json", "w") as f:
        json.dump({'predicted': json_llm}, f)

For french

In [None]:
from tqdm import tqdm
import json_repair
import json
import random
query = {'Country': "FR"}
training_data = []
total_count = collection_txt.count_documents(query)
random_indexes = random.sample(range(total_count), 500)
for item in tqdm(random_indexes):
    item = collection_txt.find_one({'Country': "FR"}, skip=item)
    #publication_number = item['Publication Number']
    if 'OCR' not in item or item['OCR'] is None or item['OCR'] == '':
        print(f"No OCR for {item['Publication Number']}")
        continue
    summary = first_stage(item)
    print("---- summary ----")
    print(summary)
    json_llm = second_stage(summary, prompt)
    print("---- json_llm ----")
    print(json_llm)
    json_llm = json_repair.loads(json_llm)
    if 'Applicant' in json_llm:
        json_llm['KeyPeople'] = json_llm.pop('Applicant')
    if 'Assignee' in json_llm:
        json_llm['KeyPeople'] = json_llm.pop('Assignee')
    if 'Inventor' in json_llm:
        json_llm['KeyPeople'] = json_llm.pop('Inventor')
    # Run the model evaluator
    with open(f"/scratch/students/ndillenb/metadata/processing/llm/json_compare/fr_gemma-3-4b-it_json_compare/json_llm_{item['_id']}.json", "w") as f:
        json.dump({'predicted': json_llm}, f)

For json

In [None]:
from tqdm import tqdm
import json_repair
import json
for item in tqdm(list(collection_json.find({'Country':"US", 'OCR': {'$exists': True}, 'Title': {'$exists': True}, 'C_Application Date': {'$exists': True}, 'C_Publication Date': {'$exists': True}, 'clean_applicants': {'$exists': True}, 'clean_inventor': {'$exists': True}}).limit(100))):
    summary = first_stage(item)
    print("---- summary ----")
    print(summary)
    json_llm = second_stage(summary, prompt)
    print("---- json_llm ----")
    print(json_llm)
    json_llm = json_repair.loads(json_llm)
    # Run the model evaluator
    with open(f"/scratch/students/ndillenb/metadata/processing/llm/json_compare/gemma-3-4b-it_json_compare/json_llm_{item['_id']}.json", "w") as f:
        data = {'Title': item['Title'], 'Application_Date': item['C_Application Date'], 'Publication_Date': item['C_Publication Date'], 'Applicants': item['clean_applicants'], 'Inventors': item['clean_inventor']}
        # Convert datetime objects to strings for JSON serialization
        data_serializable = {
            key: (value.isoformat() if isinstance(value, datetime.datetime) else value)
            for key, value in data.items()
        }
        json.dump({'predicted': json_llm, 'expected': data_serializable}, f)