In [2]:
import re
import fitz  # PyMuPDF
import requests
import os
import xml.etree.ElementTree as ET
from collections import defaultdict
import pandas as pd
import nltk
from nltk.tokenize import sent_tokenize
import unicodedata
import google.generativeai as genai
import time
import json

# Ensure NLTK's punkt tokenizer is available
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt')


# Set your GenAI API key
genai_api_key = 'AIzaSyCp4B3RemoOh--Fqekd_sExgVLwKOMXsh0'  # Replace 'YOUR_API_KEY' with your actual API key
genai.configure(api_key=genai_api_key)

# Set the path to the input directory containing PDFs
input_dir = r"pp"

# Set the path to the output directory for CSV files
output_dir = r"out"

# Set the path to a specific PDF file to process, or set to None
pdf_file = r"pp/grover19a.pdf"  # Set to None to process all PDFs in the input_dir
# pdf_file = None  # Uncomment and set to None to process all PDFs in the input_dir


def read_pdf_file(file_path):
    """
    Reads a PDF file and extracts its text content using PyMuPDF.

    Parameters:
        file_path (str): Path to the input PDF file.

    Returns:
        str: Extracted text from the PDF.

    Raises:
        Exception: If there is an error reading the PDF file.
    """
    try:
        with fitz.open(file_path) as doc:
            text = ""
            for page in doc:
                page_text = page.get_text()
                if page_text:
                    text += page_text + "\n"
        return text
    except Exception as e:
        raise Exception(f"Error reading PDF file: {e}")


def pre_process_text(text):
    """
    Preprocesses the extracted text by removing hyphenations, extra spaces,
    correcting common OCR errors, and trimming whitespace.
    Also normalizes dash characters to standard hyphen '-'.

    Parameters:
        text (str): Raw text extracted from the PDF.

    Returns:
        str: Cleaned and preprocessed text.
    """
    # Remove hyphenation at line breaks
    text = re.sub(r'-\s*\n\s*', '', text)
    # Normalize various dash characters to standard hyphen
    text = re.sub(r'[–—−‑‒―]', '-', text)
    # Replace multiple spaces and tabs with a single space
    text = re.sub(r'[ \t]+', ' ', text)
    # Replace multiple newlines with double newline to preserve paragraphs
    text = re.sub(r'\n{2,}', '\n\n', text)
    # Correct common OCR errors (e.g., 'l' misread for '1')
    text = re.sub(r'\[l\]', '[1]', text, flags=re.IGNORECASE)
    text = re.sub(r'\[I\]', '[1]', text, flags=re.IGNORECASE)
    # Normalize Unicode characters
    text = unicodedata.normalize('NFKC', text)
    # Strip leading and trailing whitespace
    text = text.strip()
    return text


def find_reference_section(text):
    """
    Identifies the references section in the text based on common headings.
    If not found, assumes references start after 70% of the text.
    Returns a tuple of (reference_section, main_text).

    Parameters:
        text (str): The preprocessed text extracted from the PDF.

    Returns:
        tuple: (reference_section (str), main_text (str))
    """
    # Patterns to detect the "References" heading
    reference_section_patterns = [
        r'(?i)^\s*References\s*$',
        r'(?i)^\s*Bibliography\s*$',
        r'(?i)^\s*Works Cited\s*$',
        r'(?i)^\s*Literature Cited\s*$',
    ]
    ref_start = None
    for pattern in reference_section_patterns:
        match = re.search(pattern, text, re.MULTILINE)
        if match:
            ref_start = match.end()
            break
    if ref_start is None:
        # Assume references start after 70% of the text
        ref_start = int(len(text) * 0.7)
    # Reference section is from ref_start to the end
    reference_section = text[ref_start:].strip()
    main_text = text[:ref_start].strip()
    return reference_section, main_text


def detect_reference_style(reference_section):
    """
    Detects the style of references used in the reference section.
    Returns one of 'numbered_brackets', 'numbered', or 'unknown'.

    Parameters:
        reference_section (str): The extracted references section from the text.

    Returns:
        str: The detected reference style.
    """
    # Check for numbered references with brackets [1]
    if re.search(r'(?m)^\s*\[\d+\]', reference_section):
        return 'numbered_brackets'
    # Check for numbered references without brackets 1. or 1)
    elif re.search(r'(?m)^\s*\d+[\.\)]\s+', reference_section):
        return 'numbered'
    else:
        return 'unknown'


def segment_references_numbered_brackets(reference_section):
    """
    Segments references that start with [number].
    Returns a list of tuples (number, reference).

    Parameters:
        reference_section (str): The extracted references section from the text.

    Returns:
        list of tuple: List containing tuples of (reference_number, reference_text).
    """
    references = []
    # Correct common OCR errors in reference numbers
    reference_section = re.sub(r'\[l\]', '[1]', reference_section, flags=re.IGNORECASE)
    reference_section = re.sub(r'\[I\]', '[1]', reference_section, flags=re.IGNORECASE)
    # Split references based on [number]
    split_refs = re.split(r'\[\d+\]', reference_section)
    numbers = re.findall(r'\[(\d+)\]', reference_section)
    for num, ref in zip(numbers, split_refs[1:]):  # first split_refs[0] is before first [number]
        ref = ref.replace('\n', ' ').strip()
        references.append((num, ref))
    return references


def segment_references_numbered(reference_section):
    """
    Segments references that start with number. or number)
    Returns a list of tuples (number, reference).

    Parameters:
        reference_section (str): The extracted references section from the text.

    Returns:
        list of tuple: List containing tuples of (reference_number, reference_text).
    """
    references = []
    # Split references based on numbers followed by dot or parenthesis
    pattern = r'(?m)^\s*(\d+)[\.\)]\s+(.+?)(?=^\s*\d+[\.\)]\s+|\Z)'
    matches = re.findall(pattern, reference_section, re.DOTALL)
    for num, ref in matches:
        ref = ref.replace('\n', ' ').strip()
        references.append((num, ref))
    return references


def segment_references(reference_section, reference_style):
    """
    Segments the references section into individual references based on the reference style.

    Parameters:
        reference_section (str): The extracted references section from the text.
        reference_style (str): Detected reference style ('numbered_brackets', 'numbered', or 'unknown').

    Returns:
        list of tuple: List containing tuples of (reference_number, reference_text).
    """
    if not reference_section:
        return []
    if reference_style == 'numbered_brackets':
        return segment_references_numbered_brackets(reference_section)
    elif reference_style == 'numbered':
        return segment_references_numbered(reference_section)
    else:
        # Attempt both segmentation methods if reference style is unknown
        segmented_refs = segment_references_numbered_brackets(reference_section)
        if not segmented_refs:
            segmented_refs = segment_references_numbered(reference_section)
        return segmented_refs


def handle_multiple_citations(citation_numbers):
    """
    Expands citation ranges and lists like [2-4,6] or [8]-[12] into ['2', '3', '4', '6', '8', '9', '10', '11', '12']
    Handles various dash types.

    Parameters:
        citation_numbers (list of str): List of citation strings to expand.

    Returns:
        list of str: Expanded list of individual citation numbers as strings.
    """
    expanded = []
    for citation in citation_numbers:
        citation = citation.replace(' ', '')  # Remove any spaces
        # Split by comma
        parts = citation.split(',')
        for part in parts:
            # Check if the part is a range (e.g., '8-12')
            range_match = re.match(r'^(\d+)[\-–—](\d+)$', part)
            if range_match:
                start, end = int(range_match.group(1)), int(range_match.group(2))
                if start <= end:
                    expanded.extend([str(num) for num in range(start, end + 1)])
            else:
                # Single citation
                if part.isdigit():
                    expanded.append(part)
    return expanded


def map_citations_to_references_numbered_unified(main_text, references):
    """
    Maps numerical citations (with possible multiple citations within brackets) to their corresponding references.
    Returns a dictionary mapping reference numbers to their contexts and counts.

    Parameters:
        main_text (str): The main text of the PDF excluding the references section.
        references (list of dict): List of references with 'id', 'title', 'authors', and 'raw_reference'.

    Returns:
        dict: Mapping from reference number (str) to a dictionary containing 'contexts' (list) and 'count' (int).
    """
    contexts = defaultdict(lambda: {'contexts': set(), 'count': 0})
    sentences = sent_tokenize(main_text)
    # Precompile the citation pattern regex
    citation_pattern = re.compile(r'\[(\d+(?:\s*[,-–—]\s*\d+)*)\]')

    for idx, sentence in enumerate(sentences):
        matches = citation_pattern.findall(sentence)
        if matches:
            # Build context sentences: four before, current, four after
            start_idx = max(0, idx - 5)
            end_idx = min(len(sentences), idx + 6)
            context_sentences = sentences[start_idx:end_idx]
            context_text = ' '.join(context_sentences)
            for match in matches:
                # Expand the citation numbers
                citation_nums = handle_multiple_citations([match])
                for num in citation_nums:
                    contexts[num]['contexts'].add(context_text.strip())
                    contexts[num]['count'] += 1

    # Convert sets to lists
    return {num: {'contexts': list(data['contexts']), 'count': data['count']} for num, data in contexts.items()}


def save_citation_contexts_to_csv(references, citation_map, output_csv_path):
    """
    Saves the citation contexts mapped to references into a CSV file.

    Parameters:
        references (list of dict): List of references with 'id', 'title', 'authors', and 'raw_reference'.
        citation_map (dict): Mapping from reference number to their contexts and counts.
        output_csv_path (str): Path to save the output CSV file.
    """
    data = []
    for ref in references:
        ref_id = ref['id']
        title = ref['title'] if ref['title'] else ref['raw_reference']  # Use raw reference if title not found
        authors = ref.get('authors', [])
        authors_str = '; '.join(authors) if authors else 'No authors extracted.'

        contexts_data = citation_map.get(ref_id, {'contexts': [], 'count': 0})
        contexts = contexts_data.get('contexts', [])
        count = contexts_data.get('count', 0)
        # Remove duplicates by converting to set and back to list
        unique_contexts = list(set(contexts))
        if unique_contexts:
            # For the first context, include citation count
            for idx, context in enumerate(unique_contexts):
                if idx == 0:
                    data.append({
                        'Reference ID': ref_id,
                        'Title': title,
                        'Authors': authors_str,
                        'Citation Count': count,
                        'Citation Context': context.strip()
                    })
                else:
                    data.append({
                        'Reference ID': ref_id,
                        'Title': title,
                        'Authors': authors_str,
                        'Citation Count': '',  # Empty or zero for subsequent contexts
                        'Citation Context': context.strip()
                    })
        else:
            # No contexts found for this reference
            data.append({
                'Reference ID': ref_id,
                'Title': title,
                'Authors': authors_str,
                'Citation Count': count,
                'Citation Context': 'No citation contexts found.'
            })

    # Sort data by Reference ID numerically
    try:
        data_sorted = sorted(data, key=lambda x: int(x['Reference ID']))
    except ValueError:
        # If Reference ID is not purely numeric, sort as strings
        data_sorted = sorted(data, key=lambda x: x['Reference ID'])

    df = pd.DataFrame(data_sorted)
    df.to_csv(output_csv_path, index=False)


def parse_reference_string_with_llm(ref_text):
    """
    Parses a reference string using LLM to extract the title and authors.

    Parameters:
        ref_text (str): The reference string.

    Returns:
        dict: Parsed reference data, including 'title' and 'authors'.
    """
    # The genai API should already be configured in the main function

    # Construct the prompt
    prompt = f"""
Extract the title and authors from the following reference:

{ref_text}

Return the result in JSON format, with keys 'title' and 'authors'. The 'authors' should be a list of author names.

Example Output:
{{
    "title": "Title of the paper",
    "authors": ["Author One", "Author Two", "Author Three"]
}}
"""

    model = genai.GenerativeModel(model_name="gemini-1.5-flash")
    try:
        response = model.generate_content(prompt)
        # Sleep to respect rate limits
        time.sleep(13)  # Adjust based on your rate limit requirements
        # Attempt to find the JSON in the response
        response_text = response.text.strip()
        # Extract the JSON part from the response
        match = re.search(r'\{.*\}', response_text, re.DOTALL)
        if match:
            json_text = match.group(0)
            parsed_ref = json.loads(json_text)
            return parsed_ref  # Should contain 'title' and 'authors'
        else:
            print(f"No JSON found in LLM response: {response_text}")
            return {'title': '', 'authors': []}
    except Exception as e:
        print(f"Error parsing reference with LLM: {e}")
        return {'title': '', 'authors': []}


def map_references(actual_references):
    """
    Parses each reference string using LLM to extract titles and authors.

    Parameters:
        actual_references (list of tuple): List of tuples (ref_num, ref_text) from actual references.

    Returns:
        list of dict: References with IDs, raw_reference, 'title', and 'authors'.
    """
    merged_references = []
    for ref_num, ref_text in actual_references:
        parsed_ref = parse_reference_string_with_llm(ref_text)
        title = parsed_ref.get('title', "No title extracted.")
        authors = parsed_ref.get('authors', [])
        if not title:
            title = ref_text.strip()  # Use the entire reference string as the title
        merged_references.append({
            'id': ref_num,
            'raw_reference': ref_text,
            'title': title,
            'authors': authors
        })
    return merged_references


def main():
    # Create output directory if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    if pdf_file:
        # Process specific PDF file
        if not os.path.exists(pdf_file):
            print(f"Specified PDF file {pdf_file} does not exist.")
            return
        pdf_files = [pdf_file]
    else:
        # Get list of PDF files in the input directory
        pdf_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.lower().endswith('.pdf')]

    if not pdf_files:
        print("No PDF files found to process.")
        return

    no_citation_context_pdfs = []

    for pdf_path in pdf_files:
        pdf_file_name = os.path.basename(pdf_path)
        try:
            # Step 1: Extract main text from PDF using PyMuPDF
            raw_text = read_pdf_file(pdf_path)

            # Step 2: Preprocess the extracted text
            cleaned_text = pre_process_text(raw_text)

            # Step 3: Find the reference section
            reference_section, main_text = find_reference_section(cleaned_text)

            # Step 4: Detect reference style
            reference_style = detect_reference_style(reference_section)

            # Step 5: Segment references
            segmented_references = segment_references(reference_section, reference_style)
            segmented_references = [ref for ref in segmented_references if ref[1].strip()]

            # Proceed even if reference_style is 'unknown' and segmented_references is empty

            if not segmented_references:
                # Attempt alternative segmentation if possible
                # Here, you can add more segmentation strategies if needed
                # For now, if no references are found, consider citation contexts not extracted
                no_citation_context_pdfs.append(pdf_file_name)
                continue  # Skip to next PDF

            # Print statement before using Gemini to get titles and authors
            print("Using Gemini to get titles and authors")

            # Step 6: Map references (process each reference string with LLM to extract titles and authors)
            merged_references = map_references(segmented_references)

            # Step 7: Extract citations from main text
            citation_mapping = map_citations_to_references_numbered_unified(main_text, merged_references)

            # Print statement after citation context extraction
            print("OK, citation context extraction is done")

            # Check if citation contexts are found
            if not citation_mapping:
                no_citation_context_pdfs.append(pdf_file_name)
                continue  # Skip to next PDF

            # Step 8: Save to CSV
            # Name the CSV file the same as the PDF file but with .csv extension
            csv_file_name = os.path.splitext(pdf_file_name)[0] + '.csv'
            output_csv = os.path.join(output_dir, csv_file_name)
            save_citation_contexts_to_csv(merged_references, citation_mapping, output_csv)

            # Print statement after mapping and storing
            print("Mapping and storing done")

        except Exception as e:
            # If any error occurs during processing, consider citation contexts not extracted
            print(f"Error processing {pdf_file_name}: {e}")
            no_citation_context_pdfs.append(pdf_file_name)
            continue  # Continue to next PDF

    # After processing all PDFs, print the list of PDFs with missing citation contexts
    if no_citation_context_pdfs:
        print("PDFs for which citation contexts were not extracted:")
        for pdf in no_citation_context_pdfs:
            print(f"- {pdf}")
    else:
        print("Citation contexts were successfully extracted for all processed PDFs.")


if __name__ == '__main__':
    main()


Using Gemini to get titles and authors
Error parsing reference with LLM: Extra data: line 4 column 4 (char 120)
Error parsing reference with LLM: Extra data: line 4 column 4 (char 116)
Error parsing reference with LLM: Extra data: line 4 column 4 (char 136)
Error parsing reference with LLM: Extra data: line 4 column 4 (char 151)
OK, citation context extraction is done
PDFs for which citation contexts were not extracted:
- grover19a.pdf
