In [None]:
import os
from openai import OpenAI
from dotenv import load_dotenv
import time
import csv
import pandas as pd
import pymupdf
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
import base64
from pathlib import Path
import PyPDF2
import re
import json

In [None]:
# Load environment variables from .env (if present)
load_dotenv()

api_key = os.getenv("OPENAI_API_KEY")

In [None]:
#Extract text from pdfs and process with GPT
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))

def extract_text_from_pdf(pdf_path):
    """Extract text from single pdf file"""
    text = ""
    try:
        with open(pdf_path, 'rb') as file:
            pdf_reader = PyPDF2.PdfReader(file)
            for page in pdf_reader.pages:
                page_text = page.extract_text()
                if page_text:
                    text += page_text + "\n"
    except Exception as e:
        print(f"Error extracting text from {pdf_path}: {e}")
    return text.strip()


def chunk_text(text, max_words=1500):
    """Split text into chunks of max words 1500"""
    words = text.split()
    chunks = []
    for i in range(0, len(words), max_words):
        chunk = ' '.join(words[i:i + max_words])
        chunks.append(chunk)
    return chunks


def process_with_gpt(prompt, model="gpt-5"):
    """Send prompt to GPT. Defensive wrapper that returns text or error string."""
    try:
        response = client.chat.completions.create(
            model=model,
            messages=[{"role": "user", "content": prompt}],
        )
        # Defensive access in case SDK shape differs
        try:
            return response.choices[0].message.content
        except Exception:
            try:
                return response.output_text
            except Exception:
                return str(response)
    except Exception as e:
        return f"API Error: {str(e)}"


def extract_data_from_pdf_text(pdf_text, prompt):
    """Extract data for each text chunk"""
    if "{text_chunk}" in prompt:
        full_prompt = prompt.replace("{text_chunk}", pdf_text)
    else:
        full_prompt = f"{prompt}\n\nDocument:\n{pdf_text}"

    return process_with_gpt(full_prompt)


def parse_csv_response(response_text):
    """Parse GPT CSV responses into a dict; tolerant to units and formats.

    Returns dict with keys: Title, Authors, Date, Dataset_Name,
    Size_Events, Size_Files, Size_Bytes, Data_Format, Dataset_DOI
    or None if parsing fails.
    """
    try:
        if not response_text:
            return None
        # remove markdown fences and common wrappers
        response_text = response_text.replace('```csv', '').replace('```', '').strip()
        lines = [line.strip() for line in response_text.split('\n') if line.strip()]
        if len(lines) < 2:
            return None
        reader = csv.DictReader(lines)
        row = next(reader)

        def textual_to_int(value):
            if value is None:
                return None
            s = str(value).strip()
            if not s or s.lower() in ['null', 'unknown', 'n/a', 'na', '']:
                return None
            # scientific notation
            try:
                if 'e' in s.lower():
                    return int(float(s))
            except Exception:
                pass
            s_clean = s.replace(',', '').lower()
            m = re.search(r'([\d\.]+)\s*(million|billion|thousand|m|b|k)', s_clean)
            if m:
                num = float(m.group(1))
                scale = m.group(2)
                if 'million' in scale or scale == 'm':
                    return int(num * 1e6)
                if 'billion' in scale or scale == 'b':
                    return int(num * 1e9)
                if 'thousand' in scale or scale == 'k':
                    return int(num * 1e3)
            digits = ''.join(c for c in s if c.isdigit())
            if digits:
                return int(digits)
            return None

        result = {}
        result['Title'] = (row.get('Title') or row.get('title') or '').strip() or None
        result['Authors'] = (row.get('Authors') or row.get('authors') or '').strip() or None
        result['Date'] = (row.get('Date') or row.get('date') or '').strip() or None
        result['Dataset_Name'] = (row.get('Dataset name (collision or MC)') or row.get('Dataset_Name') or row.get('Dataset name') or '').strip() or None

        result['Size_Events'] = textual_to_int(row.get('Size (events)') or row.get('Size_Events') or '')
        result['Size_Files'] = textual_to_int(row.get('Size (files)') or row.get('Size_Files') or '')
        result['Size_Bytes'] = textual_to_int(row.get('Size (bytes)') or row.get('Size_Bytes') or '')

        result['Data_Format'] = (row.get('Data format') or row.get('Data_Format') or '').strip() or None
        result['Dataset_DOI'] = (row.get('Dataset DOI') or row.get('Dataset_DOI') or '').strip() or None

        return result
    except Exception as e:
        print(f"Error parsing CSV response: {e}")
        print(response_text[:400])
        return None


def parse_json_like_response(text):
    """Find and parse a JSON object inside the model output."""
    try:
        if not text:
            return None
        cleaned = text.replace('```', '')
        m = re.search(r"\{[\s\S]*\}", cleaned)
        if not m:
            return None
        json_text = m.group(0)
        return json.loads(json_text)
    except Exception as e:
        print(f"parse_json_like_response error: {e}")
        return None


def find_doi(text):
    if not text:
        return None
    m = re.search(r'(10\.\d{4,9}/[^\s,\)\]]+)', text)
    return m.group(1) if m else None


def find_date(text):
    if not text:
        return None
    m = re.search(r'(20\d{2}-\d{2})', text)
    if m:
        return m.group(1)
    m2 = re.search(r'(20\d{2})', text)
    return m2.group(1) if m2 else None


def guess_sizes_from_text(text):
    out = {'Size_Events': None, 'Size_Files': None, 'Size_Bytes': None}
    if not text:
        return out
    # Search for explicit patterns
    # e.g. '1.2 million events', '1,200,000 events', 'approx 1.2e6 events'
    for m in re.finditer(r'([\d,\.]+)\s*(million|billion|thousand|m|b|k)?\s*(events?|files?|bytes?)', text, flags=re.I):
        num_s = m.group(1).replace(',', '')
        unit = (m.group(2) or '').lower()
        target = m.group(3).lower()
        try:
            val = float(num_s)
        except Exception:
            continue
        if unit in ['million','m']:
            val = int(val * 1e6)
        elif unit in ['billion','b']:
            val = int(val * 1e9)
        elif unit in ['thousand','k']:
            val = int(val * 1e3)
        else:
            val = int(val)
        if 'event' in target and out['Size_Events'] is None:
            out['Size_Events'] = int(val)
        if 'file' in target and out['Size_Files'] is None:
            out['Size_Files'] = int(val)
        if 'byte' in target and out['Size_Bytes'] is None:
            out['Size_Bytes'] = int(val)
    return out


def fill_missing_fields(parsed, llm_text, full_text):
    """Attempt to fill missing fields using llm_text (the model output) and full_text (PDF text).

    Steps:
      - search for DOI and date
      - heuristic scan for sizes
      - small LLM fallback returning JSON if many fields remain
    """
    if parsed is None:
        parsed = {}
    combined = ((llm_text or '') + '\n' + (full_text or ''))

    # DOI
    if not parsed.get('Dataset_DOI'):
        doi = find_doi(combined)
        if doi:
            parsed['Dataset_DOI'] = doi

    # Date
    if not parsed.get('Date'):
        d = find_date(combined)
        if d:
            parsed['Date'] = d

    # sizes
    sizes = guess_sizes_from_text(combined)
    for k in ['Size_Events','Size_Files','Size_Bytes']:
        if not parsed.get(k) and sizes.get(k) is not None:
            parsed[k] = sizes[k]

    # dataset name
    if not parsed.get('Dataset_Name'):
        m = re.search(r'(/[^\n,]+/[^\n,]+/[^\n,]+/AOD|/[^\n,]+/[^\n,]+/AODSIM)', combined)
        if m:
            parsed['Dataset_Name'] = m.group(0)

    # if still missing many fields, do an LLM JSON fallback
    needed = [k for k in ['Title','Authors','Date','Dataset_Name','Size_Events','Size_Files','Size_Bytes','Data_Format','Dataset_DOI'] if not parsed.get(k)]
    if needed:
        fallback_prompt = (
            "Given the following document text, return a compact JSON object with these keys: "
            "Title, Authors, Date, Dataset_Name, Size_Events, Size_Files, Size_Bytes, Data_Format, Dataset_DOI. "
            "Use null for unknown values. Numbers must be integers (no commas). Return ONLY valid JSON.\n\n" + (full_text or '')[:3000]
        )
        try:
            fb_resp = process_with_gpt(fallback_prompt)
            fb_json = parse_json_like_response(fb_resp)
            if fb_json:
                for k,v in fb_json.items():
                    if not parsed.get(k) and v is not None:
                        if k in ['Size_Events','Size_Files','Size_Bytes']:
                            try:
                                parsed[k] = int(v)
                            except Exception:
                                parsed[k] = parsed.get(k) or None
                        else:
                            parsed[k] = v
        except Exception as e:
            print(f"LLM fallback error: {e}")

    return parsed


def process_multiple_pdfs(pdf_directory, prompt, output_csv="extracted_results.csv", chunk_threshold=1500):
    """Process the pdfs and output to CSV"""

    results = []
    pdf_files = list(Path(pdf_directory).glob("*.pdf"))
    print(f"Found {len(pdf_files)} PDF files in {pdf_directory}")

    for i, pdf_path in enumerate(pdf_files, 1):
        print(f"Processing {i}/{len(pdf_files)}: {pdf_path.name}")
        try:
            pdf_text = extract_text_from_pdf(pdf_path)
            if not pdf_text:
                print(f" - No text extracted from {pdf_path.name}")
                continue
            word_count = len(pdf_text.split())
            print(f"  Extracted {word_count} words")

            if word_count > chunk_threshold:
                print(f"  Chunking document (>{chunk_threshold} words)")
                chunks = chunk_text(pdf_text, max_words=chunk_threshold)
                chunk_outputs = []
                for idx, chunk in enumerate(chunks, 1):
                    print(f"    Processing chunk {idx}/{len(chunks)}")
                    if "{text_chunk}" in prompt:
                        chunk_prompt = prompt.replace("{text_chunk}", chunk)
                    else:
                        chunk_prompt = f"{prompt}\n\nDocument chunk:\n{chunk}"
                    out = process_with_gpt(chunk_prompt)
                    chunk_outputs.append(out)
                combined_prompt = (
                    "Combine the following chunked extractions into a single "
                    "coherent CSV row. Remove duplicates and consolidate the data. Return ONLY the CSV with header and one data row:\n\n" + "\n\n---\n\n".join(chunk_outputs)
                )
                final = process_with_gpt(combined_prompt)
            else:
                final = extract_data_from_pdf_text(pdf_text, prompt)

            parsed_data = parse_csv_response(final)
            if parsed_data:
                parsed_data = fill_missing_fields(parsed_data, final, pdf_text)
                parsed_data['Source_File'] = pdf_path.name
                results.append(parsed_data)
                print(f"✓ Successfully processed {pdf_path.name}")
            else:
                print(f"⚠ Could not parse response for {pdf_path.name}, trying fallback")
                fallback_prompt = (
                    "You previously attempted to produce a CSV row but parsing failed. "
                    "Given the following document text, return a JSON object with keys: "
                    "Title, Authors, Date, Dataset_Name, Size_Events, Size_Files, Size_Bytes, Data_Format, Dataset_DOI. "
                    "Use null for unknown values, numbers as plain integers (no commas), and try to approximate missing sizes. "
                    "Return ONLY valid JSON.\n\n" + pdf_text[:2000]
                )
                fallback_resp = process_with_gpt(fallback_prompt)
                parsed_data = parse_json_like_response(fallback_resp)
                if parsed_data:
                    parsed_data = fill_missing_fields(parsed_data, fallback_resp, pdf_text)
                    parsed_data['Source_File'] = pdf_path.name
                    results.append(parsed_data)
                    print(f"✓ Successfully processed {pdf_path.name} (fallback)")
                else:
                    print(f"✗ Fallback also failed for {pdf_path.name}")

        except Exception as e:
            print(f"✗ Error processing {pdf_path.name}: {str(e)}")

    if results and output_csv:
        df = pd.DataFrame(results)
        column_order = ['Source_File', 'Title', 'Authors', 'Date', 'Dataset_Name', 'Size_Events', 'Size_Files', 'Size_Bytes', 'Data_Format', 'Dataset_DOI']
        df = df[[col for col in column_order if col in df.columns]]
        df.to_csv(output_csv, index=False, encoding='utf-8')
        print(f"\n✓ Results saved to {output_csv}")
        print(f"✓ Processed {len(results)} PDFs successfully")
        print("\nData Summary:")
        print(df[['Source_File', 'Title', 'Size_Events', 'Size_Files', 'Size_Bytes']].to_string())

    return results


# Example usage (kept for convenience; you can call process_multiple_pdfs elsewhere)
if __name__ == "__main__":
    PROMPT_TEMPLATE = """
    You are an expert at high energy particle physics and you understand jargon like \"events\" and datasets. You are also very, very careful and a good explainer.
    I need your help reading documents and extracting information about datasets used.

    Extract the following information:
    * Title of the paper
    * Authors of the paper (comma-separated if multiple)
    * Date of publication in YYYY-MM format (e.g., 2023-06)
    * Dataset name (collision or MC)
    * Size in number of events (as a plain integer, no commas)
    * Size in number of files (as a plain integer, no commas)
    * Size in bytes (as a plain integer, no commas or units)
    * Data format (AOD, miniAOD, nanoAOD, etc)
    * DOI of datasets used

    If exact numbers are not given, provide approximations without noting them as such.
    If you can't find data in the paper, look it up from cited CMS Open Data records or DOIs.
    Use regular hyphens (-) not em dashes (—).

    Return ONLY a CSV with this exact header:
    Title,Authors,Date,Dataset name (collision or MC),Size (events),Size (files),Size (bytes),Data format,Dataset DOI

    Then one data row with the extracted values. Use \"null\" for unknown values.
    NO explanations, NO markdown formatting, JUST the CSV.

    Document Text:
    "{text_chunk}"
    """

    results = process_multiple_pdfs(
        pdf_directory=r"C:/Users/ejren/OneDrive/DPOA_papers",
        prompt=PROMPT_TEMPLATE,
        output_csv="extracted_results.csv",
        chunk_threshold=1500
    )
    print("\n" + "="*60)
    print("PROCESSING COMPLETE")
    print("="*60)
    print(f"Total papers processed: {len(results)}")
