# PDF Processing via Enhanced Open-Source VLM
## Optimized for Google Colab with High Accuracy

## 0. Setting your Name and Email

Please starting by putting your name and email in the following variables - please stick to the required format i.e. NAME_SURNAME

In [None]:
# WRITE YOUR NAME_SURNAME HERE, AS WELL AS YOUR EMAIL WITH WHICH YOU LOGGED IN INTO CELONIS
MY_NAME = 'SCHUMANN'
MY_EMAIL = 'schumann.marvin@outlook.com'

## 1. Installing and importing required packages

In [None]:
# Run the first time you execute the script and then comment it out again.
!pip install --extra-index-url=https://pypi.celonis.cloud/ pycelonis
!pip install nbformat

# Install poppler for PDF processing (required on Colab)
import sys
if 'google.colab' in sys.modules:
    !apt-get update -qq
    !apt-get install -y -qq poppler-utils
    print("‚úÖ Poppler installed for PDF support!")

In [None]:
# ============================================================
# ENHANCED OPEN-SOURCE INVOICE EXTRACTION
# Model: Qwen2-VL-7B-Instruct (4-bit quantized)
# Optimized for: Google Colab Free Tier (~12-16GB RAM)
# Key Features:
# - Validation & retry logic
# - Better prompt engineering
# - Post-processing validation
# - 95%+ accuracy target
# ============================================================

import sys
import subprocess

def install_package(package):
    """Install package if not already installed"""
    try:
        __import__(package.split('[')[0].replace('-', '_'))
    except ImportError:
        print(f"üì¶ Installing {package}...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", package])

# Detect environment
IS_COLAB = 'google.colab' in sys.modules
print(f"Environment: {'Google Colab' if IS_COLAB else 'Local'}")

# Core packages
packages = [
    "pillow",
    "pdf2image",
    "pandas",
    "tqdm",
    "torch",
    "transformers>=4.37.0",
    "accelerate",
    "bitsandbytes",  # For 4-bit quantization
    "qwen-vl-utils",  # Qwen2-VL utilities
]

for pkg in packages:
    install_package(pkg)

print("‚úÖ All dependencies installed!")

## 2. Extract information from Invoices

This is the section you will need to fill in. Your code should create the following:
- **a pandas dataframe called df that includes the extracted information**.
- **the dataframe should contain a column called 'po_reference' that contains the reference to the PO**
- **the values in the column 'po_reference' should be a 11-char long strings. Use left padding with zeros where needed.**

In [None]:
# ============================================================
# SECTION 2: ENHANCED INVOICE EXTRACTION
# Using Qwen2-VL-7B-Instruct with validation & retry
# ============================================================

import os
import re
import json
from pathlib import Path
from typing import Dict, List, Optional
import pandas as pd
from PIL import Image
from tqdm import tqdm
from pdf2image import convert_from_path
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info

print("‚úÖ Packages loaded!")

# ==================== CONFIGURATION ====================

# Auto-detect path for Colab or local
if 'google.colab' in sys.modules:
    # Check if cloned repo exists
    if Path("/content/orbit_challenge/Invoices").exists():
        INVOICE_DIR = Path("/content/orbit_challenge/Invoices")
    else:
        INVOICE_DIR = Path("/content/Invoices")
else:
    INVOICE_DIR = Path("/Users/marvinschumann/orbit_challenge/Invoices")

MODEL_ID = "Qwen/Qwen2-VL-7B-Instruct"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Lower DPI to reduce memory usage (200 is good quality, uses less memory)
PDF_DPI = 200
# Maximum image dimensions to prevent OOM
MAX_IMAGE_SIZE = (1600, 1600)

REQUIRED_FIELDS = [
    "vendor_name",
    "vendor_address",
    "payment_terms",
    "invoice_value",
    "company_code",
    "po_reference",
    "invoice_id"
]

# ==================== HELPER FUNCTIONS ====================

def resize_image_if_needed(image: Image.Image, max_size: tuple = MAX_IMAGE_SIZE) -> Image.Image:
    """Resize image if it exceeds max dimensions while maintaining aspect ratio"""
    if image.width > max_size[0] or image.height > max_size[1]:
        image.thumbnail(max_size, Image.Resampling.LANCZOS)
        print(f"    üìê Resized to {image.width}x{image.height} to save memory")
    return image

def load_invoice_pages(invoice_dir: Path) -> List[Dict]:
    """Load all invoice pages as images"""
    pages = []
    files = sorted([p for p in invoice_dir.iterdir() if p.is_file()])
    
    for file_path in tqdm(files, desc="üìÇ Loading invoices"):
        suffix = file_path.suffix.lower()
        invoice_id = file_path.stem
        
        try:
            if suffix == ".pdf":
                # Use lower DPI to reduce memory usage
                images = convert_from_path(str(file_path), dpi=PDF_DPI, fmt="png")
                for idx, img in enumerate(images, start=1):
                    img = img.convert("RGB")
                    img = resize_image_if_needed(img)
                    pages.append({
                        "invoice_id": invoice_id,
                        "page_index": idx,
                        "image": img,
                    })
            elif suffix in {".png", ".jpg", ".jpeg"}:
                img = Image.open(file_path).convert("RGB")
                img = resize_image_if_needed(img)
                pages.append({
                    "invoice_id": invoice_id,
                    "page_index": 1,
                    "image": img,
                })
        except Exception as e:
            print(f"‚ùå Error loading {file_path.name}: {e}")
    
    return pages

def sanitize_po_reference(po_value: str) -> str:
    """Extract digits and zero-pad to 11 characters"""
    digits = re.sub(r"\D", "", po_value or "")
    return digits.zfill(11) if digits else "00000000000"

def validate_extraction(data: Dict) -> tuple[bool, List[str]]:
    """Validate extraction quality and return issues"""
    issues = []
    
    # Check for empty critical fields
    critical_fields = ["vendor_name", "invoice_value", "po_reference"]
    for field in critical_fields:
        if not data.get(field, "").strip():
            issues.append(f"Missing {field}")
    
    # Validate invoice_value format - accept both currency symbols and numeric values
    inv_val = data.get("invoice_value", "")
    if inv_val:
        # Check if it has either currency symbol OR is a valid number
        has_currency = bool(re.search(r'[‚Ç¨$¬£¬•]', inv_val))
        has_number = bool(re.search(r'\d+\.?\d*', inv_val))
        if not has_number:
            issues.append("Invalid invoice_value format")
    
    # Validate PO reference has digits
    po_ref = data.get("po_reference", "")
    if po_ref and not re.search(r'\d', po_ref):
        issues.append("PO reference has no digits")
    
    return len(issues) == 0, issues

# ==================== EXTRACTION PROMPTS ====================

EXTRACTION_PROMPT = """Analyze this invoice image and extract the following information with EXTREME ACCURACY.

Return ONLY a valid JSON object with these exact fields:

{
  "vendor_name": "<company name providing goods/services>",
  "vendor_address": "<complete address of the vendor>",
  "payment_terms": "<payment terms and conditions>",
  "invoice_value": "<TOTAL amount INCLUDING VAT/tax with currency symbol>",
  "company_code": "<company code or customer code>",
  "po_reference": "<purchase order number - extract the full PO number>",
  "invoice_id": "<invoice number or invoice ID>"
}

CRITICAL RULES:
1. Return ONLY the JSON object - no markdown, no code blocks, no explanation
2. For invoice_value: Use the TOTAL/FINAL amount WITH tax (look for "Total", "Amount Due", "Grand Total")
3. For po_reference: Extract the COMPLETE PO number (e.g., "PO-586652" ‚Üí "586652")
4. Use empty string "" for fields not found
5. All string values must be double-quoted
6. Be EXTREMELY accurate with numbers and codes - double-check each field

Extract the data now:"""

RETRY_PROMPT = """The previous extraction had errors. Please re-analyze this invoice MORE CAREFULLY.

Focus on these specific fields:
- vendor_name: The company sending the invoice (usually at the top)
- invoice_value: The FINAL TOTAL amount to pay (including tax/VAT)
- po_reference: Purchase Order number (look for "PO", "P.O.", "Purchase Order")
- company_code: Customer code or company code

Return ONLY a valid JSON object:

{
  "vendor_name": "...",
  "vendor_address": "...",
  "payment_terms": "...",
  "invoice_value": "...",
  "company_code": "...",
  "po_reference": "...",
  "invoice_id": "..."
}

Extract carefully:"""

# ==================== MODEL INITIALIZATION ====================

print("\n" + "="*70)
print("üöÄ INITIALIZING QWEN2-VL-7B-INSTRUCT (4-bit)")
print("="*70)

# Load model with 4-bit quantization for Colab
model = Qwen2VLForConditionalGeneration.from_pretrained(
    MODEL_ID,
    torch_dtype="auto",
    device_map="auto",
    load_in_4bit=True if DEVICE == "cuda" else False,
)

processor = AutoProcessor.from_pretrained(MODEL_ID)

print(f"‚úÖ Model loaded on {DEVICE}")

# Clear any cached memory
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print(f"üíæ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB total")

# ==================== EXTRACTION FUNCTION ====================

def extract_with_qwen(image: Image.Image, invoice_id: str, retry: bool = False) -> Dict:
    """Extract data using Qwen2-VL with validation and retry"""
    
    prompt = RETRY_PROMPT if retry else EXTRACTION_PROMPT
    
    try:
        # Clear GPU cache before processing
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        # Prepare messages for Qwen2-VL
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image},
                    {"type": "text", "text": prompt},
                ],
            }
        ]
        
        # Prepare inputs
        text = processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        image_inputs, video_inputs = process_vision_info(messages)
        inputs = processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to(DEVICE)
        
        # Generate
        with torch.no_grad():
            generated_ids = model.generate(
                **inputs,
                max_new_tokens=512,
                temperature=0.1,
                top_p=0.9,
            )
        
        generated_ids_trimmed = [
            out_ids[len(in_ids):] 
            for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        
        response_text = processor.batch_decode(
            generated_ids_trimmed,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False
        )[0].strip()
        
        # Clean response
        if response_text.startswith("```"):
            response_text = re.sub(r"```(?:json)?\n?", "", response_text)
            response_text = response_text.strip("`")
        
        # Extract JSON
        json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
        if json_match:
            response_text = json_match.group(0)
        
        data = json.loads(response_text)
        
        # Ensure all fields exist
        for field in REQUIRED_FIELDS:
            if field not in data:
                data[field] = ""
        
        # Clear GPU cache after successful processing
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        return data
        
    except torch.cuda.OutOfMemoryError as e:
        print(f"    üí• GPU out of memory for {invoice_id} - clearing cache and skipping")
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        return {field: "" for field in REQUIRED_FIELDS}
    except json.JSONDecodeError as e:
        print(f"    ‚ö†Ô∏è  JSON parse error for {invoice_id}")
        return {field: "" for field in REQUIRED_FIELDS}
    except Exception as e:
        print(f"    ‚ùå Error for {invoice_id}: {str(e)[:100]}")
        return {field: "" for field in REQUIRED_FIELDS}

# ==================== MAIN PIPELINE ====================

print("\n" + "="*70)
print("üöÄ INVOICE EXTRACTION - QWEN2-VL-7B")
print("="*70)

# Validate
if not INVOICE_DIR.exists():
    print(f"\n‚ùå ERROR: Directory not found: {INVOICE_DIR}")
    raise FileNotFoundError(f"Directory not found: {INVOICE_DIR}")

# Load invoices
invoice_pages = load_invoice_pages(INVOICE_DIR)
print(f"\n‚úÖ Loaded {len(invoice_pages)} page(s) from {len(set(p['invoice_id'] for p in invoice_pages))} invoice(s)")

# Extract with validation and retry
print("\n" + "="*70)
print("üîç EXTRACTING DATA WITH VALIDATION")
print("="*70)

page_results = []

for page in invoice_pages:
    print(f"\nüìÑ {page['invoice_id']} (page {page['page_index']}):")
    print(f"  üìê Image size: {page['image'].width}x{page['image'].height}")
    print("  ü§ñ Extracting with Qwen2-VL...")
    
    # First attempt
    result = extract_with_qwen(page['image'], page['invoice_id'], retry=False)
    result['invoice_id'] = page['invoice_id']
    
    # Validate
    is_valid, issues = validate_extraction(result)
    
    if not is_valid:
        print(f"  ‚ö†Ô∏è  Validation failed: {', '.join(issues)}")
        print("  üîÑ Retrying with stricter prompt...")
        
        # Retry
        result = extract_with_qwen(page['image'], page['invoice_id'], retry=True)
        result['invoice_id'] = page['invoice_id']
        
        is_valid, issues = validate_extraction(result)
        if is_valid:
            print("  ‚úÖ Retry successful!")
        else:
            print(f"  ‚ö†Ô∏è  Still has issues: {', '.join(issues)}")
    else:
        print("  ‚úÖ Extraction validated!")
    
    # Show filled fields
    filled = sum(1 for f in REQUIRED_FIELDS if result.get(f, "").strip())
    print(f"  üìä Extracted {filled}/{len(REQUIRED_FIELDS)} fields")
    
    page_results.append(result)

# Consolidate multi-page invoices
print("\nüìä Consolidating results...")
consolidated = {}

for entry in page_results:
    inv_id = entry["invoice_id"]
    
    if inv_id not in consolidated:
        consolidated[inv_id] = {field: "" for field in REQUIRED_FIELDS}
        consolidated[inv_id]["invoice_id"] = inv_id
    
    # Merge: prefer non-empty values
    for field in REQUIRED_FIELDS:
        if not consolidated[inv_id][field] and entry.get(field):
            consolidated[inv_id][field] = entry[field]

# Create DataFrame
records = list(consolidated.values())

for record in records:
    # Sanitize PO reference to 11 digits
    record["po_reference"] = sanitize_po_reference(record["po_reference"])
    
    # Ensure all values are strings
    for field in REQUIRED_FIELDS:
        record[field] = str(record.get(field, "")).strip()

df = pd.DataFrame(records, columns=REQUIRED_FIELDS)

# ==================== RESULTS ====================

print("\n" + "="*70)
print("‚úÖ EXTRACTION COMPLETE")
print("="*70)

print(f"\nüìã Extracted {len(df)} invoices:\n")
print(df.to_string(index=False))

# Detailed results per invoice
print("\n" + "="*70)
print("üìä DETAILED RESULTS")
print("="*70)

for idx, row in df.iterrows():
    inv_id = row['invoice_id']
    print(f"\nüìÑ {inv_id}:")
    for field in REQUIRED_FIELDS:
        value = row[field]
        status = "‚úÖ" if value and value != "00000000000" else "‚ùå"
        print(f"  {status} {field}: {value if value else '(empty)'}")

# Quality check
empty_per_row = df.apply(lambda row: sum(v == "" or v == "00000000000" for v in row), axis=1)

print("\n" + "="*70)
print("üìà SUMMARY")
print("="*70)

if empty_per_row.sum() == 0:
    print("\n‚úÖ Perfect! All fields extracted successfully")
else:
    print(f"\n‚ö†Ô∏è  Warning: {empty_per_row.sum()} empty/default fields detected")
    problem_invoices = df[empty_per_row > 0]
    print("\nInvoices needing attention:")
    print(problem_invoices[["invoice_id", "po_reference"]].to_string(index=False))

# PO reference validation
invalid_po = df[df["po_reference"].str.len() != 11]
if len(invalid_po) > 0:
    print(f"\n‚ö†Ô∏è  Warning: {len(invalid_po)} PO references not 11 chars")

print("\n" + "="*70)
print("‚úÖ READY FOR PUSH.IPYNB")
print("="*70)
print("\nNext: Run %run push.ipynb")

## 3. Pushing Data back to Data Pool

In [None]:
%run push.ipynb