In [None]:
import pickle
from pathlib import Path
from typing import TYPE_HINTING

import cv2
import numpy as np
import pandas as pd
import pytesseract
from PIL import Image

if TYPE_HINTING:
    from sklearn.ensemble import RandomForestClassifier

MODEL_PATH = Path("model.bin")


def load_is_photo_classifier() -> RandomForestClassifier:
    with open(MODEL_PATH, "rb") as f:
        model = pickle.load(f)
    return model


def image_properties(image_path) -> dict:
    # Read image using PIL first (for OCR)
    pil_image = Image.open(image_path)

    # Get OCR data including bounding boxes
    ocr_data = pytesseract.image_to_data(
        pil_image, output_type=pytesseract.Output.DATAFRAME
    )

    # Calculate total area of text bounding boxes
    total_text_area = 0
    try:
        if not ocr_data.empty:
            # Filter out empty text entries and calculate areas
            valid_boxes = ocr_data[
                ocr_data["text"].notna() & (ocr_data["text"].str.strip() != "")
            ]
            total_text_area = sum(
                row["width"] * row["height"] for _, row in valid_boxes.iterrows()
            )
    except Exception:
        pass
    # Get image dimensions for total area
    width, height = pil_image.size
    total_image_area = width * height

    # Calculate text density as ratio of text area to image area
    text_density = total_text_area / total_image_area if total_image_area > 0 else 0

    # Convert to OpenCV format for edge detection
    cv_image = cv2.imread(str(image_path))
    gray = cv2.cvtColor(cv_image, cv2.COLOR_BGR2GRAY)

    # Image properties
    height, width = gray.shape
    total_pixels = height * width

    # Edge detection
    edges = cv2.Canny(gray, 100, 200)
    edge_pixels = np.count_nonzero(edges)
    edge_density = edge_pixels / total_pixels

    # Analyze color variation
    color_std = np.std(cv_image, axis=(0, 1)).mean()

    return {
        "width": width,
        "height": height,
        "total_pixels": total_pixels,
        "text_density": text_density,
        "color_std": color_std,
        "edge_density": edge_density,
    }


def predict_image_quality(model, properties: dict) -> bool:
    X = pd.DataFrame(
        {
            "width": properties["width"],
            "height": properties["height"],
            "total_pixels": properties["total_pixels"],
            "text_density": properties["text_density"],
            "color_std": properties["color_std"],
            "edge_density": properties["edge_density"],
        },
        index=[0],
    )
    return bool(model.predict(X)[0])


def is_photo(model, image_path):
    properties = image_properties(image_path)
    return predict_image_quality(model, properties)