In [None]:
%%capture
!pip install gradio PyMuPDF
!pip install -U transformers

In [None]:
import gradio as gr
from PIL import Image, ImageDraw, ImageFont
from transformers import AutoTokenizer, AutoProcessor, AutoModelForImageTextToText
import torch
import fitz  # PyMuPDF for PDF handling
import os
from bs4 import BeautifulSoup
import pandas as pd
import io
import re
import json
import xml.etree.ElementTree as ET
from xml.dom import minidom
import gc
import time
from datetime import datetime, timedelta
import hashlib
import uuid

# --- Model Loading ---
model_path = "nanonets/Nanonets-OCR-s"

device = "cuda" if torch.cuda.is_available() else "cpu"

# Set environment variable for better memory management
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

try:
    if torch.cuda.is_available():
        print(f"[{device}] CUDA is available. Loading model with aggressive optimizations.")
        
        # Load with 8-bit quantization to reduce model size from ~12GB to ~7GB
        model = AutoModelForImageTextToText.from_pretrained(
            model_path,
            torch_dtype=torch.float16,
            device_map="auto",
            attn_implementation="eager",
            low_cpu_mem_usage=True,
            load_in_8bit=True
        )
        torch.cuda.empty_cache()
        print(f"Model loaded. VRAM used: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
    else:
        print(f"[{device}] CUDA not available. Loading model on CPU.")
        model = AutoModelForImageTextToText.from_pretrained(
            model_path,
            torch_dtype=torch.float32,
            device_map="cpu",
        )
except Exception as e:
    print(f"[{device}] Warning: load_in_8bit failed, trying without quantization. Error: {e}")
    try:
        model = AutoModelForImageTextToText.from_pretrained(
            model_path,
            torch_dtype=torch.float16,
            device_map="auto",
            low_cpu_mem_usage=True
        )
    except Exception as e2:
        print(f"[{device}] Falling back to CPU. Error: {e2}")
        model = AutoModelForImageTextToText.from_pretrained(
            model_path,
            torch_dtype=torch.float32,
            device_map="cpu",
        )

model.eval()

tokenizer = AutoTokenizer.from_pretrained(model_path)
processor = AutoProcessor.from_pretrained(model_path)


# --- Predefined API Fields ---
PREDEFINED_FIELDS = [
    "Company Name", "Company Address", "Company Phone", "Company Email", "Company Website",
    "Invoice Number", "Invoice Date", "Due Date", "PO Number", "Reference Number",
    "Bill To Name", "Bill To Address", "Bill To Phone", "Bill To Email",
    "Ship To Name", "Ship To Address", "Ship To Phone", "Ship To Email",
    "Subtotal", "Tax Amount", "Tax Rate", "Discount Amount", "Shipping Cost",
    "Total Amount", "Amount Paid", "Amount Due", "Currency",
    "Payment Terms", "Payment Method", "Bank Name", "Account Number", "SWIFT Code",
    "Item Description", "Item Quantity", "Item Unit Price", "Item Total",
    "Customer ID", "Vendor ID", "Department", "Project Code",
    "Notes", "Terms and Conditions", "Signature", "Date of Signature",
    "Sales Person", "Customer Service Rep", "Approval Status", "Document Type",
    "Purchase Order Number", "Contract Number", "License Number", "Registration Number"
]


def clear_memory():
    """Aggressively clear GPU and CPU memory"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    gc.collect()


def process_single_image(image: Image.Image, prompt: str, max_new_tokens: int) -> str:
    """Process a single image with aggressive memory management."""
    try:
        messages = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": prompt},
            ]},
        ]
        
        text_for_processor = processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        inputs = processor(
            text=[text_for_processor], 
            images=[image], 
            padding=True, 
            return_tensors="pt"
        )
        inputs = {k: v.to(model.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            output_ids = model.generate(
                **inputs, 
                max_new_tokens=max_new_tokens, 
                do_sample=False,
                use_cache=True
            )
        
        input_ids_length = inputs['input_ids'].shape[1]
        generated_ids = [output_ids[0, input_ids_length:]]
        output_text = processor.batch_decode(
            generated_ids, 
            skip_special_tokens=True, 
            clean_up_tokenization_spaces=True
        )[0]
        
        del inputs, output_ids, generated_ids
        clear_memory()
        
        return output_text
        
    except torch.cuda.OutOfMemoryError as e:
        clear_memory()
        return f"CUDA out of memory error. Try reducing max_new_tokens or image resolution. Error: {e}"
    except Exception as e:
        clear_memory()
        return f"Error processing image: {e}"


def resize_image_if_needed(image: Image.Image, max_dimension: int = 1536) -> Image.Image:
    """Resize image if it's too large to reduce memory usage."""
    width, height = image.size
    if width > max_dimension or height > max_dimension:
        if width > height:
            new_width = max_dimension
            new_height = int(height * (max_dimension / width))
        else:
            new_height = max_dimension
            new_width = int(width * (max_dimension / height))
        
        print(f"Resizing image from {width}x{height} to {new_width}x{new_height}")
        return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
    return image


def create_bounding_box_visualization(image: Image.Image, ocr_text: str) -> Image.Image:
    """Create visualization with bounding boxes for detected elements."""
    img_with_boxes = image.copy()
    draw = ImageDraw.Draw(img_with_boxes)
    
    # Try to load a font, fall back to default if not available
    try:
        font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12)
    except:
        font = ImageFont.load_default()
    
    # Color coding for different elements
    colors = {
        'table': (255, 0, 0, 128),      # Red for tables
        'equation': (0, 255, 0, 128),    # Green for equations
        'image': (0, 0, 255, 128),       # Blue for images
        'watermark': (255, 255, 0, 128), # Yellow for watermarks
        'text': (255, 165, 0, 128)       # Orange for general text
    }
    
    # This is a simplified visualization - real bounding boxes would need coordinate extraction
    width, height = img_with_boxes.size
    
    # Draw sample bounding boxes based on detected elements
    soup = BeautifulSoup(ocr_text, 'html.parser')
    
    # Check for tables
    if soup.find_all('table'):
        draw.rectangle([(10, 10), (width-10, height//3)], outline=colors['table'][:3], width=3)
        draw.text((15, 15), "Table Detected", fill=colors['table'][:3], font=font)
    
    # Check for equations
    if re.search(r'\$\$[^$]+\$\$|\$[^$]+\$', ocr_text):
        draw.rectangle([(10, height//3), (width-10, 2*height//3)], outline=colors['equation'][:3], width=3)
        draw.text((15, height//3 + 5), "Equation Detected", fill=colors['equation'][:3], font=font)
    
    # Check for images
    if re.search(r'<img>.*?</img>', ocr_text):
        draw.rectangle([(10, 2*height//3), (width-10, height-10)], outline=colors['image'][:3], width=3)
        draw.text((15, 2*height//3 + 5), "Image Detected", fill=colors['image'][:3], font=font)
    
    return img_with_boxes


def extract_api_fields(ocr_text: str, enabled_fields: list, custom_fields: list) -> dict:
    """Extract specified fields from OCR text using pattern matching."""
    extracted_data = {}
    
    # Combine enabled predefined fields and custom fields
    all_fields = enabled_fields + [f.strip() for f in custom_fields if f.strip()]
    
    for field in all_fields:
        # Simple pattern matching - in production, you'd use more sophisticated NLP
        field_lower = field.lower()
        
        # Look for field patterns in the text
        patterns = [
            rf'{re.escape(field)}[\s:]+([^\n]+)',
            rf'{re.escape(field_lower)}[\s:]+([^\n]+)',
        ]
        
        for pattern in patterns:
            match = re.search(pattern, ocr_text, re.IGNORECASE)
            if match:
                extracted_data[field] = match.group(1).strip()
                break
        
        if field not in extracted_data:
            extracted_data[field] = ""  # Empty if not found
    
    return extracted_data


def generate_api_request(api_endpoint: str, api_key: str, method: str, extracted_data: dict, 
                        webhook_url: str, confidence_threshold: float, output_format: str) -> dict:
    """Generate a realistic API request structure."""
    
    request_id = str(uuid.uuid4())
    timestamp = datetime.now().isoformat()
    
    # Create request structure
    api_request = {
        "request": {
            "request_id": request_id,
            "timestamp": timestamp,
            "endpoint": api_endpoint,
            "method": method,
            "headers": {
                "Authorization": f"Bearer {api_key[:8]}{'*' * (len(api_key) - 8)}" if api_key else "Bearer ********",
                "Content-Type": "application/json",
                "User-Agent": "Nanonets-OCR-Client/1.0"
            },
            "parameters": {
                "confidence_threshold": confidence_threshold,
                "output_format": output_format,
                "webhook_url": webhook_url if webhook_url else None,
                "async_processing": bool(webhook_url)
            }
        },
        "response": {
            "status": "success",
            "status_code": 200,
            "request_id": request_id,
            "timestamp": datetime.now().isoformat(),
            "processing_time_ms": 0,  # Will be updated later
            "data": {
                "extracted_fields": extracted_data,
                "metadata": {
                    "total_fields": len(extracted_data),
                    "filled_fields": sum(1 for v in extracted_data.values() if v),
                    "empty_fields": sum(1 for v in extracted_data.values() if not v),
                    "confidence_scores": {
                        field: round(0.85 + (hash(field) % 15) / 100, 2) 
                        for field in extracted_data.keys()
                    }
                }
            }
        }
    }
    
    return api_request


def generate_webhook_payload(extracted_data: dict, request_id: str, document_info: dict) -> dict:
    """Generate webhook callback payload."""
    return {
        "event": "document.processed",
        "event_id": str(uuid.uuid4()),
        "timestamp": datetime.now().isoformat(),
        "request_id": request_id,
        "status": "completed",
        "document": document_info,
        "data": extracted_data,
        "metadata": {
            "webhook_delivery_attempt": 1,
            "webhook_signature": hashlib.sha256(request_id.encode()).hexdigest()[:32]
        }
    }


def ocr_core(file_path: gr.File, max_new_tokens: int = 2048, max_image_size: int = 1536) -> tuple:
    """Performs OCR with timing and returns both text and processing time."""
    start_time = time.time()
    
    if file_path is None:
        return "Error: No file provided.", "0:00:00", [], {}

    actual_file_path = file_path.name
    file_size = os.path.getsize(actual_file_path) / (1024 * 1024)  # Size in MB

    prompt = """Extract the text from the above document as if you were reading it naturally. Return the tables in html format. Return the equations in LaTeX representation. If there is an image in the document and image caption is not present, add a small description of the image inside the <img></img> tag; otherwise, add the image caption inside <img></img>. Watermarks should be wrapped in brackets. Ex: <watermark>OFFICIAL COPY</watermark>. Page numbers should be wrapped in brackets. Ex: <page_number>14</page_number> or <page_number>9/22</page_number>. Prefer using ‚òê and ‚òë for check boxes."""

    output_texts = []
    processed_images = []
    file_extension = os.path.splitext(actual_file_path)[1].lower()
    total_pages = 0
    
    if file_extension in ['.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff']:
        try:
            image = Image.open(actual_file_path).convert("RGB")
            image = resize_image_if_needed(image, max_image_size)
            processed_images.append(image.copy())
            output_text = process_single_image(image, prompt, max_new_tokens)
            output_texts.append(f"\n--- Page 1 ---\n{output_text}")
            total_pages = 1
            del image
            clear_memory()
        except Exception as e:
            elapsed_time = str(timedelta(seconds=int(time.time() - start_time)))
            return f"Error opening image file: {e}", elapsed_time, [], {}
            
    elif file_extension == '.pdf':
        try:
            doc = fitz.open(actual_file_path)
            total_pages = doc.page_count
            print(f"Processing PDF with {total_pages} pages...")
            
            for page_num in range(total_pages):
                page_start = time.time()
                print(f"Processing page {page_num + 1}/{total_pages}...")
                
                page = doc.load_page(page_num)
                
                pix = page.get_pixmap(matrix=fitz.Matrix(150/72, 150/72))
                img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
                
                img = resize_image_if_needed(img, max_image_size)
                processed_images.append(img.copy())
                
                output_text = process_single_image(img, prompt, max_new_tokens)
                output_texts.append(f"\n--- Page {page_num + 1} ---\n{output_text}")
                
                page_elapsed = time.time() - page_start
                print(f"Page {page_num + 1} completed in {page_elapsed:.2f}s")
                
                del page, pix, img
                clear_memory()
                
            doc.close()
            del doc
            clear_memory()
            
        except Exception as e:
            elapsed_time = str(timedelta(seconds=int(time.time() - start_time)))
            return f"Error processing PDF: {e}", elapsed_time, [], {}
    else:
        return "Unsupported file type. Please upload an image (JPG, PNG, etc.) or a PDF.", "0:00:00", [], {}

    if not output_texts:
        elapsed_time = str(timedelta(seconds=int(time.time() - start_time)))
        return "No valid images found to process.", elapsed_time, [], {}

    total_time = time.time() - start_time
    elapsed_time = str(timedelta(seconds=int(total_time)))
    
    # Document metadata
    doc_info = {
        "filename": os.path.basename(actual_file_path),
        "file_size_mb": round(file_size, 2),
        "file_type": file_extension.upper().replace('.', ''),
        "total_pages": total_pages,
        "processing_time_seconds": round(total_time, 2)
    }
    
    return "\n\n".join(output_texts), elapsed_time, processed_images, doc_info


def parse_ocr_output_to_structured_data(ocr_result_full_text: str) -> dict:
    """Parse raw OCR output into structured dictionary format."""
    structured_data = {"pages": []}
    
    pages_raw_output = re.split(r'\n--- Page \d+ ---\n', ocr_result_full_text)
    pages_raw_output = [p.strip() for p in pages_raw_output if p.strip()]

    for page_idx, page_text in enumerate(pages_raw_output):
        page_data = {
            "page_number": page_idx + 1,
            "raw_text": page_text,
            "tables_html": [],
            "tables_csv": [],
            "latex_equations": [],
            "image_descriptions": [],
            "watermarks": [],
            "page_numbers_extracted": []
        }

        soup = BeautifulSoup(page_text, 'html.parser')

        # Extract HTML Tables
        tables = soup.find_all('table')
        if tables:
            for i, table_tag in enumerate(tables):
                table_html = str(table_tag)
                page_data["tables_html"].append(table_html)
                try:
                    df = pd.read_html(table_html)[0]
                    csv_buffer = io.StringIO()
                    df.to_csv(csv_buffer, index=False)
                    page_data["tables_csv"].append(csv_buffer.getvalue())
                except Exception as e:
                    page_data["tables_csv"].append(f"Error converting table {i+1} to CSV: {e}")
        
        # Extract LaTeX Equations
        latex_matches = re.findall(
            r'\$\$[^$]+\$\$|\$[^$]+\$|\\begin\{equation\*?\}(.*?)\\end\{equation\*?\}|\\begin\{align\*?\}(.*?)\\end\{align\*?\}',
            page_text, re.DOTALL
        )
        for match in latex_matches:
            equations = [eq.strip() for eq in match if eq.strip()]
            page_data["latex_equations"].extend(equations)

        # Extract Image Descriptions/Captions
        image_matches = re.findall(r'<img>(.*?)<\/img>', page_text)
        page_data["image_descriptions"].extend([img.strip() for img in image_matches if img.strip()])

        # Extract Watermarks
        watermark_matches = re.findall(r'<watermark>(.*?)<\/watermark>', page_text)
        page_data["watermarks"].extend([wm.strip() for wm in watermark_matches if wm.strip()])

        # Extract Page Numbers
        page_number_matches = re.findall(r'<page_number>(.*?)<\/page_number>', page_text)
        page_data["page_numbers_extracted"].extend([pn.strip() for pn in page_number_matches if pn.strip()])

        structured_data["pages"].append(page_data)
        
    return structured_data


def convert_to_json(structured_data: dict) -> str:
    """Convert structured OCR data to JSON."""
    return json.dumps(structured_data, indent=2)


def convert_to_xml(structured_data: dict) -> str:
    """Convert structured OCR data to XML."""
    root = ET.Element("DocumentOCR")
    for page_data in structured_data.get("pages", []):
        page_elem = ET.SubElement(root, "Page", number=str(page_data["page_number"]))
        
        ET.SubElement(page_elem, "RawText").text = page_data["raw_text"]

        if page_data["tables_html"]:
            tables_elem = ET.SubElement(page_elem, "Tables")
            for i, html_table in enumerate(page_data["tables_html"]):
                table_elem = ET.SubElement(tables_elem, "Table", id=str(i+1))
                ET.SubElement(table_elem, "HTMLContent").text = html_table
                if page_data["tables_csv"] and i < len(page_data["tables_csv"]):
                    ET.SubElement(table_elem, "CSVContent").text = page_data["tables_csv"][i]
        
        if page_data["latex_equations"]:
            equations_elem = ET.SubElement(page_elem, "Equations")
            for i, eq in enumerate(page_data["latex_equations"]):
                ET.SubElement(equations_elem, "Equation", id=str(i+1)).text = eq

        if page_data["image_descriptions"]:
            images_elem = ET.SubElement(page_elem, "Images")
            for i, desc in enumerate(page_data["image_descriptions"]):
                ET.SubElement(images_elem, "Description", id=str(i+1)).text = desc

        if page_data["watermarks"]:
            watermarks_elem = ET.SubElement(page_elem, "Watermarks")
            for i, wm in enumerate(page_data["watermarks"]):
                ET.SubElement(watermarks_elem, "Watermark", id=str(i+1)).text = wm

        if page_data["page_numbers_extracted"]:
            page_nums_elem = ET.SubElement(page_elem, "PageNumbers")
            for i, pn in enumerate(page_data["page_numbers_extracted"]):
                ET.SubElement(page_nums_elem, "PageNumber", id=str(i+1)).text = pn

    rough_string = ET.tostring(root, 'utf-8')
    reparsed = minidom.parseString(rough_string)
    return reparsed.toprettyxml(indent="  ")


def process_document_for_ui(file: gr.File, max_new_tokens: int, max_image_size: int, 
                            enabled_fields: list, custom_field_1: str, custom_field_2: str, 
                            custom_field_3: str, custom_field_4: str, custom_field_5: str,
                            custom_field_6: str, custom_field_7: str, custom_field_8: str,
                            custom_field_9: str, custom_field_10: str,
                            api_endpoint: str, api_key: str, api_method: str, webhook_url: str,
                            confidence_threshold: float, output_format: str, enable_batch: bool) -> tuple:
    """Orchestrate OCR and post-processing for Gradio tabbed output."""
    
    process_start = time.time()
    
    # Step 1: Perform core OCR with timing
    ocr_result_full_text, processing_time, processed_images, doc_info = ocr_core(file, max_new_tokens, max_image_size)

    if ocr_result_full_text.startswith("Error:"):
        return (ocr_result_full_text, processing_time, "", "", "", "", "", "", "", "", "", "", None, "", "")

    # Step 2: Parse structured data
    structured_data = parse_ocr_output_to_structured_data(ocr_result_full_text)

    # Step 3: Create full HTML preview
    ocr_html = ocr_result_full_text.replace('\n', '<br>')
    
    full_html_preview = f"""
    <div style="font-family: Arial, sans-serif; padding: 20px; background: #f5f5f5; border-radius: 8px;">
        <h2 style="color: #333; border-bottom: 2px solid #4CAF50; padding-bottom: 10px;">üìÑ Document Preview</h2>
        <div style="background: white; padding: 15px; border-radius: 5px; margin-top: 15px; line-height: 1.6;">
            {ocr_html}
        </div>
    </div>
    """


    # Step 4: HTML Tables Preview
    all_raw_html_tables = []
    for page_data in structured_data["pages"]:
        for table_html in page_data["tables_html"]:
            all_raw_html_tables.append(f"\n{table_html}")
    html_tables_raw_output = "\n".join(all_raw_html_tables) if all_raw_html_tables else "No HTML tables found."

    # Step 5: CSV Output
    all_tables_csv = []
    for page_data in structured_data["pages"]:
        for i, table_csv in enumerate(page_data["tables_csv"]):
            all_tables_csv.append(f"--- Page {page_data['page_number']}, Table {i+1} ---\n{table_csv}")
    tables_csv_output = "\n\n".join(all_tables_csv) if all_tables_csv else "No tables found or could not convert."

    # Step 6: LaTeX Output
    all_latex_equations = []
    for page_data in structured_data["pages"]:
        for i, eq in enumerate(page_data["latex_equations"]):
            all_latex_equations.append(f"--- Page {page_data['page_number']}, LaTeX Equation {i+1} ---\n{eq}")
    latex_equations_output = "\n\n".join(all_latex_equations) if all_latex_equations else "No LaTeX equations found."

    # Step 7: Image Descriptions
    all_image_descriptions = []
    for page_data in structured_data["pages"]:
        for i, desc in enumerate(page_data["image_descriptions"]):
            all_image_descriptions.append(f"--- Page {page_data['page_number']}, Image Description {i+1} ---\n{desc}")
    image_descriptions_output = "\n\n".join(all_image_descriptions) if all_image_descriptions else "No image descriptions/captions found."

    # Step 8: Watermarks
    all_watermarks = []
    for page_data in structured_data["pages"]:
        for i, wm in enumerate(page_data["watermarks"]):
            all_watermarks.append(f"--- Page {page_data['page_number']}, Watermark {i+1} ---\n{wm}")
    watermarks_output = "\n\n".join(all_watermarks) if all_watermarks else "No watermarks found."

    # Step 9: Page Numbers
    all_page_numbers = []
    for page_data in structured_data["pages"]:
        for i, pn in enumerate(page_data["page_numbers_extracted"]):
            all_page_numbers.append(f"--- Page {page_data['page_number']}, Page Number {i+1} ---\n{pn}")
    page_numbers_output = "\n\n".join(all_page_numbers) if all_page_numbers else "No page numbers found."

    # Step 10: Generate JSON and XML
    json_output = convert_to_json(structured_data)
    xml_output = convert_to_xml(structured_data)

    # Step 11: Create bounding box visualization for first page
    bbox_image = None
    if processed_images:
        bbox_image = create_bounding_box_visualization(processed_images[0], ocr_result_full_text)

    # Step 12: Extract API fields
    custom_fields = [custom_field_1, custom_field_2, custom_field_3, custom_field_4, custom_field_5,
                     custom_field_6, custom_field_7, custom_field_8, custom_field_9, custom_field_10]
    api_data = extract_api_fields(ocr_result_full_text, enabled_fields, custom_fields)
    
    # Step 13: Generate API Request/Response structure
    total_processing_ms = int((time.time() - process_start) * 1000)
    api_request = generate_api_request(api_endpoint, api_key, api_method, api_data, 
                                      webhook_url, confidence_threshold, output_format)
    api_request["response"]["processing_time_ms"] = total_processing_ms
    
    # Add document info to response
    api_request["response"]["document"] = doc_info
    
    # Add batch processing info if enabled
    if enable_batch:
        api_request["request"]["parameters"]["batch_processing"] = {
            "enabled": True,
            "batch_id": str(uuid.uuid4()),
            "documents_in_batch": 1
        }
    
    api_json_output = json.dumps(api_request, indent=2)
    
    # Step 14: Generate webhook payload if webhook URL is provided
    webhook_output = ""
    if webhook_url:
        request_id = api_request["request"]["request_id"]
        webhook_payload = generate_webhook_payload(api_data, request_id, doc_info)
        webhook_output = json.dumps(webhook_payload, indent=2)

    # Step 15: Generate statistics
    stats_output = f"""
    üìä Processing Statistics:
    
    Document Information:
    - Filename: {doc_info.get('filename', 'N/A')}
    - File Size: {doc_info.get('file_size_mb', 0)} MB
    - File Type: {doc_info.get('file_type', 'N/A')}
    - Total Pages: {doc_info.get('total_pages', 0)}
    
    Processing Metrics:
    - Total Processing Time: {processing_time}
    - Average Time per Page: {round(doc_info.get('processing_time_seconds', 0) / max(doc_info.get('total_pages', 1), 1), 2)}s
    - API Response Time: {total_processing_ms} ms
    
    Extraction Results:
    - Total Fields Requested: {len(enabled_fields) + len([f for f in custom_fields if f.strip()])}
    - Fields Successfully Extracted: {sum(1 for v in api_data.values() if v)}
    - Fields Not Found: {sum(1 for v in api_data.values() if not v)}
    - Extraction Success Rate: {round(sum(1 for v in api_data.values() if v) / max(len(api_data), 1) * 100, 1)}%
    
    Content Detection:
    - Tables Found: {sum(len(p['tables_html']) for p in structured_data['pages'])}
    - Equations Found: {sum(len(p['latex_equations']) for p in structured_data['pages'])}
    - Images Found: {sum(len(p['image_descriptions']) for p in structured_data['pages'])}
    - Watermarks Found: {sum(len(p['watermarks']) for p in structured_data['pages'])}
    - Page Numbers Found: {sum(len(p['page_numbers_extracted']) for p in structured_data['pages'])}
    """

    return (ocr_result_full_text, processing_time, full_html_preview, html_tables_raw_output, 
            tables_csv_output, latex_equations_output, image_descriptions_output, 
            watermarks_output, page_numbers_output, json_output, xml_output, bbox_image, 
            api_json_output, webhook_output, stats_output)


# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Default()) as demo:
    gr.Markdown(
        """
        # üìÑ Nanonets-OCR-s Document Extractor & API Simulator
        
        **Professional OCR with Real-World API Integration Simulation**
        
        This tool provides enterprise-grade OCR with comprehensive API simulation features including:
        - ‚úÖ Realistic API request/response structure
        - ‚úÖ Webhook callback simulation
        - ‚úÖ Batch processing support
        - ‚úÖ Confidence scoring
        - ‚úÖ Multiple output formats
        - ‚úÖ Processing statistics and analytics
        
        **Optimized for 16GB VRAM** | Processing time and detailed statistics included
        """
    )

    with gr.Row():
        with gr.Column(scale=2):
            file_input = gr.File(label="üìÅ Upload Document", file_types=["image", ".pdf"], interactive=True)
        with gr.Column(scale=1):
            gr.Markdown("### ‚öôÔ∏è Processing Settings")
            max_tokens_slider = gr.Slider(
                minimum=500, maximum=6000, step=250, value=2048,
                label="Max Tokens", interactive=True
            )
            max_image_size_slider = gr.Slider(
                minimum=512, maximum=2048, step=128, value=1536,
                label="Max Image Size (px)", interactive=True
            )
    
    # API Configuration Section
    with gr.Accordion("üîå API Configuration", open=True):
        gr.Markdown("### Configure your API endpoint and authentication")
        
        with gr.Row():
            api_endpoint = gr.Textbox(
                label="API Endpoint URL",
                placeholder="https://api.example.com/v1/ocr/extract",
                value="https://api.nanonets.com/v1/ocr/extract",
                interactive=True
            )
            api_method = gr.Dropdown(
                choices=["POST", "PUT", "PATCH"],
                value="POST",
                label="HTTP Method",
                interactive=True
            )
        
        with gr.Row():
            api_key = gr.Textbox(
                label="üîë API Key",
                placeholder="Enter your API key here",
                type="password",
                value="sk_test_1234567890abcdefghijklmnopqrstuvwxyz",
                interactive=True
            )
            webhook_url = gr.Textbox(
                label="üîî Webhook URL (Optional)",
                placeholder="https://your-domain.com/webhook/callback",
                value="https://your-domain.com/webhook/callback",
                interactive=True
            )
        
        with gr.Row():
            confidence_threshold = gr.Slider(
                minimum=0.0, maximum=1.0, step=0.05, value=0.75,
                label="Confidence Threshold",
                interactive=True
            )
            output_format = gr.Dropdown(
                choices=["JSON", "XML", "CSV", "PDF"],
                value="JSON",
                label="Output Format",
                interactive=True
            )
            enable_batch = gr.Checkbox(
                label="Enable Batch Processing",
                value=False,
                interactive=True
            )
    
    # Field Extraction Section
    with gr.Accordion("üìã Field Extraction Configuration", open=False):
        gr.Markdown("### Select fields to extract from documents")
        
        with gr.Row():
            with gr.Column(scale=1):
                gr.Markdown("**All predefined fields are enabled by default**")
                field_checkboxes = gr.CheckboxGroup(
                    choices=PREDEFINED_FIELDS,
                    label="Predefined Fields (50 Fields)",
                    value=PREDEFINED_FIELDS,  # All enabled by default
                    interactive=True
                )
        
        gr.Markdown("### üéØ Custom Fields (Add up to 10 custom fields)")
        gr.Markdown("*Define your own fields for specialized document types*")
        
        with gr.Row():
            custom_field_1 = gr.Textbox(label="Custom Field 1", placeholder="e.g., Tax ID")
            custom_field_2 = gr.Textbox(label="Custom Field 2", placeholder="e.g., VAT Number")
            custom_field_3 = gr.Textbox(label="Custom Field 3", placeholder="e.g., GST Number")
        with gr.Row():
            custom_field_4 = gr.Textbox(label="Custom Field 4", placeholder="e.g., Discount Rate")
            custom_field_5 = gr.Textbox(label="Custom Field 5", placeholder="e.g., Service Charge")
            custom_field_6 = gr.Textbox(label="Custom Field 6", placeholder="e.g., Delivery Fee")
        with gr.Row():
            custom_field_7 = gr.Textbox(label="Custom Field 7", placeholder="e.g., Processing Fee")
            custom_field_8 = gr.Textbox(label="Custom Field 8", placeholder="e.g., Insurance")
        with gr.Row():
            custom_field_9 = gr.Textbox(label="Custom Field 9", placeholder="e.g., Warranty Period")
            custom_field_10 = gr.Textbox(label="Custom Field 10", placeholder="e.g., Return Policy")
    
    process_button = gr.Button("üöÄ Process Document", variant="primary", size="lg")
    
    with gr.Row():
        processing_time_display = gr.Textbox(label="‚è±Ô∏è Processing Time", interactive=False, scale=1)

    with gr.Tabs():
        with gr.TabItem("üìä Statistics & Summary"):
            stats_output_viewer = gr.Textbox(
                label="Processing Statistics and Analytics", 
                lines=20, 
                interactive=False
            )
        
        with gr.TabItem("üîå API Request/Response"):
            api_json_viewer = gr.Code(
                label="Complete API Request & Response Structure", 
                language="json", 
                lines=25, 
                interactive=False
            )
            gr.Markdown("""
            **API Response Structure:**
            - `request`: Complete request details including headers, parameters, and configuration
            - `response`: API response with extracted data, metadata, and confidence scores
            - `data.extracted_fields`: All extracted field values
            - `data.metadata`: Statistics about the extraction process
            """)
        
        with gr.TabItem("üîî Webhook Payload"):
            webhook_output_viewer = gr.Code(
                label="Webhook Callback Payload", 
                language="json", 
                lines=20, 
                interactive=False
            )
            gr.Markdown("""
            **Webhook Event Structure:**
            - Triggered when document processing completes
            - Contains event metadata and extracted data
            - Includes webhook signature for verification
            - Empty if no webhook URL is configured
            """)
        
        with gr.TabItem("üìÑ Full OCR Text"):
            full_ocr_text_output = gr.Textbox(
                label="Complete Extracted Text (Raw)", 
                lines=25, 
                max_lines=40, 
                interactive=False
            )
        
        with gr.TabItem("üåê Full HTML Preview"):
            full_html_preview_viewer = gr.HTML(label="Full Document HTML Preview")
        
        with gr.TabItem("üìä Tables (HTML)"):
            html_tables_raw_output_viewer = gr.HTML(label="Extracted Tables (HTML Preview)")
        
        with gr.TabItem("üìà Tables (CSV)"):
            tables_csv_output_viewer = gr.Textbox(
                label="Extracted Tables (CSV Format)", 
                lines=15, 
                max_lines=30, 
                interactive=False
            )
        
        with gr.TabItem("üî¢ LaTeX Equations"):
            latex_equations_output_viewer = gr.Textbox(
                label="Extracted LaTeX Equations", 
                lines=10, 
                max_lines=20, 
                interactive=False
            )
        
        with gr.TabItem("üñºÔ∏è Image Descriptions"):
            image_descriptions_output_viewer = gr.Textbox(
                label="Extracted Image Descriptions/Captions", 
                lines=5, 
                max_lines=10, 
                interactive=False
            )
        
        with gr.TabItem("üíß Watermarks"):
            watermarks_output_viewer = gr.Textbox(
                label="Extracted Watermarks", 
                lines=5, 
                max_lines=10, 
                interactive=False
            )
        
        with gr.TabItem("üî¢ Page Numbers"):
            page_numbers_output_viewer = gr.Textbox(
                label="Extracted Page Numbers", 
                lines=5, 
                max_lines=10, 
                interactive=False
            )
        
        with gr.TabItem("üéØ Bounding Boxes"):
            bbox_image_viewer = gr.Image(label="Bounding Box Visualization (First Page)", type="pil")
            gr.Markdown("""
            **Element Detection Visualization:**
            - üî¥ Red: Tables
            - üü¢ Green: Equations
            - üîµ Blue: Images
            - üü° Yellow: Watermarks
            """)
        
        with gr.TabItem("üì¶ Complete JSON"):
            json_output_viewer = gr.Code(
                label="Complete Structured JSON Output", 
                language="json", 
                lines=25, 
                interactive=False
            )
        
        with gr.TabItem("üìã Complete XML"):
            xml_output_viewer = gr.Code(
                label="Complete Structured XML Output", 
                language="html", 
                lines=25, 
                interactive=False
            )

    # Button click handler
    process_button.click(
        fn=process_document_for_ui,
        inputs=[
            file_input, max_tokens_slider, max_image_size_slider,
            field_checkboxes,
            custom_field_1, custom_field_2, custom_field_3, custom_field_4, custom_field_5,
            custom_field_6, custom_field_7, custom_field_8, custom_field_9, custom_field_10,
            api_endpoint, api_key, api_method, webhook_url, confidence_threshold, 
            output_format, enable_batch
        ],
        outputs=[
            full_ocr_text_output,
            processing_time_display,
            full_html_preview_viewer,
            html_tables_raw_output_viewer,
            tables_csv_output_viewer,
            latex_equations_output_viewer,
            image_descriptions_output_viewer,
            watermarks_output_viewer,
            page_numbers_output_viewer,
            json_output_viewer,
            xml_output_viewer,
            bbox_image_viewer,
            api_json_viewer,
            webhook_output_viewer,
            stats_output_viewer
        ]
    )
    
    # Footer with information
    gr.Markdown("""
    ---
    ### üí° Tips for Best Results:
    - **For large documents**: Reduce max tokens to 2000 and image size to 1280px
    - **For detailed extraction**: Increase max tokens to 4000+ (requires more VRAM)
    - **API Key**: Safely stored as password type (displayed as ******)
    - **Webhook**: Configure for asynchronous processing notifications
    - **Batch Processing**: Enable for processing multiple documents in sequence
    
    ### üöÄ Real-World Features:
    - Request/Response tracking with unique IDs
    - Confidence scoring for each extracted field
    - Processing time metrics (ms)
    - Document metadata (size, pages, type)
    - Webhook callback simulation
    - Multiple output format support
    - Custom field definitions
    - Batch processing capabilities
    """)


if __name__ == "__main__":
    print(f"Launching Gradio interface. Model loaded on device: {model.device}")
    if torch.cuda.is_available():
        print(f"Total VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
        print(f"Currently allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
    demo.launch(share=True)