# LLM-Based Citation Extraction 

This notebook handles extraction of citations from scientific documents:
1. Loading the processed document from JSON
2. Using an LLM to extract citations section-by-section
3. For each citation, identifying:
   - The scientific claim being supported
   - The citation text
   - Expanded citation keys
   - The paragraph context
4. Saving the structured citation data for further analysis

## Setup

First, let's import the necessary libraries and set up our environment.

In [1]:
import os
import re
from pathlib import Path
import json
import pandas as pd
from typing import List, Dict, Optional, Any, Union
from pydantic import BaseModel, Field
from datetime import datetime
import matplotlib.pyplot as plt
from openai import OpenAI

## Define Data Models

We'll use Pydantic models to structure our citation data.

In [2]:
class ExtractedCitation(BaseModel):
    """A citation extracted by the LLM with expanded citation keys"""
    claim: str = Field(..., description="The scientific claim being made and supported by the citation")
    citation_text: str = Field(..., description="The exact LaTeX citation text (e.g., '$^{1}$' or '$^{2-5}$')")
    citation_keys: List[int] = Field(..., description="Expanded list of citation keys as integers")
    paragraph: str = Field(..., description="The paragraph containing the citation for context")
    section_id: str = Field(..., description="ID of the section containing the citation")

class Reference(BaseModel):
    """A reference from the bibliography"""
    reference_id: int = Field(..., description="The numeric ID of the reference (e.g., 1 from [1])")
    reference_text: str = Field(..., description="The full text of the reference")
    pmcid: Optional[str] = Field(None, description="PMCID if present in the reference")

class ReferenceList(BaseModel):
    """List of references extracted from the bibliography"""
    references: List[Reference] = Field(..., description="List of references extracted from the bibliography")
    
class CitationAnalysisResults(BaseModel):
    """Results of the citation extraction process"""
    document_id: str = Field(..., description="ID of the processed document")
    citations: List[ExtractedCitation] = Field(..., description="Citations extracted from the document")
    total_citations: int = Field(..., description="Total number of citations extracted")
    citations_by_section: Dict[str, int] = Field(..., description="Count of citations by section ID")
    references: Optional[List[Reference]] = Field(None, description="References extracted from the bibliography")
    processed_date: str = Field(..., description="Date the analysis was performed")

In [3]:
class LLMCitationExtractor:
    """
    Uses LLM to extract citations from document sections with expanded citation keys
    """
    
    def __init__(self, api_key=None, model="o3-mini"):
        """Initialize with OpenAI API key"""
        self.api_key = api_key or os.environ.get("OPENAI_API_KEY")
        if not self.api_key:
            raise ValueError("OpenAI API key is required")
        
        self.client = OpenAI(api_key=self.api_key)
        self.model = model
        
    def extract_citations_from_section(self, section_id, section_title, section_content):
        """
        Extract citations from a single section using LLM with expanded citation keys
        
        Args:
            section_id: ID of the section
            section_title: Title of the section
            section_content: Content of the section
            
        Returns:
            List of ExtractedCitation objects
        """
        # DEBUG: Print section info to verify data is being sent to the LLM
        print(f"DEBUG: Processing section_id={section_id}, title='{section_title}'")
        print(f"DEBUG: Content sample: {section_content[:100]}...")
        
        # Define a model for the response format
        class CitationResponse(BaseModel):
            citations: List[ExtractedCitation] = Field(
                default_factory=list,
                description="Extracted citations with their claims, citation text, expanded keys, and paragraph context"
            )
        
        # Define the prompt for citation extraction
        prompt = f"""
        You are an expert at identifying citations in scientific papers. Analyze the following section of text and extract all citations with their context.

        SECTION: {section_title}

        TEXT:
        ```
        {section_content}
        ```

        TASK:
        1. Identify ALL citation callouts that match these LaTeX patterns:
           - $^{{n}}$ (Single citation, e.g., $^{{1}}$)
           - $^{{m-n}}$ (Range citation, e.g., $^{{2-5}}$)
           - $^{{a,b,c}}$ (List citation, e.g., $^{{6,7,8}}$)
           - $^{{a-b,c,d-e}}$ (Complex citation, e.g., $^{{2-4,7,9-10}}$)
           - Any variations with spaces like ${{ }}^{{1}}$

        2. For each citation, extract:
           - The scientific claim being supported by the citation
           - The exact citation text as it appears (e.g., "$^{{2-5}}$")
           - The citation keys, expanded to individual numbers:
              * For a range like "2-5", expand to [2,3,4,5]
              * For a list like "6,7,8", expand to [6,7,8]
              * For complex citations like "2-4,7,9-10", expand to [2,3,4,7,9,10]
           - The full paragraph containing the citation for context
           
        IMPORTANT: Be sure to check figure captions for citations as well.
        """
        
        try:
            # Use the parse method to get structured output directly into our Pydantic model
            completion = self.client.beta.chat.completions.parse(
                model=self.model,
                messages=[
                    {"role": "system", "content": "You are a citation extraction specialist. Extract all citation callouts from scientific texts with high precision."},
                    {"role": "user", "content": prompt}
                ],
                response_format=CitationResponse
            )
            
            # Get the extracted citations directly from the parsed response
            extracted_citations = completion.choices[0].message.parsed.citations
            
            # Set the section_id for all extracted citations
            for citation in extracted_citations:
                citation.section_id = section_id
            
            print(f"Extracted {len(extracted_citations)} citations from section '{section_title}'")
            return extracted_citations
            
        except Exception as e:
            print(f"Error extracting citations from section '{section_title}': {e}")
            return []
    
    def extract_all_citations(self, document_data):
        """
        Extract all citations from document sections
        
        Args:
            document_data: Processed document data from JSON
            
        Returns:
            List of ExtractedCitation objects
        """
        all_citations = []
        processed_section_ids = set()  # Track sections we've processed
        
        # Process all sections and their subsections
        def process_section(section, parent_id=""):
            section_id = section['section_id']
            processed_section_ids.add(section_id)
            
            # Skip the bibliography/references section
            if any(ref_term in section['title'].lower() for ref_term in ['literature cited', 'references', 'bibliography']):
                print(f"Skipping bibliography section: {section['title']}")
                return
            
            # Extract citations from this section
            section_citations = self.extract_citations_from_section(
                section_id, 
                section['title'], 
                section['content']
            )
            all_citations.extend(section_citations)
            
            # Process subsections recursively
            if 'subsections' in section and section['subsections']:
                for subsection in section['subsections']:
                    process_section(subsection, section_id)
        
        # Process all top-level sections
        for section in document_data['sections']:
            process_section(section)
        
        # Process figure captions
        if 'figures' in document_data:
            for figure in document_data['figures']:
                figure_id = figure['figure_id']
                caption = figure['caption']
                
                # Extract citations from the figure caption
                caption_citations = self.extract_citations_from_section(
                    figure_id,
                    f"Figure {figure.get('figure_number', '')}",
                    caption
                )
                all_citations.extend(caption_citations)
                processed_section_ids.add(figure_id)
        
        print(f"Extracted {len(all_citations)} total citations from document using LLM")
        return all_citations, processed_section_ids

In [4]:
def analyze_citations_with_llm(document_json_path, openai_api_key=None):
    """
    Extract and analyze citations in a document using LLM.
    
    Args:
        document_json_path: Path to the processed document JSON file
        openai_api_key: OpenAI API key (optional, will use environment variable if not provided)
        
    Returns:
        CitationAnalysisResults object
    """
    # Load the document data
    with open(document_json_path, 'r', encoding='utf-8') as f:
        document_data = json.load(f)
    
    document_id = Path(document_json_path).stem
    
    print(f"Processing document: {document_id}")
    
    # Extract citations using LLM
    print("\n=== EXTRACTING CITATIONS USING LLM ===")
    extractor = LLMCitationExtractor(api_key=openai_api_key)
    citations, processed_section_ids = extractor.extract_all_citations(document_data)
    
    # Count citations by section
    citations_by_section = {}
    
    # Initialize counts for all processed sections (even those with zero citations)
    for section_id in processed_section_ids:
        citations_by_section[section_id] = 0
    
    # Count citations for each section
    for citation in citations:
        section_id = citation.section_id
        citations_by_section[section_id] = citations_by_section.get(section_id, 0) + 1
    
    # Extract bibliography references
    references = extract_bibliography(document_data, openai_api_key)
    
    # Create the results object
    results = CitationAnalysisResults(
        document_id=document_id,
        citations=citations,
        total_citations=len(citations),
        citations_by_section=citations_by_section,
        references=references,
        processed_date=datetime.now().isoformat()
    )
    
    # Print summary statistics
    print("\n===== CITATION ANALYSIS SUMMARY =====")
    print(f"Total citations extracted: {results.total_citations}")
    print(f"Total references extracted: {len(references)}")
    
    print("\nCitation distribution by section:")
    # Get section titles from document data
    section_titles = {}
    
    def collect_section_titles(section, parent_title=""):
        section_id = section['section_id']
        section_title = section['title']
        section_titles[section_id] = section_title
        
        if 'subsections' in section and section['subsections']:
            for subsection in section['subsections']:
                collect_section_titles(subsection, section_title)
    
    # Collect section titles from document data
    for section in document_data['sections']:
        collect_section_titles(section)
    
    # Add figure captions to section titles
    if 'figures' in document_data:
        for figure in document_data['figures']:
            figure_id = figure['figure_id']
            section_titles[figure_id] = f"Figure {figure.get('figure_number', '')}"
    
    # Print citation counts by section with titles
    for section_id, count in sorted(citations_by_section.items(), key=lambda x: x[1], reverse=True):
        section_title = section_titles.get(section_id, "Unknown")
        print(f"  - {section_title}: {count} citations")
    
    return results

In [5]:
def save_citation_results(results, output_dir='citation_analysis'):
    """
    Save the citation analysis results to files.
    
    Args:
        results: CitationAnalysisResults object
        output_dir: Directory to save the files
    """
    # Create directory if it doesn't exist
    output_path = Path(output_dir)
    output_path.mkdir(exist_ok=True, parents=True)
    
    # Save the complete results as JSON
    with open(output_path / 'citations.json', 'w', encoding='utf-8') as f:
        f.write(results.model_dump_json(indent=2))
    
    # Save citations as CSV
    citations_df = pd.DataFrame([{
        'claim': citation.claim,
        'citation_text': citation.citation_text,
        'citation_keys': ','.join(map(str, citation.citation_keys)),
        'section_id': citation.section_id,
        'paragraph_length': len(citation.paragraph) 
    } for citation in results.citations])
    
    citations_df.to_csv(output_path / 'citations.csv', index=False)
    
    # Save references as CSV if available
    if results.references:
        references_df = pd.DataFrame([{
            'reference_id': ref.reference_id,
            'reference_text': ref.reference_text,
            'pmcid': ref.pmcid or ''
        } for ref in results.references])
        
        references_df.to_csv(output_path / 'bibliography.csv', index=False)
        print(f"Saved {len(results.references)} references to {output_path}/bibliography.csv")
    
    print(f"\nAnalysis results saved to {output_path}/")
    
    return citations_df

In [6]:
def extract_bibliography(document_data, openai_api_key=None):
    """
    Extract structured bibliography from document.
    
    Args:
        document_data: Processed document data from JSON
        openai_api_key: OpenAI API key (optional, will use environment variable if not provided)
        
    Returns:
        List of Reference objects
    """
    # Find the bibliography section
    bibliography_section = None
    for section in document_data['sections']:
        if any(ref_term in section['title'].lower() for ref_term in ['literature cited', 'references', 'bibliography']):
            bibliography_section = section
            break
    
    if not bibliography_section:
        print("Bibliography section not found")
        return []
    
    # Get the bibliography text
    bibliography_text = bibliography_section['content']
    
    print(f"\n=== EXTRACTING BIBLIOGRAPHY ===")
    print(f"Processing bibliography section: {bibliography_section['title']}")
    
    # Use the LLM to extract references
    client = OpenAI(api_key=openai_api_key or os.environ.get("OPENAI_API_KEY"))
    
    prompt = f"""
    Parse this bibliography into a structured format with the following fields:
    
    1. reference_id: The numeric ID of the reference (e.g., 1 from [1])
    2. reference_text: The full text of the reference
    3. pmcid: Extract any PMCID if present (e.g., "PMC1234567"), otherwise leave empty
    
    Bibliography:
    ```
    {bibliography_text}
    ```
    
    IMPORTANT:
    - Make sure to extract ALL references
    - Capture the entire reference text, including any URLs or DOIs
    - Pay special attention to extracting PMCIDs correctly
    - If there are no PMCIDs, it's fine to leave them empty
    - Verify that the reference numbers form a complete sequence (1,2,3,...) with no gaps
    """
    
    try:
        completion = client.beta.chat.completions.parse(
            model="gpt-4.1",
            messages=[
                {"role": "system", "content": "You are an expert at parsing academic bibliographies into structured data."},
                {"role": "user", "content": prompt}
            ],
            response_format=ReferenceList
        )
        
        references = completion.choices[0].message.parsed.references
        
        # Verify we have all references in sequence
        extracted_ids = sorted([ref.reference_id for ref in references])
        expected_ids = list(range(1, max(extracted_ids) + 1))
        
        if extracted_ids != expected_ids:
            missing_ids = set(expected_ids) - set(extracted_ids)
            print(f"WARNING: Missing references: {missing_ids}")
        
        print(f"Extracted {len(references)} references from bibliography")
        return references
        
    except Exception as e:
        print(f"Error extracting bibliography: {e}")
        return []

In [7]:
# Enter the path to your processed document JSON
document_json_path = "R35_MIRA_document.json"

# Enter your OpenAI API key here or set it as an environment variable
openai_api_key = os.environ.get("OPENAI_API_KEY")

try:
    # Check if the document JSON exists
    if not Path(document_json_path).exists():
        print(f"Document JSON not found at {document_json_path}")
        print("Run the ocr.ipynb notebook first to create the document JSON.")
    elif not openai_api_key:
        print("OpenAI API key not provided. Set the OPENAI_API_KEY environment variable or provide it directly.")
    else:
        # Analyze the document's citations using LLM
        results = analyze_citations_with_llm(document_json_path, openai_api_key)
        
        # Save the results
        citations_df = save_citation_results(results)
        
        # Display a sample of the results
        print("\n=== SAMPLE OF EXTRACTED CITATIONS ===")
        print(citations_df.head())
        
        # Print some example claims and their citation keys
        print("\n=== EXAMPLE CLAIMS WITH EXPANDED CITATION KEYS ===")
        for i, citation in enumerate(results.citations[:5]):  # Show first 5 citations
            print(f"\n{i+1}. Claim: {citation.claim[:100]}..." if len(citation.claim) > 100 else f"\n{i+1}. Claim: {citation.claim}")
            print(f"   Citation: {citation.citation_text}")
            print(f"   Expanded keys: {citation.citation_keys}")
        
        # Print sample of extracted references if available
        if results.references:
            print("\n=== SAMPLE OF EXTRACTED REFERENCES ===")
            for i, ref in enumerate(results.references[:5]):  # Show first 5 references
                print(f"\n{i+1}. Reference ID: {ref.reference_id}")
                print(f"   Text: {ref.reference_text[:100]}..." if len(ref.reference_text) > 100 else f"   Text: {ref.reference_text}")
                if ref.pmcid:
                    print(f"   PMCID: {ref.pmcid}")
        
except Exception as e:
    print(f"Error analyzing document: {e}")
    import traceback
    traceback.print_exc()

Processing document: R35_MIRA_document

=== EXTRACTING CITATIONS USING LLM ===
DEBUG: Processing section_id=section_0, title='# A Background'
DEBUG: Content sample: A Background 

Our research is focused on identifying and understanding the mechanisms that ensure t...
Extracted 22 citations from section '# A Background'
DEBUG: Processing section_id=section_1, title='# B Recent Research Progress'
DEBUG: Content sample: B Recent Research Progress 

Our work over the last four years has described, at nucleotide resoluti...
Extracted 19 citations from section '# B Recent Research Progress'
DEBUG: Processing section_id=section_2, title='# C Overview of Future Research'
DEBUG: Content sample: C Overview of Future Research 

We will continue to work on and address major questions in DNA repli...
Extracted 0 citations from section '# C Overview of Future Research'
DEBUG: Processing section_id=subsection_0, title='## C. 1 Chromatin assembly behind the fork'
DEBUG: Content sample: C. 1 Chromatin

## Conclusion

In this notebook, we:
1. Loaded the processed document from JSON
2. Used an LLM to extract citations section-by-section
3. For each citation, the LLM:
   - Identified the scientific claim being supported
   - Extracted the exact citation text
   - Expanded citation keys (ranges and lists) into individual numbers
   - Captured the surrounding paragraph for context
4. Extracted bibliography references with PMCIDs when available
5. Saved the structured citation data for further analysis

The LLM-based approach provides a more intelligent extraction of citations by understanding the context and automatically handling the expansion of citation ranges and lists. This eliminates the need for complex regex patterns and separate post-processing steps, resulting in more accurate identification of claims and their supporting citations.

### Future Improvements

- **Document Structure Enhancement**: When processing markdown for the document structure JSON, add a specific flag for bibliography sections to make them easier to identify and skip during citation extraction.
- **Subsection Processing Optimization**: Improve how subsections are processed to ensure more efficient citation extraction throughout the document hierarchy.
- **Citation Pattern Recognition**: Further refine the LLM prompt to handle various citation formats including uncommon LaTeX variations.
- **Bibliography Chunking**: If bibliography extraction completeness is an issue, implement a chunking strategy to process larger bibliographies.
- **PMCID Validation**: Add additional validation for PMCIDs, including format checking and verification.