In [None]:
# scripts/evaluate_against_sroie.py
import json
from pathlib import Path
import csv
from math import isclose

EXTRACTED_DIR = Path("../data/processed/extracted")
SROIE_ENT_DIR = Path("../data/raw/SROIE2019/test/entities")  # adjust if different
OUT_REPORT = Path("results/evaluation_report.csv")
OUT_REPORT.parent.mkdir(parents=True, exist_ok=True)

# load extracted
extracted_csv = EXTRACTED_DIR / "test_extracted.csv"
if not extracted_csv.exists():
    print("No extracted CSV found at", extracted_csv)
    raise SystemExit(1)

# build ground truth map from SROIE entities (they are usually json text files)
gt_map = {}
for ent_file in sorted(SROIE_ENT_DIR.glob("*")):
    # many SROIE datasets store key-value tsv or txt; adjust per your folder format
    # try to parse as JSON, else parse as plain text for key:value
    try:
        with open(ent_file, "r", encoding="utf-8") as f:
            data = f.read()
    except:
        continue
    # attempt quick parse: find lines with TOTAL or DATE
    total = None
    date = None
    for line in data.splitlines():
        lo = line.lower()
        if 'total' in lo:
            # extract number
            import re
            m = re.search(r'([0-9,]+\\.[0-9]{2})', line)
            if m:
                total = float(m.group(1).replace(',',''))
        if any(k in lo for k in ['date','invoice date','transaction date']):
            # try dateutil
            from dateutil import parser
            try:
                dt = parser.parse(line, fuzzy=True)
                date = dt.date().isoformat()
            except:
                pass
    gt_map[ent_file.stem] = {"total": total, "date": date}

# read extracted CSV
import csv
rows = []
with open(extracted_csv, newline='', encoding='utf-8') as f:
    r = csv.DictReader(f)
    for row in r:
        rows.append(row)

# evaluate
total_matches = 0
date_matches = 0
valid_total = 0
valid_date = 0
for row in rows:
    img = row['image']
    ext_total = row['total']
    ext_date = row['date']
    gt = gt_map.get(img, {})
    gt_total = gt.get('total')
    gt_date = gt.get('date')
    if gt_total is not None and ext_total:
        valid_total += 1
        try:
            if isclose(float(ext_total), float(gt_total), rel_tol=1e-3, abs_tol=0.01):
                total_matches += 1
        except:
            pass
    if gt_date and ext_date:
        valid_date += 1
        if gt_date == ext_date:
            date_matches += 1

print(f"Totals: matched {total_matches} / {valid_total} ({(total_matches/valid_total*100) if valid_total else 0:.2f}%)")
print(f"Dates: matched {date_matches} / {valid_date} ({(date_matches/valid_date*100) if valid_date else 0:.2f}%)")

# write basic CSV report
with open(OUT_REPORT, "w", newline="", encoding="utf-8") as out:
    writer = csv.writer(out)
    writer.writerow(["image","extracted_total","gt_total","extracted_date","gt_date","total_match","date_match"])
    for row in rows:
        img = row['image']
        ext_total = row['total']
        ext_date = row['date']
        gt = gt_map.get(img, {})
        gt_total = gt.get('total')
        gt_date = gt.get('date')
        total_match = ''
        date_match = ''
        try:
            if gt_total is not None and ext_total:
                total_match = isclose(float(ext_total), float(gt_total), rel_tol=1e-3, abs_tol=0.01)
        except:
            total_match = False
        if gt_date and ext_date:
            date_match = (gt_date == ext_date)
        writer.writerow([img, ext_total, gt_total, ext_date, gt_date, total_match, date_match])

print("Wrote evaluation report to", OUT_REPORT)