<a href="https://colab.research.google.com/github/RogueTex/StreamingDataforModelTraining/blob/main/NewVerPynbAgent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Receipt Automation System

This notebook builds a receipt processing pipeline with:
- **Document Classification** - ViT model
- **OCR** - EasyOCR for text extraction
- **Field Extraction** - LayoutLMv3 + regex patterns
- **Anomaly Detection** - Isolation Forest
- **Agent Workflow** - LangGraph
- **Demo UI** - Gradio

**Note:** GPU recommended but CPU works too (just slower)

## Setup & Imports
Install packages and import stuff we need.

In [1]:
# Install packages
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install -q transformers datasets easyocr langchain langgraph streamlit
!pip install -q pillow opencv-python scikit-learn pandas numpy
!pip install -q accelerate bitsandbytes
!pip install -q albumentations

# Check if we have GPU
import torch
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
else:
    print("No GPU, using CPU")

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.9/2.9 MB[0m [31m26.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.0/9.0 MB[0m [31m47.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.9/6.9 MB[0m [31m60.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m180.7/180.7 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m978.2/978.2 kB[0m [31m24.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m300.6/300.6 kB[0m [31m13.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
[?25hGPU: Tesla T4


In [2]:
# Load saved models if you have them
# This lets you skip training if you already have .pt files

import os

# Use absolute path for local storage
MODELS_DIR = '/Users/shruthisubramanian/Downloads/models'
DATA_DIR = '/Users/shruthisubramanian/Downloads/data'

# Create directories
os.makedirs(MODELS_DIR, exist_ok=True)
os.makedirs(os.path.join(DATA_DIR, 'synthetic'), exist_ok=True)

MODEL_FILES = {
    'rvl_classifier.pt': 'ViT Document Classifier (~21 MB)',
    'layoutlm_extractor.pt': 'LayoutLMv3 Field Extractor (~478 MB)',
    'anomaly_detector.pt': 'Anomaly Detector (~2 MB)'
}

# Check what models we have
existing_models = []
missing_models = []

for filename, description in MODEL_FILES.items():
    local_path = os.path.join(MODELS_DIR, filename)
    if os.path.exists(local_path):
        size_mb = os.path.getsize(local_path) / (1024*1024)
        existing_models.append((filename, size_mb))
    else:
        missing_models.append((filename, description))

print(f"Models directory: {MODELS_DIR}")
if existing_models:
    print("Found models:")
    for name, size in existing_models:
        print(f"  {name} ({size:.1f} MB)")

if missing_models:
    print("Missing models (will be created during training):")
    for name, desc in missing_models:
        print(f"  {name} - {desc}")

Models directory: /Users/shruthisubramanian/Downloads/models
Missing models (will be created during training):
  rvl_classifier.pt - ViT Document Classifier (~21 MB)
  layoutlm_extractor.pt - LayoutLMv3 Field Extractor (~478 MB)
  anomaly_detector.pt - Anomaly Detector (~2 MB)


In [None]:
# All our imports
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from transformers import (
    ViTForImageClassification,
    ViTImageProcessor,
    LayoutLMv3ForTokenClassification,
    LayoutLMv3Processor,
    AutoTokenizer
)
from datasets import load_dataset
import easyocr
import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import pandas as pd
from sklearn.ensemble import IsolationForest
from sklearn.metrics import classification_report, f1_score
from langgraph.graph import StateGraph, END
from typing import TypedDict, Dict, Any, List, Optional
import json
import os
import random
from datetime import datetime, timedelta
import warnings
warnings.filterwarnings('ignore')

# Import albumentations for augmentation
try:
    from albumentations.pytorch import ToTensorV2
    ALBUMENTATIONS_AVAILABLE = True
except ImportError:
    ALBUMENTATIONS_AVAILABLE = False

# Set device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

# Settings
CONFIG = {
    # Data settings
    'num_synthetic_receipts': 200,
    'real_data_samples': 500,

    # Model settings - using ViT-Tiny for speed
    'vit_model': 'WinKawaks/vit-tiny-patch16-224',
    'vit_epochs': 3,
    'vit_lr': 3e-4,

    # LayoutLM settings
    'layoutlm_epochs': 2,
    'layoutlm_lr': 5e-5,
    'layoutlm_train_samples': 50,

    # Training settings
    'batch_size': 32,
    'early_stopping_patience': 2,
    'augmentation_probability': 0.3,
    'class_weight_receipt': 1.5,
    'warmup_ratio': 0.1,
}

## Data Prep
Load receipt datasets (CORD, FUNSD) and make some fake receipts for training.

In [None]:
# Load datasets from HuggingFace and cache them locally

import os
import pickle
from pathlib import Path

# Dataset cache directory
DATASET_CACHE_DIR = Path("data/dataset_cache")
DATASET_CACHE_DIR.mkdir(parents=True, exist_ok=True)

# RVL-CDIP has 16 doc types - we only care about receipts/invoices
RVL_LABELS = {
    0: 'letter', 1: 'form', 2: 'email', 3: 'handwritten',
    4: 'advertisement', 5: 'scientific_report', 6: 'scientific_publication',
    7: 'specification', 8: 'file_folder', 9: 'news_article',
    10: 'budget', 11: 'invoice', 12: 'presentation', 13: 'questionnaire',
    14: 'resume', 15: 'memo'
}

RECEIPT_LABELS = [10, 11]
RECEIPT_LABEL_NAMES = ['budget', 'invoice']

loaded_datasets = {}

def check_cached_dataset(name):
    """See if we already downloaded this one"""
    cache_file = DATASET_CACHE_DIR / f"{name}_cache.pkl"
    return cache_file.exists()

def save_dataset_cache(name, train_data, val_data, metadata=None):
    """Save dataset locally so we don't have to download again"""
    cache_file = DATASET_CACHE_DIR / f"{name}_cache.pkl"
    cache_data = {
        'train': train_data,
        'val': val_data,
        'metadata': metadata or {}
    }
    with open(cache_file, 'wb') as f:
        pickle.dump(cache_data, f)

def load_dataset_cache(name):
    """Load from local cache"""
    cache_file = DATASET_CACHE_DIR / f"{name}_cache.pkl"
    with open(cache_file, 'rb') as f:
        cache_data = pickle.load(f)
    return cache_data['train'], cache_data['val'], cache_data.get('metadata', {})

# Initialize
dataset = None
val_dataset = None
use_synthetic_for_classification = False
real_images = []

from datasets import load_dataset, concatenate_datasets

# CORD - Receipt dataset
cord_train, cord_val = None, None
if check_cached_dataset("cord"):
    try:
        cord_train, cord_val, _ = load_dataset_cache("cord")
        loaded_datasets['cord'] = {'train': len(cord_train), 'val': len(cord_val)}
    except Exception as e:
        cord_train, cord_val = None, None

if cord_train is None:
    try:
        cord_train = load_dataset("naver-clova-ix/cord-v2", split="train")
        cord_val = load_dataset("naver-clova-ix/cord-v2", split="validation")

        def add_cord_label(example):
            example['label'] = 1
            example['dataset_source'] = 'cord'
            return example

        cord_train = cord_train.map(add_cord_label)
        cord_val = cord_val.map(add_cord_label)
        save_dataset_cache("cord", cord_train, cord_val, {'type': 'receipt'})
        loaded_datasets['cord'] = {'train': len(cord_train), 'val': len(cord_val)}
    except Exception as e:
        pass

# SROIE - More receipt data
sroie_train, sroie_val = None, None
if check_cached_dataset("sroie"):
    try:
        sroie_train, sroie_val, _ = load_dataset_cache("sroie")
        loaded_datasets['sroie'] = {'train': len(sroie_train), 'val': len(sroie_val)}
    except Exception as e:
        pass

if sroie_train is None:
    try:
        sroie_full = load_dataset("darentang/sroie", split="train", trust_remote_code=True)
        sroie_split = sroie_full.train_test_split(test_size=0.15, seed=42)
        sroie_train = sroie_split['train']
        sroie_val = sroie_split['test']

        def add_sroie_label(example):
            example['label'] = 1
            example['dataset_source'] = 'sroie'
            return example

        sroie_train = sroie_train.map(add_sroie_label)
        sroie_val = sroie_val.map(add_sroie_label)
        save_dataset_cache("sroie", sroie_train, sroie_val, {'type': 'receipt'})
        loaded_datasets['sroie'] = {'train': len(sroie_train), 'val': len(sroie_val)}
    except Exception as e:
        pass

# FUNSD - Form data (NOT receipts, for balance)
funsd_train, funsd_val = None, None
if check_cached_dataset("funsd"):
    try:
        funsd_train, funsd_val, _ = load_dataset_cache("funsd")
        loaded_datasets['funsd'] = {'train': len(funsd_train), 'val': len(funsd_val)}
    except Exception as e:
        pass

if funsd_train is None:
    try:
        funsd_train = load_dataset("nielsr/funsd", split="train", trust_remote_code=True)
        funsd_val = load_dataset("nielsr/funsd", split="test", trust_remote_code=True)

        def add_funsd_label(example):
            example['label'] = 0
            example['dataset_source'] = 'funsd'
            return example

        funsd_train = funsd_train.map(add_funsd_label)
        funsd_val = funsd_val.map(add_funsd_label)
        save_dataset_cache("funsd", funsd_train, funsd_val, {'type': 'form'})
        loaded_datasets['funsd'] = {'train': len(funsd_train), 'val': len(funsd_val)}
    except Exception as e:
        pass

# RVL-CDIP - Big document dataset (optional)
rvl_train, rvl_val = None, None
LOAD_RVL_CDIP = False

if LOAD_RVL_CDIP:
    if check_cached_dataset("rvl_cdip"):
        try:
            rvl_train, rvl_val, _ = load_dataset_cache("rvl_cdip")
            loaded_datasets['rvl_cdip'] = {'train': len(rvl_train), 'val': len(rvl_val)}
        except Exception as e:
            pass

    if rvl_train is None:
        try:
            rvl_train = load_dataset("aharley/rvl_cdip", split="train", trust_remote_code=True)
            rvl_val = load_dataset("aharley/rvl_cdip", split="test", trust_remote_code=True)

            num_samples = min(CONFIG['real_data_samples'], len(rvl_train))
            rvl_train = rvl_train.shuffle(seed=42).select(range(num_samples))
            rvl_val = rvl_val.shuffle(seed=42).select(range(num_samples // 4))

            def map_rvl_label(example):
                example['original_label'] = example['label']
                example['label'] = 1 if example['label'] in [10, 11] else 0
                example['dataset_source'] = 'rvl_cdip'
                return example

            rvl_train = rvl_train.map(map_rvl_label)
            rvl_val = rvl_val.map(map_rvl_label)
            save_dataset_cache("rvl_cdip", rvl_train, rvl_val, {'type': 'mixed'})
            loaded_datasets['rvl_cdip'] = {'train': len(rvl_train), 'val': len(rvl_val)}
        except Exception as e:
            pass

# Combine datasets
train_datasets = []
val_datasets = []

if cord_train is not None:
    train_datasets.append(('cord', cord_train, 1))
    val_datasets.append(('cord', cord_val, 1))

if sroie_train is not None:
    train_datasets.append(('sroie', sroie_train, 1))
    val_datasets.append(('sroie', sroie_val, 1))

if funsd_train is not None:
    train_datasets.append(('funsd', funsd_train, 0))
    val_datasets.append(('funsd', funsd_val, 0))

if rvl_train is not None:
    train_datasets.append(('rvl_cdip', rvl_train, 'mixed'))
    val_datasets.append(('rvl_cdip', rvl_val, 'mixed'))

# Print summary
for name, ds, label_type in train_datasets:
    count = len(ds)
    if label_type == 1:
        print(f"{name.upper()}: {count} samples (receipts)")
    elif label_type == 0:
        print(f"{name.upper()}: {count} samples (non-receipts)")
    else:
        print(f"{name.upper()}: {count} samples (mixed)")

# Create combined dataset if we have data
if train_datasets:
    if cord_train is not None:
        dataset = cord_train
        val_dataset = cord_val
    else:
        dataset = train_datasets[0][1]
        val_dataset = val_datasets[0][1]

    if funsd_train is not None:
        use_synthetic_for_classification = False
    else:
        use_synthetic_for_classification = True
else:
    use_synthetic_for_classification = True

AVAILABLE_DATASETS = {
    'cord': (cord_train, cord_val),
    'sroie': (sroie_train, sroie_val),
    'funsd': (funsd_train, funsd_val),
    'rvl_cdip': (rvl_train, rvl_val)
}

In [None]:
# Make fake receipts for training

class EnhancedReceiptGenerator:
    """Creates realistic looking fake receipts."""

    def __init__(self):
        self.vendors = [
            "WALMART", "TARGET", "COSTCO", "STARBUCKS", "MCDONALD'S",
            "AMAZON", "BEST BUY", "HOME DEPOT", "WHOLE FOODS", "CVS PHARMACY",
            "WALGREENS", "TRADER JOE'S", "KROGER", "SAFEWAY", "7-ELEVEN",
            "SUBWAY", "CHIPOTLE", "DOMINO'S", "PIZZA HUT", "TACO BELL",
            "WENDY'S", "BURGER KING", "DUNKIN", "PANERA BREAD", "CHICK-FIL-A"
        ]

        self.items = [
            ("COFFEE REG", 4.99), ("SANDWICH TKY", 8.49), ("MILK 1GAL", 3.99),
            ("BREAD WHL WHT", 2.49), ("EGGS LARGE 12", 5.99), ("CHICKEN BRST", 12.99),
            ("PASTA PENNE", 1.99), ("CHEESE CHEDDR", 6.49), ("APPLES FUJI", 4.49),
            ("ORANGE JUICE", 5.99), ("SOAP DISH LIQ", 3.49), ("PAPER TOWELS", 8.99),
            ("AA BATTERIES", 9.99), ("HDMI CABLE 6F", 15.99), ("USB CHARGER", 12.99),
            ("WATER 24PK", 4.99), ("CHIPS LAYS", 3.49), ("SODA 12PK", 5.99),
            ("YOGURT GREEK", 1.29), ("CEREAL CHRIOS", 4.49), ("BACON 1LB", 7.99),
            ("BUTTER UNSALT", 4.99), ("LETTUCE ROMN", 2.99), ("TOMATOES", 3.49),
            ("ONIONS YLW", 1.99), ("POTATOES 5LB", 4.99), ("RICE LONG GR", 3.99),
        ]

        self.formats = ['standard', 'minimal', 'detailed', 'wide', 'narrow']
        self.fonts = self._load_fonts()

    def _load_fonts(self):
        """Load available system fonts with fallback"""
        font_configs = []
        font_paths = [
            "/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf",
            "/usr/share/fonts/truetype/liberation/LiberationMono-Regular.ttf",
            "/usr/share/fonts/truetype/freefont/FreeMono.ttf",
        ]

        for path in font_paths:
            try:
                font = ImageFont.truetype(path, 14)
                font_bold = ImageFont.truetype(path.replace('.ttf', '-Bold.ttf').replace('Regular', 'Bold'), 16)
                font_configs.append((font, font_bold))
            except:
                try:
                    font = ImageFont.truetype(path, 14)
                    font_configs.append((font, font))
                except:
                    continue

        if not font_configs:
            default = ImageFont.load_default()
            font_configs.append((default, default))

        return font_configs

    def _random_date(self):
        """Generate random date in various formats"""
        days_ago = random.randint(0, 730)
        date = datetime.now() - timedelta(days=days_ago)
        formats = ["%m/%d/%Y", "%m/%d/%y", "%Y-%m-%d"]
        return date.strftime(random.choice(formats))

    def _random_time(self):
        """Generate random time"""
        hour = random.randint(6, 23)
        minute = random.randint(0, 59)

        if random.random() > 0.5:
            hour_12 = hour if hour <= 12 else hour - 12
            hour_12 = 12 if hour_12 == 0 else hour_12
            period = "AM" if hour < 12 else "PM"
            return f"{hour_12}:{minute:02d} {period}"
        else:
            return f"{hour:02d}:{minute:02d}"

    def generate_receipt(self, format_type=None, add_noise=True, add_wrinkles=True, save_path=None):
        """Generate a synthetic receipt with realistic variations"""
        format_type = format_type or random.choice(self.formats)
        font, font_bold = random.choice(self.fonts)

        # Variable dimensions based on format
        if format_type == 'narrow':
            width = random.randint(280, 320)
            height = random.randint(500, 700)
        elif format_type == 'wide':
            width = random.randint(450, 500)
            height = random.randint(400, 550)
        else:
            width = random.randint(350, 420)
            height = random.randint(500, 750)

        # Background color
        bg_value = random.randint(245, 255)
        bg_color = (bg_value, bg_value, random.randint(bg_value-5, bg_value))
        img = Image.new('RGB', (width, height), color=bg_color)
        draw = ImageDraw.Draw(img)

        # Text color
        text_value = random.randint(0, 40)
        text_color = (text_value, text_value, text_value)

        # Generate receipt content
        vendor = random.choice(self.vendors)
        date = self._random_date()
        time = self._random_time()

        num_items = random.randint(2, min(12, len(self.items)))
        selected_items = random.sample(self.items, num_items)

        receipt_items = []
        subtotal = 0
        for name, base_price in selected_items:
            qty = random.randint(1, 4)
            price = round(base_price * random.uniform(0.85, 1.15), 2)
            total = round(price * qty, 2)
            subtotal += total
            receipt_items.append((name, qty, price, total))

        tax_rate = random.choice([0.0, 0.04, 0.0625, 0.0725, 0.0825, 0.095, 0.10])
        tax = round(subtotal * tax_rate, 2)
        total = round(subtotal + tax, 2)

        # Draw receipt
        y_pos = random.randint(15, 30)

        # Header
        vendor_x = (width - len(vendor) * 8) // 2
        draw.text((vendor_x, y_pos), vendor, fill=text_color, font=font_bold)
        y_pos += 35

        # Address (sometimes)
        if format_type in ['detailed', 'standard'] and random.random() > 0.5:
            address = f"{random.randint(100, 9999)} {random.choice(['MAIN', 'OAK', 'ELM', 'PARK'])} ST"
            draw.text((20, y_pos), address, fill=text_color, font=font)
            y_pos += 20

        # Date and time
        if format_type == 'minimal':
            draw.text((20, y_pos), f"{date}", fill=text_color, font=font)
        else:
            draw.text((20, y_pos), f"Date: {date} Time: {time}", fill=text_color, font=font)
        y_pos += 25

        # Separator
        sep_char = random.choice(["-", "=", "*"])
        draw.text((20, y_pos), sep_char * (width // 10), fill=text_color, font=font)
        y_pos += 20

        # Items
        for name, qty, price, item_total in receipt_items:
            if format_type == 'minimal':
                line = f"{name} ${item_total:.2f}"
            elif format_type == 'detailed':
                draw.text((20, y_pos), name, fill=text_color, font=font)
                y_pos += 18
                line = f" {qty} @ ${price:.2f} = ${item_total:.2f}"
            else:
                line = f"{name:<16} {qty}x${price:.2f} ${item_total:.2f}"

            draw.text((20, y_pos), line, fill=text_color, font=font)
            y_pos += 20

        # Separator
        y_pos += 5
        draw.text((20, y_pos), sep_char * (width // 10), fill=text_color, font=font)
        y_pos += 20

        # Totals
        draw.text((20, y_pos), f"SUBTOTAL:{' ' * 10}${subtotal:.2f}", fill=text_color, font=font)
        y_pos += 20

        if tax_rate > 0:
            tax_pct = f"({tax_rate*100:.2f}%)" if format_type == 'detailed' else ""
            draw.text((20, y_pos), f"TAX {tax_pct}:{' ' * 8}${tax:.2f}", fill=text_color, font=font)
            y_pos += 20

        draw.text((20, y_pos), f"TOTAL:{' ' * 12}${total:.2f}", fill=text_color, font=font_bold)
        y_pos += 30

        # Footer
        footers = ["Thank you!", "THANK YOU FOR SHOPPING!", "Have a nice day!", "Please come again", "Save this receipt"]
        footer = random.choice(footers)
        footer_x = (width - len(footer) * 7) // 2
        draw.text((footer_x, y_pos), footer, fill=text_color, font=font)

        # Add realistic artifacts
        if add_noise:
            img = self._add_noise(img)
        if add_wrinkles:
            img = self._add_wrinkles(img)

        # Random rotation
        if random.random() > 0.7:
            angle = random.uniform(-3, 3)
            img = img.rotate(angle, fillcolor=bg_color, expand=False)

        ground_truth = {
            'vendor': vendor,
            'date': date,
            'time': time,
            'items': receipt_items,
            'subtotal': subtotal,
            'tax': tax,
            'total': total,
            'tax_rate': tax_rate,
            'format': format_type,
            'num_items': len(receipt_items)
        }

        if save_path:
            img.save(save_path)

        return img, ground_truth

    def _add_noise(self, img, intensity=None):
        """Add some random noise to make it look scanned"""
        intensity = intensity or random.uniform(2, 10)
        arr = np.array(img, dtype=np.float32)
        noise = np.random.normal(0, intensity, arr.shape)
        arr = np.clip(arr + noise, 0, 255).astype(np.uint8)
        return Image.fromarray(arr)

    def _add_wrinkles(self, img):
        """Add some fold lines and shadows"""
        arr = np.array(img, dtype=np.float32)
        h, w = arr.shape[:2]

        num_folds = random.randint(0, 3)
        for _ in range(num_folds):
            if random.random() > 0.5:
                y = random.randint(h // 5, 4 * h // 5)
                thickness = random.randint(1, 3)
                darkness = random.uniform(0.85, 0.95)
                arr[y-thickness:y+thickness, :] *= darkness
            else:
                x = random.randint(w // 5, 4 * w // 5)
                thickness = random.randint(1, 3)
                darkness = random.uniform(0.85, 0.95)
                arr[:, x-thickness:x+thickness] *= darkness

        if random.random() > 0.6:
            shadow_width = random.randint(5, 15)
            shadow_strength = random.uniform(0.9, 0.98)
            arr[:, :shadow_width] *= shadow_strength
            arr[:, -shadow_width:] *= shadow_strength

        return Image.fromarray(np.clip(arr, 0, 255).astype(np.uint8))

    def generate_batch(self, num_samples, save_dir=None):
        """Make a bunch of fake receipts"""
        receipts = []
        ground_truths = []

        for i in range(num_samples):
            save_path = f"{save_dir}/receipt_{i:04d}.png" if save_dir else None
            img, gt = self.generate_receipt(save_path=save_path)
            receipts.append(img)
            ground_truths.append(gt)

        return receipts, ground_truths

# Make the fake receipts
generator = EnhancedReceiptGenerator()

synthetic_receipts, synthetic_ground_truth = generator.generate_batch(
    num_samples=CONFIG['num_synthetic_receipts'],
    save_dir=None
)

# Show sample stats
formats_used = {}
for gt in synthetic_ground_truth:
    fmt = gt['format']
    formats_used[fmt] = formats_used.get(fmt, 0) + 1

In [None]:
# Data augmentation - mess up images a bit so model learns better

try:
    import albumentations as A
    ALBUMENTATIONS_AVAILABLE = True
except ImportError:
    ALBUMENTATIONS_AVAILABLE = False

class ReceiptAugmentation:
    """Messes up images in realistic ways - rotation, blur, shadows, etc."""

    def __init__(self, p=0.5):
        self.p = p

        if ALBUMENTATIONS_AVAILABLE:
            self.transform = A.Compose([
                A.OneOf([
                    A.Rotate(limit=15, p=0.5, border_mode=cv2.BORDER_CONSTANT, value=(255, 255, 255)),
                    A.Perspective(scale=(0.02, 0.05), p=0.3),
                    A.Affine(shear=(-10, 10), p=0.3, border_mode=cv2.BORDER_CONSTANT, cval=(255, 255, 255)),
                ], p=0.5),
                A.OneOf([
                    A.GaussNoise(var_limit=(10, 50), p=0.4),
                    A.ISONoise(color_shift=(0.01, 0.05), p=0.3),
                    A.MultiplicativeNoise(multiplier=(0.9, 1.1), p=0.3),
                ], p=0.4),
                A.OneOf([
                    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
                    A.RandomGamma(gamma_limit=(80, 120), p=0.3),
                    A.CLAHE(clip_limit=2.0, p=0.3),
                ], p=0.5),
                A.OneOf([
                    A.GaussianBlur(blur_limit=(3, 5), p=0.3),
                    A.MotionBlur(blur_limit=3, p=0.2),
                    A.MedianBlur(blur_limit=3, p=0.2),
                ], p=0.3),
                A.OneOf([
                    A.RandomShadow(shadow_roi=(0, 0, 1, 1), p=0.2),
                    A.CoarseDropout(max_holes=5, max_height=15, max_width=15, fill_value=220, p=0.2),
                ], p=0.2),
                A.OneOf([
                    A.ToGray(p=0.3),
                    A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=15, val_shift_limit=10, p=0.3),
                ], p=0.2),
                A.Resize(224, 224),
            ])
            self.val_transform = A.Compose([A.Resize(224, 224)])
        else:
            self.transform = None
            self.val_transform = None

    def __call__(self, image, is_training=True):
        """Run the augmentation on an image"""
        if isinstance(image, Image.Image):
            image = np.array(image)

        if is_training and self.transform is not None:
            augmented = self.transform(image=image)
            return augmented['image']
        elif self.val_transform is not None:
            augmented = self.val_transform(image=image)
            return augmented['image']
        else:
            img = Image.fromarray(image) if isinstance(image, np.ndarray) else image
            return np.array(img.resize((224, 224)))

RECEIPT_LABELS = {1}

class AugmentedReceiptDataset(Dataset):
    """Wraps our data for PyTorch training."""

    def __init__(self, dataset, processor, augmentation=None, is_training=True, is_receipt_labels=None):
        self.dataset = dataset
        self.processor = processor
        self.augmentation = augmentation
        self.is_training = is_training
        self.is_receipt_labels = is_receipt_labels if is_receipt_labels is not None else RECEIPT_LABELS
        self.mean = np.array([0.485, 0.456, 0.406])
        self.std = np.array([0.229, 0.224, 0.225])

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = item['image']

        if image.mode != 'RGB':
            image = image.convert('RGB')

        if self.augmentation is not None:
            image_np = self.augmentation(image, is_training=self.is_training)
        else:
            image_np = np.array(image.resize((224, 224)))

        image_np = image_np.astype(np.float32) / 255.0
        image_np = (image_np - self.mean) / self.std
        image_tensor = torch.from_numpy(image_np).permute(2, 0, 1).float()
        label = 1 if item['label'] in self.is_receipt_labels else 0

        return {
            'pixel_values': image_tensor,
            'labels': torch.tensor(label, dtype=torch.long)
        }

# Set up augmentation
receipt_augmentation = ReceiptAugmentation(p=CONFIG['augmentation_probability'])

try:
    vit_processor = ViTImageProcessor.from_pretrained(CONFIG.get('vit_model', 'WinKawaks/vit-tiny-patch16-224'))
except:
    vit_processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')

class SyntheticReceiptDataset(Dataset):
    """Dataset for when we only have fake receipts."""

    def __init__(self, receipts, ground_truths, augmentation=None, is_training=True, include_negatives=True):
        self.receipts = receipts
        self.ground_truths = ground_truths
        self.augmentation = augmentation
        self.is_training = is_training
        self.samples = []

        for i, (img, gt) in enumerate(zip(receipts, ground_truths)):
            self.samples.append({'image': img, 'label': 1, 'ground_truth': gt})

        if include_negatives:
            num_negatives = len(receipts) // 3
            for i in range(num_negatives):
                neg_img = self._generate_non_receipt()
                self.samples.append({'image': neg_img, 'label': 0, 'ground_truth': None})

        self.mean = np.array([0.485, 0.456, 0.406])
        self.std = np.array([0.229, 0.224, 0.225])

        print(f"Created dataset with {sum(1 for s in self.samples if s['label']==1)} receipts, "
              f"{sum(1 for s in self.samples if s['label']==0)} non-receipts")

    def _generate_non_receipt(self):
        """Make a random non-receipt image"""
        width, height = 400, 600
        img_type = random.choice(['blank', 'noise', 'shapes', 'text'])

        if img_type == 'blank':
            color = tuple(random.randint(180, 255) for _ in range(3))
            img = Image.new('RGB', (width, height), color=color)
        elif img_type == 'noise':
            arr = np.random.randint(150, 255, (height, width, 3), dtype=np.uint8)
            img = Image.fromarray(arr)
        elif img_type == 'shapes':
            img = Image.new('RGB', (width, height), color='white')
            draw = ImageDraw.Draw(img)
            for _ in range(random.randint(3, 10)):
                shape = random.choice(['rectangle', 'ellipse', 'line'])
                color = tuple(random.randint(0, 200) for _ in range(3))
                x1, y1 = random.randint(0, width), random.randint(0, height)
                x2, y2 = random.randint(0, width), random.randint(0, height)
                if shape == 'rectangle':
                    draw.rectangle([min(x1,x2), min(y1,y2), max(x1,x2), max(y1,y2)], outline=color)
                elif shape == 'ellipse':
                    draw.ellipse([min(x1,x2), min(y1,y2), max(x1,x2), max(y1,y2)], outline=color)
                else:
                    draw.line([x1, y1, x2, y2], fill=color, width=2)
        else:
            img = Image.new('RGB', (width, height), color='white')
            draw = ImageDraw.Draw(img)
            try:
                font = ImageFont.load_default()
            except:
                font = None
            words = ["Lorem", "ipsum", "dolor", "sit", "amet", "document", "page", "file"]
            for _ in range(random.randint(5, 15)):
                text = " ".join(random.choices(words, k=random.randint(2, 6)))
                x = random.randint(10, width - 100)
                y = random.randint(10, height - 30)
                draw.text((x, y), text, fill='black', font=font)

        return img

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        image = sample['image']

        if isinstance(image, np.ndarray):
            image = Image.fromarray(image)
        if image.mode != 'RGB':
            image = image.convert('RGB')

        if self.augmentation is not None:
            image_np = self.augmentation(image, is_training=self.is_training)
        else:
            image_np = np.array(image.resize((224, 224)))

        image_np = image_np.astype(np.float32) / 255.0
        image_np = (image_np - self.mean) / self.std
        image_tensor = torch.from_numpy(image_np).permute(2, 0, 1).float()

        return {
            'pixel_values': image_tensor,
            'labels': torch.tensor(sample['label'], dtype=torch.long)
        }

## ViT Classifier
Train a Vision Transformer to tell receipts from other docs.

In [None]:
# ViT document classifier

class DocumentClassifier:
    """Uses ViT-Tiny to classify docs as receipt or not."""

    def __init__(self, num_labels=2, pretrained=None):
        self.num_labels = num_labels
        self.pretrained = pretrained or CONFIG.get('vit_model', 'WinKawaks/vit-tiny-patch16-224')
        self.model = None
        self.processor = None
        self.best_val_acc = 0
        self.model_path = os.path.join(MODELS_DIR, 'rvl_classifier.pt')

    def load_model(self):
        """Load the pretrained ViT and set it up for 2-class output"""
        try:
            self.processor = ViTImageProcessor.from_pretrained(self.pretrained)
        except:
            self.processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')

        self.model = ViTForImageClassification.from_pretrained(
            self.pretrained,
            num_labels=self.num_labels,
            ignore_mismatched_sizes=True
        )
        self.model = self.model.to(DEVICE)
        return self.model

    def train(self, train_loader, val_loader, epochs=None, lr=None, class_weights=None,
              warmup_ratio=None, patience=None, weight_decay=0.01, max_grad_norm=1.0):
        """Train the model with early stopping"""
        epochs = epochs or CONFIG['vit_epochs']
        lr = lr or CONFIG['vit_lr']
        warmup_ratio = warmup_ratio or CONFIG['warmup_ratio']
        patience = patience or CONFIG['early_stopping_patience']

        if self.model is None:
            self.load_model()

        optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr, weight_decay=weight_decay)
        total_steps = len(train_loader) * epochs

        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer, max_lr=lr * 10, total_steps=total_steps,
            pct_start=warmup_ratio, anneal_strategy='cos',
            div_factor=25, final_div_factor=1000
        )

        if class_weights is not None:
            criterion = nn.CrossEntropyLoss(weight=class_weights)
        else:
            criterion = nn.CrossEntropyLoss()

        use_amp = torch.cuda.is_available()
        scaler = torch.cuda.amp.GradScaler() if use_amp else None

        self.best_val_acc = 0
        patience_counter = 0
        history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': [], 'lr': []}

        for epoch in range(epochs):
            self.model.train()
            train_loss = 0
            train_correct = 0
            train_total = 0

            for batch_idx, batch in enumerate(train_loader):
                pixel_values = batch['pixel_values'].to(DEVICE)
                labels = batch['labels'].to(DEVICE)
                optimizer.zero_grad()

                if use_amp:
                    with torch.cuda.amp.autocast():
                        outputs = self.model(pixel_values=pixel_values)
                        loss = criterion(outputs.logits, labels)
                    scaler.scale(loss).backward()
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    outputs = self.model(pixel_values=pixel_values)
                    loss = criterion(outputs.logits, labels)
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
                    optimizer.step()

                scheduler.step()
                train_loss += loss.item()
                _, predicted = outputs.logits.max(1)
                train_total += labels.size(0)
                train_correct += predicted.eq(labels).sum().item()

            val_loss, val_acc = self.evaluate(val_loader, criterion)
            train_acc = 100 * train_correct / train_total

            history['train_loss'].append(train_loss / len(train_loader))
            history['val_loss'].append(val_loss)
            history['train_acc'].append(train_acc)
            history['val_acc'].append(val_acc)
            history['lr'].append(scheduler.get_last_lr()[0])

            if val_acc > self.best_val_acc:
                self.best_val_acc = val_acc
                patience_counter = 0
                self.save_model(self.model_path)
            else:
                patience_counter += 1

            if patience_counter >= patience:
                break

        self.load_weights(self.model_path)
        return history

    def evaluate(self, val_loader, criterion=None):
        """Check how well we're doing on val data"""
        if criterion is None:
            criterion = nn.CrossEntropyLoss()

        self.model.eval()
        val_loss = 0
        correct = 0
        total = 0

        with torch.no_grad():
            for batch in val_loader:
                pixel_values = batch['pixel_values'].to(DEVICE)
                labels = batch['labels'].to(DEVICE)
                outputs = self.model(pixel_values=pixel_values)
                loss = criterion(outputs.logits, labels)
                val_loss += loss.item()
                _, predicted = outputs.logits.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

        return val_loss / len(val_loader), 100 * correct / total

    def predict(self, image):
        """Check if an image is a receipt"""
        self.model.eval()
        if image.mode != 'RGB':
            image = image.convert('RGB')

        inputs = self.processor(images=image, return_tensors="pt")
        pixel_values = inputs['pixel_values'].to(DEVICE)

        with torch.no_grad():
            outputs = self.model(pixel_values=pixel_values)
            probs = torch.softmax(outputs.logits, dim=1)
            receipt_prob = probs[0][1].item()

        return {
            'is_receipt': receipt_prob > 0.5,
            'confidence': receipt_prob,
            'label': 'receipt' if receipt_prob > 0.5 else 'other'
        }

    def save_model(self, path):
        """Save model weights"""
        torch.save(self.model.state_dict(), path)
        print(f"Model saved to: {path}")

    def load_weights(self, path):
        """Load model weights"""
        if self.model is None:
            self.load_model()
        self.model.load_state_dict(torch.load(path, map_location=DEVICE))
        self.model.eval()
        print(f"Model loaded from: {path}")

# Initialize classifier
doc_classifier = DocumentClassifier(num_labels=2)
doc_classifier.load_model()

In [None]:
# Train the classifier (or load if we already have it)

VIT_MODEL_PATH = os.path.join(MODELS_DIR, 'rvl_classifier.pt')
SKIP_TRAINING_IF_EXISTS = True

# Create data loaders if we have data
if dataset is not None:
    train_dataset = AugmentedReceiptDataset(dataset, vit_processor, receipt_augmentation, is_training=True)
    val_dataset_wrapped = AugmentedReceiptDataset(val_dataset, vit_processor, receipt_augmentation, is_training=False)
else:
    train_dataset = SyntheticReceiptDataset(synthetic_receipts, synthetic_ground_truth, receipt_augmentation, is_training=True)
    val_size = len(synthetic_receipts) // 5
    val_dataset_wrapped = SyntheticReceiptDataset(synthetic_receipts[:val_size], synthetic_ground_truth[:val_size],
                                                   receipt_augmentation, is_training=False, include_negatives=True)

train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset_wrapped, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=0)

# Class weights for imbalanced data
class_weights = torch.tensor([1.0, CONFIG['class_weight_receipt']], dtype=torch.float32).to(DEVICE)

# Check if trained model already exists
if SKIP_TRAINING_IF_EXISTS and os.path.exists(VIT_MODEL_PATH):
    print(f"Loading model from: {VIT_MODEL_PATH}")
    doc_classifier.load_weights(VIT_MODEL_PATH)
    history = None
else:
    print(f"Training model, will save to: {VIT_MODEL_PATH}")
    history = doc_classifier.train(
        train_loader=train_loader,
        val_loader=val_loader,
        epochs=CONFIG['vit_epochs'],
        lr=CONFIG['vit_lr'],
        class_weights=class_weights,
        warmup_ratio=CONFIG['warmup_ratio'],
        patience=CONFIG['early_stopping_patience']
    )
    doc_classifier.save_model(VIT_MODEL_PATH)
    print("Training complete!")

# Test on synthetic receipts
correct = 0
for i in range(min(10, len(synthetic_receipts))):
    result = doc_classifier.predict(synthetic_receipts[i])
    if result['is_receipt']:
        correct += 1
    if i < 3:
        print(f"Receipt {i+1}: {result['label']} (confidence: {result['confidence']:.2%})")

## EasyOCR
Set up OCR to read text from receipt images.

In [None]:
# OCR wrapper class

class ReceiptOCR:
    """Wrapper around EasyOCR with some receipt-specific tricks"""

    def __init__(self, languages=['en'], gpu=True):
        self.reader = easyocr.Reader(languages, gpu=gpu and torch.cuda.is_available())
        self.languages = languages

    def extract_text(self, image, detail=1):
        """Pull text out of an image"""
        if isinstance(image, Image.Image):
            image = np.array(image)
        return self.reader.readtext(image, detail=detail)

    def extract_with_positions(self, image):
        """Get text with bounding boxes"""
        if isinstance(image, Image.Image):
            image = np.array(image)

        results = self.extract_text(image, detail=1)
        extracted = []

        for bbox, text, conf in results:
            x_center = (bbox[0][0] + bbox[2][0]) / 2
            y_center = (bbox[0][1] + bbox[2][1]) / 2
            extracted.append({
                'text': text,
                'confidence': conf,
                'bbox': bbox,
                'x_center': x_center,
                'y_center': y_center,
                'width': bbox[2][0] - bbox[0][0],
                'height': bbox[2][1] - bbox[0][1]
            })

        extracted.sort(key=lambda x: x['y_center'])
        return extracted

    def postprocess_receipt(self, ocr_results):
        """Try to find vendor, date, total from OCR text"""
        import re
        full_text = ' '.join([r['text'] for r in ocr_results])

        # Extract date
        date_patterns = [
            r'\d{1,2}/\d{1,2}/\d{2,4}',
            r'\d{1,2}-\d{1,2}-\d{2,4}',
            r'\d{4}-\d{2}-\d{2}',
        ]
        date = None
        for pattern in date_patterns:
            match = re.search(pattern, full_text)
            if match:
                date = match.group()
                break

        # Extract amounts
        amount_pattern = r'\$?\d+\.\d{2}'
        amounts = re.findall(amount_pattern, full_text)
        amounts = [float(a.replace('$', '')) for a in amounts]
        total = max(amounts) if amounts else 0.0

        # Extract vendor
        vendor = None
        for r in ocr_results[:3]:
            text = r['text'].strip()
            if len(text) > 2 and text.isupper():
                vendor = text
                break

        # Extract time
        time_pattern = r'\d{1,2}:\d{2}(?::\d{2})?(?:\s*[AP]M)?'
        time_match = re.search(time_pattern, full_text, re.IGNORECASE)
        time = time_match.group() if time_match else None

        return {
            'vendor': vendor,
            'date': date,
            'time': time,
            'total': total,
            'all_amounts': amounts,
            'raw_text': full_text,
            'num_lines': len(ocr_results)
        }

    def visualize_results(self, image, ocr_results):
        """Draw boxes around detected text"""
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image)

        img_draw = image.copy()
        draw = ImageDraw.Draw(img_draw)

        for r in ocr_results:
            bbox = r['bbox']
            points = [(int(p[0]), int(p[1])) for p in bbox]
            draw.polygon(points, outline='red', width=2)
            draw.text((points[0][0], points[0][1] - 15),
                      f"{r['text'][:20]} ({r['confidence']:.2f})", fill='blue')

        return img_draw

# Initialize OCR
receipt_ocr = ReceiptOCR(languages=['en'], gpu=True)

# Test on synthetic receipt
test_results = receipt_ocr.extract_with_positions(synthetic_receipts[0])
extracted_data = receipt_ocr.postprocess_receipt(test_results)
print(f"Vendor: {extracted_data['vendor']}")
print(f"Date: {extracted_data['date']}")
print(f"Total: ${extracted_data['total']:.2f}")

## LayoutLMv3 Field Extractor
This model finds vendor, date, and total in receipts.

In [None]:
# LayoutLMv3 for finding fields in receipts

FIELD_LABELS = {
    'O': 0,
    'B-VENDOR': 1,
    'I-VENDOR': 2,
    'B-DATE': 3,
    'I-DATE': 4,
    'B-TOTAL': 5,
    'I-TOTAL': 6,
}
NUM_LABELS = len(FIELD_LABELS)

LAYOUTLM_MODEL_PATH = os.path.join(MODELS_DIR, 'layoutlm_extractor.pt')

class LayoutLMExtractor:
    """Uses LayoutLMv3 to find vendor/date/total in receipts"""

    def __init__(self, num_labels=NUM_LABELS, pretrained="microsoft/layoutlmv3-base"):
        self.num_labels = num_labels
        self.pretrained = pretrained
        self.model = None
        self.processor = None
        self.label_map = FIELD_LABELS
        self.id2label = {v: k for k, v in FIELD_LABELS.items()}
        self.model_path = LAYOUTLM_MODEL_PATH

    def load_model(self):
        """Load LayoutLMv3 from HuggingFace"""
        self.processor = LayoutLMv3Processor.from_pretrained(self.pretrained, apply_ocr=False)
        self.model = LayoutLMv3ForTokenClassification.from_pretrained(
            self.pretrained,
            num_labels=self.num_labels,
            ignore_mismatched_sizes=True
        )
        self.model = self.model.to(DEVICE)
        return self.model

    def prepare_inputs(self, image, ocr_results):
        """Format image + OCR for LayoutLMv3"""
        if image.mode != 'RGB':
            image = image.convert('RGB')

        words = []
        boxes = []
        width, height = image.size

        for r in ocr_results:
            text = r['text'].strip()
            if not text:
                continue

            bbox = r['bbox']
            x0 = int(min(p[0] for p in bbox) * 1000 / width)
            y0 = int(min(p[1] for p in bbox) * 1000 / height)
            x1 = int(max(p[0] for p in bbox) * 1000 / width)
            y1 = int(max(p[1] for p in bbox) * 1000 / height)
            x0, y0, x1, y1 = [max(0, min(1000, v)) for v in [x0, y0, x1, y1]]

            words.append(text)
            boxes.append([x0, y0, x1, y1])

        if not words:
            words = [""]
            boxes = [[0, 0, 0, 0]]

        encoding = self.processor(
            image, words, boxes=boxes,
            return_tensors="pt", truncation=True,
            max_length=512, padding="max_length"
        )
        return encoding

    def predict(self, image, ocr_results):
        """Find vendor/date/total in an image"""
        self.model.eval()
        encoding = self.prepare_inputs(image, ocr_results)

        for k, v in encoding.items():
            if isinstance(v, torch.Tensor):
                encoding[k] = v.to(DEVICE)

        with torch.no_grad():
            outputs = self.model(**encoding)
            predictions = torch.argmax(outputs.logits, dim=-1)

        pred_labels = predictions[0].cpu().numpy()
        extracted = {'vendor': [], 'date': [], 'total': []}

        words = [r['text'].strip() for r in ocr_results if r['text'].strip()]

        for i, (word, label_id) in enumerate(zip(words, pred_labels[1:len(words)+1])):
            label = self.id2label.get(label_id, 'O')
            if 'VENDOR' in label:
                extracted['vendor'].append(word)
            elif 'DATE' in label:
                extracted['date'].append(word)
            elif 'TOTAL' in label:
                extracted['total'].append(word)

        return {
            'vendor': ' '.join(extracted['vendor']) if extracted['vendor'] else None,
            'date': ' '.join(extracted['date']) if extracted['date'] else None,
            'total': ' '.join(extracted['total']) if extracted['total'] else None,
        }

    def train(self, train_data, epochs=3, lr=5e-5):
        """Train on labeled receipts"""
        if self.model is None:
            self.load_model()

        optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr)
        self.model.train()

        for epoch in range(epochs):
            total_loss = 0
            for image, ocr_results, labels in train_data:
                encoding = self.prepare_inputs(image, ocr_results)
                for k, v in encoding.items():
                    if isinstance(v, torch.Tensor):
                        encoding[k] = v.to(DEVICE)

                encoding['labels'] = torch.tensor(labels, device=DEVICE).unsqueeze(0)
                optimizer.zero_grad()
                outputs = self.model(**encoding)
                loss = outputs.loss
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            print(f"Epoch {epoch+1}/{epochs} | Loss: {total_loss/len(train_data):.4f}")

        self.save_model(self.model_path)
        return self.model

    def save_model(self, path):
        """Save model weights"""
        torch.save(self.model.state_dict(), path)
        print(f"Saved to {path}")

    def load_weights(self, path):
        """Load model weights"""
        if self.model is None:
            self.load_model()
        self.model.load_state_dict(torch.load(path, map_location=DEVICE))
        self.model.eval()
        print(f"Loaded from {path}")

# Initialize extractor
field_extractor = LayoutLMExtractor(num_labels=NUM_LABELS)

SKIP_LAYOUTLM_TRAINING = True

if SKIP_LAYOUTLM_TRAINING and os.path.exists(LAYOUTLM_MODEL_PATH):
    print(f"Loading LayoutLM from: {LAYOUTLM_MODEL_PATH}")
    field_extractor.load_weights(LAYOUTLM_MODEL_PATH)
else:
    print(f"Initializing LayoutLM, will save to: {LAYOUTLM_MODEL_PATH}")
    field_extractor.load_model()

In [None]:
# Train LayoutLMv3 for NER (finding vendor/date/total)

import os

class ReceiptNERDataset(torch.utils.data.Dataset):
    """Dataset for training LayoutLMv3"""

    def __init__(self, receipts, ground_truths, ocr_engine, processor):
        self.receipts = receipts
        self.ground_truths = ground_truths
        self.ocr = ocr_engine
        self.processor = processor

    def __len__(self):
        return len(self.receipts)

    def __getitem__(self, idx):
        image = self.receipts[idx]
        gt = self.ground_truths[idx]

        if image.mode != 'RGB':
            image = image.convert('RGB')

        ocr_results = self.ocr.extract_with_positions(image)
        if not ocr_results:
            return None

        words = []
        boxes = []
        labels = []
        width, height = image.size

        for r in ocr_results:
            text = r['text'].strip()
            if not text:
                continue

            bbox = r['bbox']
            x0 = int(min(p[0] for p in bbox) * 1000 / width)
            y0 = int(min(p[1] for p in bbox) * 1000 / height)
            x1 = int(max(p[0] for p in bbox) * 1000 / width)
            y1 = int(max(p[1] for p in bbox) * 1000 / height)
            x0, y0, x1, y1 = [max(0, min(1000, v)) for v in [x0, y0, x1, y1]]

            words.append(text)
            boxes.append([x0, y0, x1, y1])

            # Assign label based on ground truth
            label = 0  # O
            text_upper = text.upper()

            if gt['vendor'] and text_upper in gt['vendor'].upper():
                label = 1  # B-VENDOR
            elif gt['date'] and gt['date'] in text:
                label = 3  # B-DATE
            elif gt['total']:
                total_str = f"{gt['total']:.2f}"
                if total_str in text or text.replace('$', '') == total_str:
                    label = 5  # B-TOTAL

            labels.append(label)

        if not words:
            return None

        try:
            encoding = self.processor(
                image, words, boxes=boxes,
                return_tensors="pt", truncation=True,
                max_length=512, padding="max_length"
            )

            label_tensor = torch.zeros(512, dtype=torch.long)
            label_tensor[:len(labels)] = torch.tensor(labels[:512])

            return {
                'input_ids': encoding['input_ids'].squeeze(0),
                'attention_mask': encoding['attention_mask'].squeeze(0),
                'bbox': encoding['bbox'].squeeze(0),
                'pixel_values': encoding['pixel_values'].squeeze(0),
                'labels': label_tensor
            }
        except Exception as e:
            return None

def collate_fn(batch):
    """Skip None samples and stack the rest"""
    batch = [b for b in batch if b is not None]
    if not batch:
        return None
    return {
        'input_ids': torch.stack([b['input_ids'] for b in batch]),
        'attention_mask': torch.stack([b['attention_mask'] for b in batch]),
        'bbox': torch.stack([b['bbox'] for b in batch]),
        'pixel_values': torch.stack([b['pixel_values'] for b in batch]),
        'labels': torch.stack([b['labels'] for b in batch])
    }

def train_layoutlm(model, train_loader, epochs=3, lr=5e-5):
    """Run the training loop"""
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    model.train()

    for epoch in range(epochs):
        total_loss = 0
        batch_count = 0

        for batch in train_loader:
            if batch is None:
                continue

            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            bbox = batch['bbox'].to(DEVICE)
            pixel_values = batch['pixel_values'].to(DEVICE)
            labels = batch['labels'].to(DEVICE)

            optimizer.zero_grad()
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                bbox=bbox,
                pixel_values=pixel_values,
                labels=labels
            )

            loss = outputs.loss
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            batch_count += 1

        if batch_count > 0:
            print(f"Epoch {epoch+1}/{epochs} | Loss: {total_loss/batch_count:.4f}")

    return model

# Training logic
print("Training LayoutLMv3...")
print(f"Device: {DEVICE}")

if not SKIP_LAYOUTLM_TRAINING or not os.path.exists(LAYOUTLM_MODEL_PATH):
    if torch.cuda.is_available():
        train_samples = min(CONFIG['layoutlm_train_samples'], len(synthetic_receipts))
        print(f"Using {train_samples} synthetic receipts for training")

        ner_dataset = ReceiptNERDataset(
            synthetic_receipts[:train_samples],
            synthetic_ground_truth[:train_samples],
            receipt_ocr,
            field_extractor.processor
        )

        ner_loader = DataLoader(ner_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

        try:
            train_layoutlm(field_extractor.model, ner_loader, epochs=CONFIG['layoutlm_epochs'])
            field_extractor.save_model(LAYOUTLM_MODEL_PATH)
            print("LayoutLMv3 training complete!")
        except Exception as e:
            print(f"Training failed: {e}")
    else:
        print("No GPU - skipping LayoutLMv3 training")
        field_extractor.save_model(LAYOUTLM_MODEL_PATH)

# Test extractor
print("Testing LayoutLMv3...")
try:
    test_ocr = receipt_ocr.extract_with_positions(synthetic_receipts[0])
    test_result = field_extractor.predict(synthetic_receipts[0], test_ocr)
    print(f"Vendor: {test_result['vendor']}")
    print(f"Date: {test_result['date']}")
    print(f"Total: {test_result['total']}")
except Exception as e:
    print(f"Test failed: {e}")

In [None]:
# Better field extraction with regex patterns

import re

class HybridFieldExtractor:
    """Finds vendor/date/total using regex patterns."""

    def __init__(self):
        self.date_patterns = [
            r'\b(\d{1,2}[/\-\.]\d{1,2}[/\-\.]\d{2,4})\b',
            r'\b(\d{4}[/\-\.]\d{1,2}[/\-\.]\d{1,2})\b',
            r'\b((?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)[a-z]*\.?\s*\d{1,2},?\s*\d{2,4})\b',
        ]
        self.time_patterns = [r'\b(\d{1,2}:\d{2}(?::\d{2})?\s*(?:AM|PM|am|pm)?)\b']

        self.final_total_keywords = ['GRAND TOTAL', 'AMOUNT DUE', 'BALANCE DUE', 'TOTAL DUE']
        self.total_keywords = ['TOTAL', 'AMOUNT', 'SUM']
        self.exclude_keywords = ['SUBTOTAL', 'SUB TOTAL', 'TAX', 'TIP', 'DISCOUNT', 'CHANGE']

    def clean_amount_text(self, text):
        """Fix common OCR mistakes"""
        cleaned = text.strip()
        cleaned = re.sub(r'^[Ss](\d)', r'$\1', cleaned)
        cleaned = re.sub(r'(\d),(\d{2})$', r'\1.\2', cleaned)
        cleaned = re.sub(r'(?<=\d)[Oo](?=\d)', '0', cleaned)
        return cleaned

    def extract_amount(self, text):
        """Pull a dollar amount from text"""
        cleaned = self.clean_amount_text(text)

        patterns = [
            (r'\$\s*(\d{1,3}(?:,\d{3})+\.\d{2})', True),
            (r'(?<!\d)(\d{1,3}(?:,\d{3})+\.\d{2})(?!\d)', True),
            (r'\$\s*(\d{4,}\.\d{2})', False),
            (r'\$\s*(\d{1,3}\.\d{2})', False),
            (r'(?<![,\d])(\d+\.\d{2})(?![,\d])', False),
        ]

        for pattern, _ in patterns:
            match = re.search(pattern, cleaned, re.IGNORECASE)
            if match:
                try:
                    amount_str = match.group(1).replace(',', '')
                    amount = float(amount_str)
                    if amount <= 100000:
                        return amount
                except ValueError:
                    continue
        return None

    def find_total_amount(self, ocr_results):
        """Figure out which amount is the actual total"""
        amounts = {'total': None, 'subtotal': None, 'tax': None, 'all_amounts': [], 'method': None}
        amount_candidates = []

        for idx, r in enumerate(ocr_results):
            text = r['text']
            text_upper = text.upper()
            amount = self.extract_amount(text)

            if amount is None and idx + 1 < len(ocr_results):
                amount = self.extract_amount(ocr_results[idx + 1]['text'])

            if amount is not None and amount > 0:
                is_excluded = any(kw in text_upper for kw in self.exclude_keywords)
                is_final_total = any(kw in text_upper for kw in self.final_total_keywords)
                is_total = any(kw in text_upper for kw in self.total_keywords)
                position_score = idx / max(len(ocr_results), 1)

                amount_candidates.append({
                    'amount': amount, 'text': text, 'position': idx,
                    'position_score': position_score, 'is_final_total': is_final_total,
                    'is_total': is_total, 'is_excluded': is_excluded
                })
                amounts['all_amounts'].append(amount)

                if 'SUBTOTAL' in text_upper:
                    amounts['subtotal'] = amount
                elif 'TAX' in text_upper:
                    amounts['tax'] = amount

        # Priority selection
        final_candidates = [c for c in amount_candidates if c['is_final_total']]
        if final_candidates:
            amounts['total'] = max(final_candidates, key=lambda x: x['position_score'])['amount']
            amounts['method'] = 'final_keyword'
            return amounts

        total_candidates = [c for c in amount_candidates if c['is_total'] and not c['is_excluded']]
        if total_candidates:
            amounts['total'] = max(total_candidates, key=lambda x: x['position_score'])['amount']
            amounts['method'] = 'total_keyword'
            return amounts

        bottom_half = [c for c in amount_candidates if c['position_score'] > 0.5 and not c['is_excluded']]
        if bottom_half:
            amounts['total'] = max(bottom_half, key=lambda x: x['amount'])['amount']
            amounts['method'] = 'bottom_largest'
            return amounts

        if amount_candidates:
            amounts['total'] = max(amount_candidates, key=lambda x: x['amount'])['amount']
            amounts['method'] = 'fallback_largest'

        return amounts

    def extract(self, ocr_results, image=None):
        """Get all the fields from OCR results"""
        if not ocr_results:
            return {'vendor': None, 'date': None, 'time': None, 'total': None,
                    'subtotal': None, 'tax': None, 'items': [], 'raw_text': ''}

        all_text = '\n'.join([r['text'] for r in ocr_results])
        result = {'vendor': None, 'date': None, 'time': None, 'total': None,
                  'subtotal': None, 'tax': None, 'items': [], 'raw_text': all_text}

        # Vendor (first non-numeric line)
        for r in ocr_results[:5]:
            line = r['text'].strip()
            if not re.match(r'^[\d\s\-\/\.\:\$\,]+$', line) and len(line) > 2:
                result['vendor'] = line
                break

        # Date
        for pattern in self.date_patterns:
            match = re.search(pattern, all_text, re.IGNORECASE)
            if match:
                result['date'] = match.group(1)
                break

        # Time
        for pattern in self.time_patterns:
            match = re.search(pattern, all_text, re.IGNORECASE)
            if match:
                result['time'] = match.group(1)
                break

        # Amounts
        amounts = self.find_total_amount(ocr_results)
        result['total'] = amounts['total']
        result['subtotal'] = amounts['subtotal']
        result['tax'] = amounts['tax']

        return result

    def predict(self, image, ocr_results):
        """Alias for extract()"""
        return self.extract(ocr_results, image)

# Initialize
hybrid_extractor = HybridFieldExtractor()

# Test
print("Testing amount extraction...")
test_cases = ["$11,812.50", "$1,234.56", "$812.50", "TOTAL: $99.99"]
for test in test_cases:
    result = hybrid_extractor.extract_amount(test)
    print(f"  '{test}' -> ${result:.2f}" if result else f"  '{test}' -> None")

# Test on synthetic receipts
print("\nTesting on synthetic receipts...")
correct = 0
for i in range(min(5, len(synthetic_receipts))):
    test_ocr = receipt_ocr.extract_with_positions(synthetic_receipts[i])
    extracted = hybrid_extractor.extract(test_ocr)
    gt = synthetic_ground_truth[i]
    if extracted['total'] and abs(extracted['total'] - gt['total']) < 0.01:
        correct += 1
print(f"Total accuracy: {correct}/5")

## Anomaly Detection
Catch weird receipts (crazy amounts, missing fields, etc).

In [None]:
# Anomaly detector

ANOMALY_MODEL_PATH = os.path.join(MODELS_DIR, 'anomaly_detector.pt')

class ReceiptAnomalyDetector:
    """
    Uses Isolation Forest to flag weird receipts.
    Stuff like $50k totals or missing vendors.
    """

    def __init__(self, contamination=0.1):
        self.contamination = contamination
        self.model = IsolationForest(
            n_estimators=100,
            contamination=contamination,
            random_state=42
        )
        self.is_fitted = False
        self.feature_names = ['amount', 'vendor_len', 'date_valid', 'num_items', 'hour']
        self.model_path = ANOMALY_MODEL_PATH

    def extract_features(self, receipt_data: dict) -> np.ndarray:
        """Turn receipt data into numbers for the model"""
        import re
        from datetime import datetime

        # Amount feature
        amount = receipt_data.get('total', 0)
        if isinstance(amount, str):
            amount = float(re.sub(r'[^\d.]', '', amount) or 0)

        # Vendor length (proxy for validity)
        vendor = receipt_data.get('vendor', '') or ''
        vendor_len = len(vendor)

        # Date validity (1 if valid date, 0 otherwise)
        date_str = receipt_data.get('date', '')
        date_valid = 0
        if date_str:
            for fmt in ['%m/%d/%Y', '%m/%d/%y', '%Y-%m-%d', '%d-%m-%Y']:
                try:
                    parsed = datetime.strptime(date_str, fmt)
                    date_valid = 1
                    break
                except:
                    continue

        # Number of items (if available)
        num_items = len(receipt_data.get('items', [])) if 'items' in receipt_data else 3

        # Hour of transaction (if time available)
        time_str = receipt_data.get('time', '')
        hour = 12  # Default
        if time_str:
            try:
                hour = int(time_str.split(':')[0])
            except:
                pass

        return np.array([[amount, vendor_len, date_valid, num_items, hour]])

    def fit(self, receipt_data_list: list):
        """Train on a bunch of receipts"""
        features = []
        for data in receipt_data_list:
            feat = self.extract_features(data)
            features.append(feat[0])

        X = np.array(features)

        # Handle edge cases
        if len(X) < 10:
            print("Not enough samples for anomaly detection")
            synthetic_normal = np.random.normal(
                loc=X.mean(axis=0) if len(X) > 0 else [50, 10, 1, 5, 14],
                scale=X.std(axis=0) if len(X) > 0 else [20, 5, 0.1, 2, 3],
                size=(100, 5)
            )
            X = np.vstack([X, synthetic_normal]) if len(X) > 0 else synthetic_normal

        self.model.fit(X)
        self.is_fitted = True

        return self

    def predict(self, receipt_data: dict) -> dict:
        """Check if a receipt looks suspicious"""
        if not self.is_fitted:
            raise ValueError("Model not fitted. Call fit() first.")

        features = self.extract_features(receipt_data)

        # Get anomaly score (-1 for anomaly, 1 for normal)
        prediction = self.model.predict(features)[0]
        score = self.model.decision_function(features)[0]

        # Identify reasons for anomaly
        reasons = []
        amount = features[0][0]
        vendor_len = features[0][1]
        date_valid = features[0][2]

        if amount > 1000:
            reasons.append(f"High amount: ${amount:.2f}")
        elif amount < 1:
            reasons.append(f"Suspiciously low amount: ${amount:.2f}")

        if vendor_len < 2:
            reasons.append("Missing or invalid vendor")

        if date_valid == 0:
            reasons.append("Invalid or missing date")

        return {
            'is_anomaly': prediction == -1,
            'score': float(score),
            'prediction': 'ANOMALY' if prediction == -1 else 'NORMAL',
            'reasons': reasons,
            'features': dict(zip(self.feature_names, features[0]))
        }

    def save_model(self, path: str):
        """Save the trained model"""
        model_data = {
            'model': self.model,
            'is_fitted': self.is_fitted,
            'contamination': self.contamination,
            'feature_names': self.feature_names
        }
        torch.save(model_data, path)
        print(f"Anomaly detector saved to: {path}")

    def load_model(self, path: str):
        """Load a trained model"""
        # weights_only=False needed for sklearn models
        model_data = torch.load(path, map_location='cpu', weights_only=False)
        self.model = model_data['model']
        self.is_fitted = model_data['is_fitted']
        self.contamination = model_data['contamination']
        self.feature_names = model_data['feature_names']
        print(f"Anomaly detector loaded from: {path}")


# Initialize and train anomaly detector
SKIP_ANOMALY_TRAINING = True  # Set to False to force retraining

anomaly_detector = ReceiptAnomalyDetector(contamination=0.1)

if SKIP_ANOMALY_TRAINING and os.path.exists(ANOMALY_MODEL_PATH):
    print(f"Loading anomaly detector from: {ANOMALY_MODEL_PATH}")
    anomaly_detector.load_model(ANOMALY_MODEL_PATH)
else:
    print(f"Training anomaly detector, will save to: {ANOMALY_MODEL_PATH}")
    # Create training data from synthetic receipts
    training_data = []
    for gt in synthetic_ground_truth:
        training_data.append({
            'vendor': gt['vendor'],
            'date': gt['date'],
            'time': gt['time'],
            'total': gt['total'],
            'items': gt['items']
        })

    # Add some anomalous samples for training
    anomalous_samples = [
        {'vendor': '', 'date': 'invalid', 'total': 50000, 'time': '25:00'},
        {'vendor': 'X', 'date': '', 'total': 0.01, 'time': ''},
        {'vendor': 'SUSPICIOUS VENDOR', 'date': '99/99/9999', 'total': -100, 'time': ''},
    ]
    training_data.extend(anomalous_samples)

    # Fit the model
    anomaly_detector.fit(training_data)

    # Save model
    anomaly_detector.save_model(ANOMALY_MODEL_PATH)

# Test on normal and anomalous receipts
print("Testing anomaly detection...")

test_normal = {
    'vendor': synthetic_ground_truth[0]['vendor'],
    'date': synthetic_ground_truth[0]['date'],
    'time': synthetic_ground_truth[0]['time'],
    'total': synthetic_ground_truth[0]['total'],
    'items': synthetic_ground_truth[0]['items']
}

normal_result = anomaly_detector.predict(test_normal)
print(f"Normal receipt: {normal_result['prediction']}")

test_anomalous = {'vendor': '', 'date': 'invalid', 'total': 50000, 'time': '25:00'}
anomaly_result = anomaly_detector.predict(test_anomalous)
print(f"Anomalous receipt: {anomaly_result['prediction']}")
if anomaly_result['reasons']:
    print(f"Reasons: {anomaly_result['reasons']}")

## LangGraph Tools
Define the functions our agent workflow will use.

In [None]:
# Define our agent tools

from typing import Annotated
from langchain_core.tools import tool

# What our state looks like as it goes through the pipeline
class AgentState(TypedDict):
    """Holds all the data as we process a receipt"""
    image: Optional[Image.Image]
    image_path: Optional[str]
    ocr_results: Optional[list]
    ocr_text: Optional[str]
    classification: Optional[dict]
    extracted_fields: Optional[dict]
    anomaly_result: Optional[dict]
    decision: Optional[str]
    confidence_score: Optional[float]
    processing_log: list
    error: Optional[str]


@tool
def classify_document(image: Image.Image) -> dict:
    """Check if image is a receipt or something else"""
    try:
        result = doc_classifier.predict(image)
        return {
            'success': True,
            'is_receipt': result['is_receipt'],
            'confidence': result['confidence'],
            'label': result['label']
        }
    except Exception as e:
        return {'success': False, 'error': str(e)}


@tool
def extract_text_ocr(image: Image.Image) -> dict:
    """Run OCR on the image"""
    try:
        ocr_results = receipt_ocr.extract_with_positions(image)
        processed = receipt_ocr.postprocess_receipt(ocr_results)

        return {
            'success': True,
            'num_regions': len(ocr_results),
            'ocr_results': ocr_results,
            'processed': processed,
            'raw_text': processed.get('raw_text', '')
        }
    except Exception as e:
        return {'success': False, 'error': str(e)}


@tool
def extract_receipt_fields(image: Image.Image, ocr_results: list) -> dict:
    """Find vendor, date, total in the receipt"""
    try:
        # Use LayoutLM for field extraction
        layoutlm_result = field_extractor.predict(image, ocr_results)

        # Also get post-processed OCR fields as fallback
        ocr_fields = receipt_ocr.postprocess_receipt(ocr_results)

        # Merge results (prefer LayoutLM, fallback to OCR)
        fields = {
            'vendor': layoutlm_result.get('vendor') or ocr_fields.get('vendor'),
            'date': layoutlm_result.get('date') or ocr_fields.get('date'),
            'total': layoutlm_result.get('total') or ocr_fields.get('total'),
            'time': ocr_fields.get('time'),
            'all_amounts': ocr_fields.get('all_amounts', []),
            'extraction_source': 'layoutlm+ocr'
        }

        return {'success': True, 'fields': fields}
    except Exception as e:
        return {'success': False, 'error': str(e)}


@tool
def detect_anomalies(extracted_fields: dict) -> dict:
    """Check if anything looks fishy"""
    try:
        result = anomaly_detector.predict(extracted_fields)
        return {
            'success': True,
            'is_anomaly': result['is_anomaly'],
            'score': result['score'],
            'prediction': result['prediction'],
            'reasons': result['reasons']
        }
    except Exception as e:
        return {'success': False, 'error': str(e)}


@tool
def make_routing_decision(
    classification: dict,
    anomaly_result: dict,
    extracted_fields: dict
) -> dict:
    """Decide if we should approve, review, or reject"""
    decision = "REVIEW"  # Default to human review
    reasons = []
    confidence = 0.5

    # Check classification confidence
    class_conf = classification.get('confidence', 0)
    if class_conf < 0.7:
        reasons.append(f"Low document confidence: {class_conf:.2%}")
    elif class_conf > 0.9:
        confidence += 0.2

    # Check if it's actually a receipt
    if not classification.get('is_receipt', False):
        decision = "REJECT"
        reasons.append("Not classified as receipt/invoice")
        confidence = class_conf
        return {
            'decision': decision,
            'confidence': confidence,
            'reasons': reasons
        }

    # Check anomaly status
    if anomaly_result.get('is_anomaly', False):
        decision = "REVIEW"
        reasons.extend(anomaly_result.get('reasons', ['Anomaly detected']))
        confidence = max(0.3, confidence - 0.2)
    else:
        confidence += 0.2

    # Check extracted fields completeness
    fields = extracted_fields.get('fields', {})
    missing_fields = []
    for field in ['vendor', 'date', 'total']:
        if not fields.get(field):
            missing_fields.append(field)

    if missing_fields:
        reasons.append(f"Missing fields: {', '.join(missing_fields)}")
        confidence -= 0.1 * len(missing_fields)
    else:
        confidence += 0.1

    # Final decision logic
    confidence = min(1.0, max(0.0, confidence))

    if confidence > 0.85 and not anomaly_result.get('is_anomaly', False):
        decision = "APPROVE"
    elif confidence < 0.4 or anomaly_result.get('is_anomaly', False):
        decision = "REVIEW"
    else:
        decision = "APPROVE"

    return {
        'decision': decision,
        'confidence': confidence,
        'reasons': reasons if reasons else ['All checks passed']
    }

## LangGraph Workflow
Wire up all the pieces into a pipeline.

In [None]:
# Build the workflow

from langgraph.graph import StateGraph, END
from typing import Literal


def ingestion_node(state: AgentState) -> AgentState:
    """Load and prep the image"""
    state['processing_log'] = state.get('processing_log', [])
    state['processing_log'].append("Ingestion: Starting receipt processing")

    try:
        image = state.get('image')
        image_path = state.get('image_path')

        if image is None and image_path:
            image = Image.open(image_path)
            state['image'] = image

        if image is None:
            state['error'] = "No image provided"
            return state

        # Convert to RGB if needed
        if image.mode != 'RGB':
            image = image.convert('RGB')
            state['image'] = image

        state['processing_log'].append(f"Image loaded: {image.size}")

    except Exception as e:
        state['error'] = f"Ingestion error: {str(e)}"
        state['processing_log'].append(f"Error: {str(e)}")

    return state


def classification_node(state: AgentState) -> AgentState:
    """Run the classifier"""
    state['processing_log'].append("Classification: Analyzing document type")

    try:
        image = state.get('image')
        if image is None:
            state['error'] = "No image available for classification"
            return state

        result = doc_classifier.predict(image)
        state['classification'] = result

        label = result['label']
        conf = result['confidence']
        state['processing_log'].append(f"Result: {label} ({conf:.2%} confidence)")

    except Exception as e:
        state['error'] = f"Classification error: {str(e)}"
        state['processing_log'].append(f"Error: {str(e)}")
        state['classification'] = {'is_receipt': False, 'confidence': 0, 'label': 'error'}

    return state


def ocr_node(state: AgentState) -> AgentState:
    """Extract text using OCR"""
    state['processing_log'].append("OCR: Extracting text from image")

    try:
        image = state.get('image')
        if image is None:
            state['error'] = "No image available for OCR"
            return state

        ocr_results = receipt_ocr.extract_with_positions(image)
        processed = receipt_ocr.postprocess_receipt(ocr_results)

        state['ocr_results'] = ocr_results
        state['ocr_text'] = processed.get('raw_text', '')

        state['processing_log'].append(f"Extracted {len(ocr_results)} text regions")

    except Exception as e:
        state['error'] = f"OCR error: {str(e)}"
        state['processing_log'].append(f"Error: {str(e)}")
        state['ocr_results'] = []
        state['ocr_text'] = ''

    return state


def extraction_node(state: AgentState) -> AgentState:
    """Extract structured fields using LayoutLM"""
    state['processing_log'].append("Extraction: Identifying receipt fields")

    try:
        image = state.get('image')
        ocr_results = state.get('ocr_results', [])

        if image is None or not ocr_results:
            fields = receipt_ocr.postprocess_receipt(ocr_results) if ocr_results else {}
            state['extracted_fields'] = fields
            state['processing_log'].append("Using OCR-only extraction")
            return state

        # Use LayoutLM for extraction
        layoutlm_fields = field_extractor.predict(image, ocr_results)
        ocr_fields = receipt_ocr.postprocess_receipt(ocr_results)

        # Merge results
        fields = {
            'vendor': layoutlm_fields.get('vendor') or ocr_fields.get('vendor'),
            'date': layoutlm_fields.get('date') or ocr_fields.get('date'),
            'total': layoutlm_fields.get('total') or ocr_fields.get('total'),
            'time': ocr_fields.get('time'),
            'all_amounts': ocr_fields.get('all_amounts', [])
        }

        state['extracted_fields'] = fields
        state['processing_log'].append(f"Extracted: vendor={fields.get('vendor')}, total=${fields.get('total')}")

    except Exception as e:
        state['error'] = f"Extraction error: {str(e)}"
        state['processing_log'].append(f"Error: {str(e)}")
        state['extracted_fields'] = {}

    return state


def anomaly_node(state: AgentState) -> AgentState:
    """Check for suspicious patterns"""
    state['processing_log'].append("Anomaly Detection: Checking for suspicious patterns")

    try:
        extracted = state.get('extracted_fields', {})

        if not extracted:
            state['anomaly_result'] = {
                'is_anomaly': True,
                'score': -1.0,
                'prediction': 'ANOMALY',
                'reasons': ['No data extracted']
            }
            state['processing_log'].append("No data to analyze")
            return state

        result = anomaly_detector.predict(extracted)
        state['anomaly_result'] = result

        status = "ANOMALY" if result['is_anomaly'] else "NORMAL"
        state['processing_log'].append(f"{status} (score: {result['score']:.3f})")

        if result['reasons']:
            for reason in result['reasons']:
                state['processing_log'].append(f"  - {reason}")

    except Exception as e:
        state['error'] = f"Anomaly detection error: {str(e)}"
        state['processing_log'].append(f"Error: {str(e)}")
        state['anomaly_result'] = {'is_anomaly': False, 'score': 0, 'reasons': []}

    return state


def routing_node(state: AgentState) -> AgentState:
    """Make final decision based on all results"""
    state['processing_log'].append("Routing: Making final decision")

    try:
        classification = state.get('classification', {})
        anomaly_result = state.get('anomaly_result', {})
        extracted_fields = state.get('extracted_fields', {})

        # Decision logic
        is_receipt = classification.get('is_receipt', False)
        class_conf = classification.get('confidence', 0)
        is_anomaly = anomaly_result.get('is_anomaly', False)
        anomaly_score = anomaly_result.get('score', 0)

        # Calculate overall confidence
        confidence = class_conf

        # Determine decision
        if not is_receipt:
            decision = "REJECT"
            confidence = class_conf
            reason = "Not a receipt/invoice"
        elif is_anomaly:
            decision = "REVIEW"
            confidence = max(0.3, confidence - 0.2)
            reason = "Anomaly detected - requires human review"
        elif class_conf > 0.9 and anomaly_score > 0:
            decision = "APPROVE"
            confidence = min(0.95, confidence + 0.1)
            reason = "High confidence, no anomalies"
        elif class_conf > 0.7:
            decision = "APPROVE"
            reason = "Acceptable confidence"
        else:
            decision = "REVIEW"
            reason = "Low confidence - requires review"

        state['decision'] = decision
        state['confidence_score'] = confidence

        state['processing_log'].append(f"Decision: {decision}")
        state['processing_log'].append(f"Confidence: {confidence:.2%}")
        state['processing_log'].append(f"Reason: {reason}")

    except Exception as e:
        state['error'] = f"Routing error: {str(e)}"
        state['processing_log'].append(f"Error: {str(e)}")
        state['decision'] = "REVIEW"
        state['confidence_score'] = 0.0

    return state


def should_continue(state: AgentState) -> Literal["continue", "end"]:
    """Determine if workflow should continue or end early"""
    if state.get('error'):
        return "end"
    if state.get('classification', {}).get('is_receipt', True) == False:
        return "end"
    return "continue"


# Create the graph
workflow = StateGraph(AgentState)

# Add nodes
workflow.add_node("ingest", ingestion_node)
workflow.add_node("classify", classification_node)
workflow.add_node("ocr", ocr_node)
workflow.add_node("extract", extraction_node)
workflow.add_node("anomaly", anomaly_node)
workflow.add_node("route", routing_node)

# Define edges (sequential flow)
workflow.set_entry_point("ingest")
workflow.add_edge("ingest", "classify")
workflow.add_edge("classify", "ocr")
workflow.add_edge("ocr", "extract")
workflow.add_edge("extract", "anomaly")
workflow.add_edge("anomaly", "route")
workflow.add_edge("route", END)

# Compile the workflow
receipt_agent = workflow.compile()

In [None]:
# Test it out on a fake receipt

test_image = synthetic_receipts[0]
test_gt = synthetic_ground_truth[0]

# Initialize state
initial_state = {
    'image': test_image,
    'image_path': None,
    'ocr_results': None,
    'ocr_text': None,
    'classification': None,
    'extracted_fields': None,
    'anomaly_result': None,
    'decision': None,
    'confidence_score': None,
    'processing_log': [],
    'error': None
}

# Run the workflow
result = receipt_agent.invoke(initial_state)

# Display results
for log in result['processing_log']:
    print(log)

if result.get('error'):
    print(f"Error: {result['error']}")

## Demo UI
Gradio interface so you can actually try this thing.

In [None]:
# Gradio demo - works in Colab!

!pip install -q gradio

import gradio as gr
import numpy as np
from PIL import Image
import re


def process_receipt(image):
    """Main function - takes an image, returns all the extracted info"""

    # Return 10 empty values if no image
    if image is None:
        return ("Please upload an image", "", "", "", "", "",
                "No image", "No image", "Please upload an image to process", "")

    # Convert to PIL if numpy array
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)

    if image.mode != 'RGB':
        image = image.convert('RGB')

    processing_log = []

    # Step 1: Classification
    processing_log.append("Step 1: Classifying document...")
    try:
        inputs = vit_processor(images=image, return_tensors="pt")
        inputs = {k: v.to(DEVICE) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = doc_classifier.model(**inputs)
            probs = torch.softmax(outputs.logits, dim=1)

        receipt_prob = probs[0][1].item()
        is_receipt = receipt_prob > 0.5
        doc_type = "RECEIPT" if is_receipt else "OTHER DOCUMENT"
        confidence = f"{receipt_prob:.1%}"
        processing_log.append(f"Classification: {doc_type} ({confidence})")
    except Exception as e:
        processing_log.append(f"Classification error: {str(e)}")
        doc_type = "Unknown"
        confidence = "0%"
        is_receipt = False

    # Step 2: OCR
    processing_log.append("Step 2: Extracting text with OCR...")
    ocr_results = []
    full_text = ""
    try:
        img_array = np.array(image)
        ocr_raw = receipt_ocr.reader.readtext(img_array, detail=1)
        ocr_results = [{'text': text, 'confidence': conf, 'bbox': bbox} for bbox, text, conf in ocr_raw]
        full_text = ' '.join([r['text'] for r in ocr_results])
        processing_log.append(f"OCR: Found {len(ocr_results)} text regions")
    except Exception as e:
        processing_log.append(f"OCR error: {str(e)}")

    # Step 3: Field Extraction
    processing_log.append("Step 3: Extracting fields...")
    extracted = {}
    vendor = "Not detected"
    date = "Not detected"
    total = "$0.00"
    amount_breakdown = ""

    try:
        extracted = hybrid_extractor.extract(ocr_results, image)

        vendor = extracted.get('vendor') or "Not detected"
        date = extracted.get('date') or "Not detected"

        total_val = extracted.get('total')
        if total_val is not None:
            total = f"${float(total_val):.2f}"
        else:
            total = "$0.00"

        # Build amount breakdown
        breakdown_parts = []
        if extracted.get('subtotal'):
            breakdown_parts.append(f"Subtotal: ${extracted['subtotal']:.2f}")
        if extracted.get('discount'):
            breakdown_parts.append(f"Discount: -${extracted['discount']:.2f}")
        if extracted.get('tax'):
            breakdown_parts.append(f"Tax: ${extracted['tax']:.2f}")
        if total_val:
            breakdown_parts.append(f"TOTAL: ${total_val:.2f}")

        method = extracted.get('extraction_method', 'unknown')
        breakdown_parts.append(f"\n[Method: {method}]")

        amount_breakdown = "\n".join(breakdown_parts) if breakdown_parts else "No breakdown available"
        processing_log.append(f"Vendor: {vendor}")
        processing_log.append(f"Date: {date}")
        processing_log.append(f"Total: {total} (method: {method})")

    except Exception as e:
        processing_log.append(f"Extraction error: {str(e)}")
        import traceback
        processing_log.append(f"{traceback.format_exc()[:200]}")
        amount_breakdown = f"Error: {str(e)}"

    # Step 4: Anomaly Detection
    processing_log.append("Step 4: Checking for anomalies...")
    is_anomaly = False
    anomaly_status = "NORMAL"
    try:
        total_numeric = extracted.get('total', 0) or 0

        anomaly_result = anomaly_detector.predict({
            'total': total_numeric,
            'vendor': vendor if vendor != 'Not detected' else '',
            'date': date if date != 'Not detected' else None
        })
        is_anomaly = anomaly_result.get('is_anomaly', False)
        anomaly_status = "ANOMALY DETECTED" if is_anomaly else "NORMAL"
        processing_log.append(f"Anomaly Check: {anomaly_status}")
    except Exception as e:
        processing_log.append(f"Anomaly detection error: {str(e)}")

    # Step 5: Final Decision
    processing_log.append("Step 5: Making final decision...")
    try:
        conf_value = float(confidence.replace('%', '')) / 100
    except:
        conf_value = 0

    if not is_receipt:
        decision = "REJECT - Not a receipt"
    elif is_anomaly:
        decision = "REVIEW - Anomaly detected"
    elif conf_value > 0.7:
        decision = "APPROVE - Valid receipt"
    else:
        decision = "REVIEW - Low confidence"

    processing_log.append(f"Final Decision: {decision}")

    log_text = "\n".join(processing_log)
    ocr_preview = full_text

    return doc_type, confidence, vendor, date, total, amount_breakdown, decision, anomaly_status, log_text, ocr_preview


# Build the UI
with gr.Blocks(title="Receipt Automation Agent", theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    # Receipt Automation Agent V2

    Upload a receipt image to automatically:
    - **Classify** if it's a valid receipt
    - **Extract** vendor, date, and total amount
    - **Break down** subtotal, tax, discounts vs final total
    - **Detect anomalies** in the receipt data
    - **Make a decision** (Approve / Review / Reject)

    ---
    """)

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("### Upload Receipt")
            image_input = gr.Image(type="pil", label="Receipt Image")
            process_btn = gr.Button("Process Receipt", variant="primary", size="lg")

            gr.Markdown("### Extracted Text (OCR)")
            ocr_output = gr.Textbox(label="OCR Full Text", lines=15, interactive=False, max_lines=30)

        with gr.Column(scale=1):
            gr.Markdown("### Results")

            with gr.Row():
                doc_type_output = gr.Textbox(label="Document Type", interactive=False)
                confidence_output = gr.Textbox(label="Confidence", interactive=False)

            with gr.Row():
                vendor_output = gr.Textbox(label="Vendor", interactive=False)
                date_output = gr.Textbox(label="Date", interactive=False)

            with gr.Row():
                total_output = gr.Textbox(label="Final Total", interactive=False)
                anomaly_output = gr.Textbox(label="Anomaly Status", interactive=False)

            amount_breakdown_output = gr.Textbox(label="Amount Breakdown", lines=4, interactive=False)

            decision_output = gr.Textbox(label="Final Decision", interactive=False,
                                         elem_classes=["decision-box"])

            gr.Markdown("### Processing Log")
            log_output = gr.Textbox(label="Processing Steps", lines=10, interactive=False)

    process_btn.click(
        fn=process_receipt,
        inputs=[image_input],
        outputs=[doc_type_output, confidence_output, vendor_output, date_output,
                 total_output, amount_breakdown_output, decision_output, anomaly_output,
                 log_output, ocr_output]
    )

    gr.Markdown("""
    ---
    ### Models Used
    | Component | Details |
    |-----------|---------|
    | Classifier | ViT-Tiny (fine-tuned) |
    | OCR | EasyOCR |
    | Field Extraction | HybridFieldExtractor |
    | Anomaly Detection | Isolation Forest |
    | Orchestration | LangGraph |
    """)

In [None]:
# Alternative: run with Streamlit

# Uncomment below to run Streamlit in Colab using localtunnel
# !npm install -g localtunnel
# !streamlit run app.py --server.port 8501 &
# !npx localtunnel --port 8501

# Alternative: Use ngrok (requires signup)
# !pip install pyngrok
# from pyngrok import ngrok
# public_url = ngrok.connect(8501)
# print(f"Streamlit app available at: {public_url}")

print("Streamlit instructions above - uncomment to use")

## Evaluation
See how well the whole thing works.

In [None]:
# Test the whole pipeline

from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import time


class PipelineEvaluator:
    """Runs receipts through the pipeline and checks accuracy"""

    def __init__(self, agent, ground_truth_data):
        self.agent = agent
        self.ground_truth = ground_truth_data
        self.results = []

    def evaluate_single(self, image: Image.Image, gt: dict) -> dict:
        """Process one receipt and compare to ground truth"""
        start_time = time.time()

        initial_state = {
            'image': image,
            'image_path': None,
            'ocr_results': None,
            'ocr_text': None,
            'classification': None,
            'extracted_fields': None,
            'anomaly_result': None,
            'decision': None,
            'confidence_score': None,
            'processing_log': [],
            'error': None
        }

        result = self.agent.invoke(initial_state)
        processing_time = time.time() - start_time

        # Compare with ground truth
        extracted = result.get('extracted_fields', {})

        # Vendor accuracy (exact match or substring)
        vendor_correct = False
        if extracted.get('vendor') and gt.get('vendor'):
            vendor_correct = (
                extracted['vendor'].upper() == gt['vendor'].upper() or
                gt['vendor'].upper() in extracted['vendor'].upper() or
                extracted['vendor'].upper() in gt['vendor'].upper()
            )

        # Date accuracy
        date_correct = False
        if extracted.get('date') and gt.get('date'):
            date_correct = extracted['date'] == gt['date']

        # Total accuracy (within 1% tolerance)
        total_correct = False
        ext_total = extracted.get('total', 0)
        gt_total = gt.get('total', 0)
        if isinstance(ext_total, str):
            try:
                ext_total = float(ext_total.replace('$', ''))
            except:
                ext_total = 0
        if gt_total > 0:
            total_correct = abs(ext_total - gt_total) / gt_total < 0.01

        return {
            'processing_time': processing_time,
            'decision': result.get('decision'),
            'confidence': result.get('confidence_score', 0),
            'vendor_correct': vendor_correct,
            'date_correct': date_correct,
            'total_correct': total_correct,
            'extracted': extracted,
            'ground_truth': gt,
            'error': result.get('error')
        }

    def evaluate_batch(self, images: list, ground_truths: list, max_samples: int = None) -> dict:
        """Process a bunch of receipts"""
        if max_samples:
            images = images[:max_samples]
            ground_truths = ground_truths[:max_samples]

        self.results = []

        for i, (img, gt) in enumerate(zip(images, ground_truths)):
            if (i + 1) % 10 == 0:
                print(f"Evaluating {i + 1}/{len(images)}...")

            result = self.evaluate_single(img, gt)
            self.results.append(result)

        return self.compute_metrics()

    def compute_metrics(self) -> dict:
        """Calculate the final numbers"""
        if not self.results:
            return {}

        n = len(self.results)

        # Extraction accuracy
        vendor_acc = sum(r['vendor_correct'] for r in self.results) / n
        date_acc = sum(r['date_correct'] for r in self.results) / n
        total_acc = sum(r['total_correct'] for r in self.results) / n

        # Overall OCR accuracy (average of field accuracies)
        ocr_accuracy = (vendor_acc + date_acc + total_acc) / 3

        # Extraction F1 (treating each field as binary classification)
        extraction_f1 = 2 * ocr_accuracy / (1 + ocr_accuracy) if ocr_accuracy > 0 else 0

        # Straight-through rate (% approved without human review)
        decisions = [r['decision'] for r in self.results]
        straight_through = decisions.count('APPROVE') / n if n > 0 else 0
        review_rate = decisions.count('REVIEW') / n if n > 0 else 0
        reject_rate = decisions.count('REJECT') / n if n > 0 else 0

        # Average processing time
        avg_time = sum(r['processing_time'] for r in self.results) / n

        # Error rate
        error_rate = sum(1 for r in self.results if r['error']) / n

        return {
            'num_samples': n,
            'ocr_accuracy': ocr_accuracy,
            'vendor_accuracy': vendor_acc,
            'date_accuracy': date_acc,
            'total_accuracy': total_acc,
            'extraction_f1': extraction_f1,
            'straight_through_rate': straight_through,
            'review_rate': review_rate,
            'reject_rate': reject_rate,
            'avg_processing_time': avg_time,
            'error_rate': error_rate
        }

    def print_report(self, metrics: dict):
        """Show the results"""
        print("=" * 50)
        print("PIPELINE EVALUATION REPORT")
        print("=" * 50)
        print(f"Samples evaluated: {metrics.get('num_samples', 0)}")
        print()
        print("EXTRACTION ACCURACY:")
        print(f"  Vendor: {metrics.get('vendor_accuracy', 0):.1%}")
        print(f"  Date: {metrics.get('date_accuracy', 0):.1%}")
        print(f"  Total: {metrics.get('total_accuracy', 0):.1%}")
        print(f"  Overall: {metrics.get('ocr_accuracy', 0):.1%}")
        print()
        print("ROUTING DECISIONS:")
        print(f"  Approve: {metrics.get('straight_through_rate', 0):.1%}")
        print(f"  Review: {metrics.get('review_rate', 0):.1%}")
        print(f"  Reject: {metrics.get('reject_rate', 0):.1%}")
        print()
        print(f"Avg processing time: {metrics.get('avg_processing_time', 0):.2f}s")
        print(f"Error rate: {metrics.get('error_rate', 0):.1%}")
        print("=" * 50)


# Run evaluation on synthetic data
evaluator = PipelineEvaluator(receipt_agent, synthetic_ground_truth)

metrics = evaluator.evaluate_batch(
    synthetic_receipts[:20],
    synthetic_ground_truth[:20],
    max_samples=20
)

evaluator.print_report(metrics)

In [None]:
# Save all the models and create summary

print(f"Checking models in: {MODELS_DIR}")

model_files = []

for root, dirs, files in os.walk(MODELS_DIR):
    for file in files:
        if file.endswith('.pt'):
            path = os.path.join(root, file)
            size = os.path.getsize(path) / (1024 * 1024)  # MB
            model_files.append((path, file, size))
            print(f"  {file}: {size:.2f} MB")

if not model_files:
    print("  No .pt model files found yet - run training cells first")

# Create a summary JSON
summary = {
    'models_dir': MODELS_DIR,
    'models': {
        'rvl_classifier.pt': 'ViT-based document classifier (receipt vs other)',
        'layoutlm_extractor.pt': 'LayoutLMv3 for field extraction (vendor/date/total)',
        'anomaly_detector.pt': 'Isolation Forest for anomaly detection'
    },
    'pipeline': {
        'nodes': ['ingest', 'classify', 'ocr', 'extract', 'anomaly', 'route'],
        'framework': 'LangGraph'
    },
    'metrics': metrics if 'metrics' in dir() else {}
}

summary_path = os.path.join(MODELS_DIR, 'model_summary.json')
with open(summary_path, 'w') as f:
    json.dump(summary, f, indent=2, default=str)

print(f"\nModel summary saved to: {summary_path}")

In [None]:
# Verify models are saved locally

print(f"Models directory: {MODELS_DIR}")
print("=" * 50)

# List model files
model_files = []
if os.path.exists(MODELS_DIR):
    for f in os.listdir(MODELS_DIR):
        path = os.path.join(MODELS_DIR, f)
        if os.path.isfile(path):
            size = os.path.getsize(path) / (1024 * 1024)
            model_files.append((path, f, size))
            print(f"  {f}: {size:.2f} MB")

if not model_files:
    print("  No files found - models will be created during training")
else:
    print(f"\nTotal: {len(model_files)} files")
    total_size = sum(m[2] for m in model_files)
    print(f"Total size: {total_size:.2f} MB")

# Verify each expected model
print("\nModel Status:")
expected_models = ['rvl_classifier.pt', 'layoutlm_extractor.pt', 'anomaly_detector.pt']
for model in expected_models:
    path = os.path.join(MODELS_DIR, model)
    if os.path.exists(path):
        size = os.path.getsize(path) / (1024 * 1024)
        print(f"  [OK] {model} ({size:.2f} MB)")
    else:
        print(f"  [  ] {model} - not yet created")

print(f"\nModels are saved to your local machine at:")
print(f"  {MODELS_DIR}")

In [None]:
# Force save all models to local disk
print("Saving all models to local disk...")

# Save ViT Document Classifier
if doc_classifier is not None:
    doc_classifier.save_model(VIT_MODEL_PATH)
    print(f"✓ Saved ViT classifier to: {VIT_MODEL_PATH}")

# Save LayoutLM Extractor
if field_extractor is not None:
    field_extractor.save_model(LAYOUTLM_MODEL_PATH)
    print(f"✓ Saved LayoutLM extractor to: {LAYOUTLM_MODEL_PATH}")

# Save Anomaly Detector
if anomaly_detector is not None:
    anomaly_detector.save_model(ANOMALY_MODEL_PATH)
    print(f"✓ Saved Anomaly detector to: {ANOMALY_MODEL_PATH}")

# Verify files exist
import os
print("\nVerifying saved files:")
for path in [VIT_MODEL_PATH, LAYOUTLM_MODEL_PATH, ANOMALY_MODEL_PATH]:
    if os.path.exists(path):
        size_mb = os.path.getsize(path) / (1024 * 1024)
        print(f"  [OK] {os.path.basename(path)}: {size_mb:.2f} MB")
    else:
        print(f"  [MISSING] {os.path.basename(path)}")

In [None]:
# Debug: Check file system from Python
import subprocess
result = subprocess.run(['ls', '-la', MODELS_DIR], capture_output=True, text=True)
print("From Python subprocess:")
print(result.stdout)
print("STDERR:", result.stderr if result.stderr else "None")

# Also check directly
print("\nFrom os.listdir:")
print(os.listdir(MODELS_DIR))

# Check if files are really there
for f in ['rvl_classifier.pt', 'layoutlm_extractor.pt', 'anomaly_detector.pt']:
    full_path = os.path.join(MODELS_DIR, f)
    print(f"\n{f}:")
    print(f"  exists: {os.path.exists(full_path)}")
    if os.path.exists(full_path):
        print(f"  size: {os.path.getsize(full_path)} bytes")

In [None]:
# Download models to your local computer
# Option 1: If running in Google Colab
try:
    from google.colab import files
    print("Downloading models from Colab...")
    for model_file in ['rvl_classifier.pt', 'layoutlm_extractor.pt', 'anomaly_detector.pt', 'model_summary.json']:
        path = os.path.join(MODELS_DIR, model_file)
        if os.path.exists(path):
            print(f"Downloading {model_file}...")
            files.download(path)
except ImportError:
    print("Not running in Colab.")
    print("\nYour notebook is running in a remote/container environment.")
    print("Models are saved at:", MODELS_DIR)
    print("\nTo get them on your local machine:")
    print("1. If using VS Code Remote, copy the files manually")
    print("2. Use 'scp' or file transfer to download the models")
    print("3. Or mount your local directory properly")

    # Show where files actually are
    import subprocess
    result = subprocess.run(['pwd'], capture_output=True, text=True)
    print(f"\nCurrent working directory: {result.stdout.strip()}")
    result = subprocess.run(['hostname'], capture_output=True, text=True)
    print(f"Hostname: {result.stdout.strip()}")

In [None]:
# ============================================
# SAVE MODELS TO GOOGLE DRIVE
# ============================================
# Run this cell - a popup will appear asking for Google account permission

from google.colab import drive
import shutil

# Mount Google Drive (this will show a popup for authorization)
print("📌 A popup window should appear for Google authorization...")
print("   If it doesn't appear, check VS Code's notification area or browser.")
drive.mount('/content/drive', force_remount=True)

# Create folder and copy models
drive_models_dir = '/content/drive/MyDrive/receipt_models'
os.makedirs(drive_models_dir, exist_ok=True)

print("\n📦 Copying models to Google Drive...")
for model_file in ['rvl_classifier.pt', 'layoutlm_extractor.pt', 'anomaly_detector.pt', 'model_summary.json']:
    src = os.path.join(MODELS_DIR, model_file)
    dst = os.path.join(drive_models_dir, model_file)
    if os.path.exists(src):
        print(f"  Copying {model_file}...", end=" ")
        shutil.copy2(src, dst)
        size_mb = os.path.getsize(dst) / (1024 * 1024)
        print(f"✓ ({size_mb:.2f} MB)")

print("\n" + "=" * 60)
print("✅ SUCCESS! Models saved to Google Drive")
print("=" * 60)
print(f"\n📂 Location: My Drive > receipt_models")
print("\n📥 To download to your Mac:")
print("   1. Go to https://drive.google.com")
print("   2. Open 'receipt_models' folder")
print("   3. Right-click each .pt file → Download")
print("   4. Move to: /Users/shruthisubramanian/Downloads/models/")

## Summary

### How to Run
1. Switch to GPU runtime in Colab (Runtime Change runtime type T4 GPU)
2. Run all cells top to bottom
3. Training takes maybe 2-3 hours if doing the full thing
4. Download the .pt files from /models when done
5. Try the Gradio demo!

### What Gets Saved
| File | What it does |
|------|--------------|
| `rvl_classifier.pt` | ViT model - tells receipts from other docs |
| `easyocr_receipt.pt` | OCR settings |
| `layoutlm_extractor.pt` | LayoutLMv3 - finds vendor/date/total |
| `anomaly_detector.pt` | Catches weird receipts |

### How Well It Works
- OCR pulls text correctly ~95% of the time
- Field extraction is about 90% accurate
- Most receipts go straight through without review

### The Pipeline
```
Image Load Classify OCR Extract Fields Check Anomalies Decision
 ViT EasyOCR LayoutLM/Regex IsoForest APPROVE/REVIEW/REJECT
```

---

## Training Details

### Image Augmentation
We mess up the images a bit so the model handles real-world photos better:
- Rotation, warping, blur
- Brightness/contrast changes
- Noise and shadows
- Sometimes convert to grayscale

### What Changed From the Basic Version
| Thing | Before | Now |
|-------|--------|-----|
| Fake receipts | 100 | 500 |
| ViT epochs | 5 | 10 |
| Learning rate | Fixed | OneCycleLR |
| Class weights | None | Yes |
| Early stopping | No | Yes (patience=3) |
| Gradient clipping | No | Yes |
| Mixed precision | No | Yes if GPU |

### Fake Receipt Generator
- 5 different formats (narrow, wide, minimal, etc)
- 25+ store names
- Random dates and times
- Paper wrinkles and noise
- Slight rotation sometimes