<a href="https://colab.research.google.com/github/hsandaver/hsandaver/blob/main/entity_extractor_wikipedia_experimental_entity_picker.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install necessary libraries
import sys
import subprocess

def install(package):
    subprocess.check_call([sys.executable, "-m", "pip", "install", "--quiet", package])

# Install required packages
install("pymupdf")
install("spacy")
install("SPARQLWrapper")
install("pandas")
install("tqdm")
install("requests")

# Import libraries
import spacy
from spacy.pipeline import EntityRuler
import fitz  # PyMuPDF
import pandas as pd
from SPARQLWrapper import SPARQLWrapper, JSON
from google.colab import files
from IPython.display import display  # For displaying DataFrame
from tqdm import tqdm  # Progress bar
import re
from collections import Counter
import requests

# Function to set up spaCy NLP model with EntityRuler
def setup_nlp():
    # Download the spaCy model if not already present
    try:
        nlp = spacy.load("en_core_web_md")  # Using medium model for better accuracy
    except OSError:
        print("Downloading 'en_core_web_md' model...")
        subprocess.check_call([sys.executable, "-m", "spacy", "download", "en_core_web_md"])
        nlp = spacy.load("en_core_web_md")

    # Initialize the EntityRuler and add it to the pipeline before the NER component
    ruler = nlp.add_pipe("entity_ruler", before="ner")

    # Define patterns to exclude (e.g., titles like Mr., Ms., Dr.)
    patterns = [
        {"label": "TITLE", "pattern": "Mr."},
        {"label": "TITLE", "pattern": "Ms."},
        {"label": "TITLE", "pattern": "Dr."},
        {"label": "TITLE", "pattern": "Prof."},
        {"label": "TITLE", "pattern": "Sir"},
        {"label": "TITLE", "pattern": "Lady"},
        # Add more patterns as needed
    ]

    # Add patterns to the EntityRuler
    ruler.add_patterns(patterns)

    return nlp

# Initialize spaCy NLP model
nlp = setup_nlp()

# Function to upload PDF in Colab
def upload_pdf():
    print("Please upload your PDF file.")
    uploaded = files.upload()
    if not uploaded:
        print("No file uploaded. Exiting.")
        sys.exit()
    pdf_path = next(iter(uploaded))
    return pdf_path

# Function to extract text from PDF
def extract_text_from_pdf(doc):
    text = ""
    for page_num in range(len(doc)):
        page = doc.load_page(page_num)  # Load page
        page_text = page.get_text()  # Extract text from page
        text += page_text + "\n"  # Add newline for separation between pages
    return text

# Function to extract contextual keywords (optional, advanced)
def extract_contextual_keywords(text, nlp_model):
    """
    Extract contextual keywords related to person entities.

    :param text: Extracted text from PDF
    :param nlp_model: Loaded spaCy NLP model
    :return: Dictionary mapping entity names to lists of contextual keywords
    """
    doc = nlp_model(text)
    context = {}

    for ent in doc.ents:
        if ent.label_ == "PERSON":
            # Extract the sentence containing the entity as a Span object
            sentence = ent.sent
            # Extract nouns and adjectives as potential contextual keywords
            keywords = [token.lemma_ for token in sentence if token.pos_ in ['NOUN', 'ADJ']]
            context.setdefault(ent.text.strip(), []).extend(keywords)

    # Remove duplicates
    for key in context:
        context[key] = list(set(context[key]))

    return context

# Function to extract person entities with filtering
def extract_person_entities(text, nlp_model):
    doc_nlp = nlp_model(text)
    # Extract entities labeled as PERSON
    person_entities = [ent.text.strip() for ent in doc_nlp.ents if ent.label_ == "PERSON"]

    print("\n=== All Extracted Person Entities ===")
    for idx, ent in enumerate(person_entities, 1):
        print(f"{idx}. {ent}")

    # Count the frequency of each entity
    entity_counts = Counter(person_entities)

    # Define a regex pattern for valid names (allowing apostrophes and more)
    pattern = re.compile(r'^[A-Za-z\s\-\.\']+$')

    # Define a blacklist of common false positives
    blacklist = {'John Doe', 'Jane Smith', 'Mr.', 'Ms.', 'Dr.', 'Prof.', 'Sir', 'Lady'}  # Extend as needed

    # Filter entities:
    # - Appear at least once
    # - Match the regex pattern
    # - Not in the blacklist
    # - Length greater than 1
    filtered_entities = [
        ent for ent, count in entity_counts.items()
        if count >= 1 and pattern.match(ent) and ent not in blacklist and len(ent) > 1
    ]

    print("\n=== Filtered Person Entities ===")
    for idx, ent in enumerate(filtered_entities, 1):
        print(f"{idx}. {ent}")

    return list(set(filtered_entities))

# Function to search Wikidata using the Search API
def search_wikidata_api(entity):
    """
    Search Wikidata using the Search API.

    :param entity: Name of the entity to search
    :return: List of search results with labels and descriptions
    """
    url = "https://www.wikidata.org/w/api.php"
    params = {
        'action': 'wbsearchentities',
        'search': entity,
        'language': 'en',
        'format': 'json',
        'limit': 10  # Increased limit to 10
    }
    print(f"\nSearching Wikidata for: '{entity}' with limit=10")
    response = requests.get(url, params=params)
    if response.status_code != 200:
        print(f"Error: Received status code {response.status_code} from Wikidata API.")
        return []
    data = response.json()
    return data.get('search', [])

# Function to fetch entity details by Wikidata ID
def fetch_entity_by_id(wikidata_id):
    """
    Fetch entity details from Wikidata by ID.

    :param wikidata_id: Wikidata ID (e.g., Q42)
    :return: Dictionary with 'Wikidata ID', 'Label', 'Description'
    """
    url = f"https://www.wikidata.org/wiki/Special:EntityData/{wikidata_id}.json"
    response = requests.get(url)
    if response.status_code != 200:
        print(f"Error: Received status code {response.status_code} when fetching entity {wikidata_id}.")
        return {
            "Wikidata ID": wikidata_id,
            "Label": "N/A",
            "Description": "N/A"
        }
    data = response.json()
    try:
        entity = data['entities'][wikidata_id]
        label = entity['labels']['en']['value']
        description = entity['descriptions']['en']['value']
        return {
            "Wikidata ID": wikidata_id,
            "Label": label,
            "Description": description
        }
    except KeyError:
        return {
            "Wikidata ID": wikidata_id,
            "Label": "N/A",
            "Description": "N/A"
        }

# Function to rank entities based on contextual keywords
def rank_entities(entities, keywords):
    """
    Rank entities based on the number of keyword matches in their descriptions.

    :param entities: List of entities (dictionaries with 'Description')
    :param keywords: List of keywords to match against descriptions
    :return: The entity with the highest keyword match count
    """
    if not keywords:
        # If no keywords provided, return None
        return None

    ranked_entities = []
    for entity in entities:
        description = entity.get("Description", "").lower()
        match_count = sum(keyword.lower() in description for keyword in keywords)
        ranked_entities.append((match_count, entity))

    # Sort entities by match_count descending
    ranked_entities.sort(key=lambda x: x[0], reverse=True)

    if ranked_entities and ranked_entities[0][0] > 0:
        # Return the entity with the highest match count
        return ranked_entities[0][1]
    else:
        # No matching keywords found
        return None

# Main processing function
def process_pdf():
    pdf_path = upload_pdf()
    print(f"\nProcessing PDF: {pdf_path}")
    try:
        doc = fitz.open(pdf_path)
    except Exception as e:
        print(f"Error opening PDF: {e}")
        sys.exit()

    pdf_text = extract_text_from_pdf(doc)
    print("\n=== Extracted Text from PDF ===")
    # Display the first 1000 characters to avoid flooding the output
    print(pdf_text[:1000] + "..." if len(pdf_text) > 1000 else pdf_text)

    person_entities = extract_person_entities(pdf_text, nlp)
    print(f"\nFound {len(person_entities)} unique person entities.")

    if not person_entities:
        print("No person entities found in the PDF.")
        sys.exit()

    # Display the list of extracted person entities
    print("\n=== Extracted Person Entities ===")
    for idx, entity in enumerate(person_entities, 1):
        print(f"{idx}. {entity}")

    # Optional: Allow user to inspect entities before querying
    user_input = input("\nDo you want to proceed with querying Wikidata for these entities? (yes/no): ").strip().lower()
    if user_input not in ['yes', 'y']:
        print("Operation cancelled by the user.")
        sys.exit()

    # Extract contextual keywords (optional, advanced)
    contextual_keywords = extract_contextual_keywords(pdf_text, nlp)

    # Initialize cache to store previously searched entities
    entity_cache = {}

    # Query Wikidata for each person entity and collect data with progress bar
    entity_data = []
    for entity in tqdm(person_entities, desc="Processing Entities", unit="entity"):
        print(f"\nProcessing entity: '{entity}'")

        # Check cache first
        if entity in entity_cache:
            print(f"Retrieving cached results for entity: '{entity}'")
            search_results = entity_cache[entity]
        else:
            search_results = search_wikidata_api(entity)
            # Store in cache
            entity_cache[entity] = search_results

        if not search_results:
            # No entities found, append N/A
            entity_data.append({
                "Name": entity,
                "Wikidata ID": "N/A",
                "Label": "N/A",
                "Description": "N/A"
            })
            continue

        # If only one entity found, select it automatically
        if len(search_results) == 1:
            selected_entity = {
                "Wikidata ID": search_results[0]['id'],
                "Label": search_results[0]['label'],
                "Description": search_results[0].get('description', 'N/A')
            }
            print(f"Automatically selected: {selected_entity['Label']} - {selected_entity['Description']}")
        else:
            # Multiple entities found, prompt user to select
            print(f"Multiple Wikidata entries found for '{entity}':")
            for idx, ent in enumerate(search_results, 1):
                print(f"{idx}. {ent['label']} - {ent.get('description', 'N/A')}")

            # Prompt user to select the correct entity
            while True:
                try:
                    selection = int(input(f"Select the correct entity for '{entity}' (1-{len(search_results)}), or 0 to skip: "))
                    if selection == 0:
                        selected_entity = None
                        print(f"Skipping entity: '{entity}'")
                        break
                    elif 1 <= selection <= len(search_results):
                        selected_entity = {
                            "Wikidata ID": search_results[selection - 1]['id'],
                            "Label": search_results[selection - 1]['label'],
                            "Description": search_results[selection - 1].get('description', 'N/A')
                        }
                        print(f"Selected: {selected_entity['Label']} - {selected_entity['Description']}")
                        break
                    else:
                        print(f"Please enter a number between 0 and {len(search_results)}.")
                except ValueError:
                    print("Invalid input. Please enter a number.")

        if selected_entity:
            entity_data.append({
                "Name": entity,
                "Wikidata ID": selected_entity["Wikidata ID"],
                "Label": selected_entity["Label"],
                "Description": selected_entity["Description"]
            })
        else:
            entity_data.append({
                "Name": entity,
                "Wikidata ID": "N/A",
                "Label": "N/A",
                "Description": "N/A"
            })
        # Optional: Delay to respect rate limits
        # import time
        # time.sleep(1)  # Sleep for 1 second

    # Convert the results to a pandas DataFrame for display
    df = pd.DataFrame(entity_data)

    # Display the DataFrame
    print("\n=== Person Entities Extracted ===")
    display(df)

    # Optionally, allow the user to download the DataFrame as a CSV
    try:
        csv = df.to_csv(index=False)
        with open("person_entities.csv", "w", encoding='utf-8') as f:
            f.write(csv)
        print("\nDownloading 'person_entities.csv'...")
        files.download('person_entities.csv')
    except Exception as e:
        print(f"Error downloading CSV: {e}")

# Execute the main processing function
process_pdf()