In [4]:
import os
import fitz
import pytesseract
import tempfile
import base64
import json
import shutil
import pandas as pd
from pathlib import Path
from PIL import Image
from pdf2image import convert_from_path
from dotenv import load_dotenv
from tqdm import tqdm
from joblib import Parallel, delayed
from openai import OpenAI

# ✅ 加载 OpenAI
load_dotenv()
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

# ✅ 设置 poppler 路径（macOS）
POPPLER_PATH = "/opt/homebrew/bin" if shutil.which("pdftoppm") is None else None

# ✅ 通用文本提取（PyMuPDF + OCR fallback）
def extract_front_back_text(pdf_path, front_n=5, back_n=5, dpi=300):
    try:
        doc = fitz.open(pdf_path)
        texts = [doc[i].get_text() for i in range(min(front_n, len(doc)))]
        texts += [doc[i].get_text() for i in range(max(0, len(doc) - back_n), len(doc))]
        doc.close()
        text = "\n".join(texts)
        if len(text.strip()) < 50:
            raise ValueError("Too little text")
        return text
    except:
        with tempfile.TemporaryDirectory() as path:
            images = convert_from_path(pdf_path, dpi=dpi, output_folder=path, poppler_path=POPPLER_PATH)
            selected = images[:front_n] + images[-back_n:]
            return "\n".join([pytesseract.image_to_string(img) for img in selected])

# ✅ 任务1：报告类型分类
def classify_report_type(text):
    prompt = f"""
From the text (first/last pages of a corporate report), perform:

1. Classify as:
- "sustainability report"
- "annual report"
- "integrated report"
- "other"

2. Check if a sustainability section is included, and name it.

Return JSON:
{{
  "report_type": "...",
  "has_sustainability_section": true/false,
  "sustainability_section_name": "...",
  "reasoning": "..."
}}

Text:
{text}
"""
    try:
        res = client.chat.completions.create(
            model="gpt-4.1-mini",
            messages=[{"role": "user", "content": prompt}]
        )
        return json.loads(res.choices[0].message.content)
    except Exception as e:
        return {
            "report_type": "ERROR",
            "has_sustainability_section": "",
            "sustainability_section_name": "",
            "reasoning": str(e)
        }

# ✅ 任务2：报告年度提取（文本优先）
def build_year_prompt(text):
    return f"""
Extract any fiscal/reporting year expressions like:
- "FY2023", "April 2022 – March 2023", "for the year ended 31 Dec 2022"

Return JSON:
{{
  "normalized_report_year": "...",
  "original_expression": "...",
  "source": "..."
}}

Text:
{text}
"""

def encode_image_to_base64(pil_img):
    with tempfile.NamedTemporaryFile(suffix=".png") as f:
        pil_img.save(f.name, format="PNG")
        with open(f.name, "rb") as img_f:
            return base64.b64encode(img_f.read()).decode("utf-8")

def extract_year_from_vision(pdf_path, max_pages=3):
    try:
        images = convert_from_path(pdf_path, dpi=200, poppler_path=POPPLER_PATH)
        for i, img in enumerate(images[:max_pages]):
            b64 = encode_image_to_base64(img)
            res = client.chat.completions.create(
                model="gpt-4o",
                messages=[{
                    "role": "user",
                    "content": [
                        {"type": "text", "text": "Extract fiscal/reporting year. Return JSON."},
                        {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64}"}}
                    ]
                }]
            )
            parsed = json.loads(res.choices[0].message.content)
            if parsed.get("normalized_report_year"):
                return parsed
    except Exception as e:
        return {"normalized_report_year": None, "original_expression": None, "source": f"Vision ERROR: {e}"}
    return {"normalized_report_year": None, "original_expression": None, "source": "Vision NOT FOUND"}

def extract_report_year(pdf_path):
    try:
        text = extract_front_back_text(pdf_path)
        prompt = build_year_prompt(text)
        res = client.chat.completions.create(
            model="gpt-4.1-mini",
            messages=[{"role": "user", "content": prompt}]
        )
        parsed = json.loads(res.choices[0].message.content)
        if parsed.get("normalized_report_year"):
            return parsed
    except:
        pass
    return extract_year_from_vision(pdf_path)

# ✅ 任务3：公司名称提取
def extract_company_info(text):
    prompt = f"""
From the text (report front/back pages), extract:

- "company_name": issuing company (or "UNKNOWN")
- "other_names": abbreviations/nicknames
- "country": where headquartered
- "reasoning": explanation

Return JSON:
{{
  "company_name": "...",
  "other_names": ["..."],
  "country": "...",
  "reasoning": "..."
}}
"""
    try:
        res = client.chat.completions.create(
            model="gpt-4.1",
            messages=[{"role": "user", "content": prompt}]
        )
        result = json.loads(res.choices[0].message.content)
    except Exception as e:
        return {
            "company_name": "GPT_ERROR",
            "other_names": [],
            "country": "GPT_ERROR",
            "reasoning": str(e),
            "publisher": None
        }

    if result.get("company_name", "").upper() == "UNKNOWN":
        try:
            res2 = client.chat.completions.create(
                model="gpt-4.1",
                messages=[{"role": "user", "content": f"Extract publisher from text:\n{text}"}]
            )
            pub = json.loads(res2.choices[0].message.content)
            result["publisher"] = pub.get("publisher", "UNKNOWN")
        except:
            result["publisher"] = "GPT_ERROR"
    else:
        result["publisher"] = None

    return result

# ✅ 整体流程：处理单个 PDF
def process_pdf(pdf_path):
    text = extract_front_back_text(pdf_path)

    rtype = classify_report_type(text)
    year = extract_report_year(pdf_path)
    company = extract_company_info(text)

    return {
        "filename": pdf_path.name,
        "report_type": rtype.get("report_type"),
        "has_sustainability_section": rtype.get("has_sustainability_section"),
        "sustainability_section_name": rtype.get("sustainability_section_name"),
        "normalized_report_year": year.get("normalized_report_year"),
        "original_expression": year.get("original_expression"),
        "year_source": year.get("source"),
        "company_name": company.get("company_name"),
        "other_names": "; ".join(company.get("other_names", [])),
        "publisher": company.get("publisher"),
        "country": company.get("country"),
        "reasoning": company.get("reasoning")
    }

# ✅ 批处理执行
pdf_dir = Path("pdf_folder")
output_path = Path("results/full_pipeline_output.csv")
os.makedirs(output_path.parent, exist_ok=True)

pdf_files = list(pdf_dir.glob("*.pdf"))
# ✅ 串行执行，避免多进程序列化问题
results = [process_pdf(pdf) for pdf in tqdm(pdf_files, desc="🔍 Processing PDFs")]
df = pd.DataFrame(results)
df.to_csv(output_path, index=False)
print(f"✅ Done! Results saved to: {output_path}")


[A
[A
[A
[A
[A
🔍 Processing PDFs:   0%|          | 5/1277 [03:32<15:01:14, 42.51s/it]


KeyboardInterrupt: 