In [None]:
import os
import re
from dotenv import load_dotenv
from gliner import GLiNER
from unstructured.partition.pdf import partition_pdf
from unstructured.cleaners.core import (clean, 
                                        group_broken_paragraphs, 
                                        clean_extra_whitespace)

from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForTokenClassification

from ollama import chat
from pydantic import BaseModel
from typing import Literal, List, Optional, Dict, Any

from langchain.chains import RetrievalQA
from langchain_ollama.llms import OllamaLLM
from langchain_ollama import OllamaEmbeddings
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.vectorstores import InMemoryVectorStore
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter

# Load environment variables from .env file
load_dotenv()


In [None]:
def clean_text(text):
    text = clean(text)
    text = group_broken_paragraphs(text)
    text = clean_extra_whitespace(text)
    return text

def extract_sentences(elements):
    all_text = []
    for e in elements:
        if e.text:
            # Clean the text
            text = clean_text(e.text)
            all_text.append(text)
    return all_text

def get_connections(sentences, model, labels):
    connections = []
    for sent in sentences:
        c = []
        entities = model.predict_entities(sent, labels, threshold=0.65)
        for entity in entities:
            c.append((entity["text"], entity["label"]))
        c = list(set(c))  # Remove duplicates
        if len(c) > 1:
            connections.append(c)
    return connections

def create_edge_list(connections, file_name="edge_list.csv"):
    edge_list = []
    with open(file_name, 'w') as f:
        f.write("Source,Target\n")
        for c in connections:
            for i in range(len(c)):
                source = c[0][0]
                if i == 0:
                    continue
                edge_list.append((source, c[i][0]))
                f.write(f"{source},{c[i][0]}\n")
    return edge_list


In [None]:
# Initialize GLiNER with the base model
# model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1")
# model = GLiNER.from_pretrained("urchade/gliner_large-v2.1")
model = GLiNER.from_pretrained("numind/NuNER_Zero-span")
labels = ["Person", "Company", "Organization"]


In [None]:
file = "TrumpEpstein/PDF/Former Models for Donald Trump’s Agency Say They Violated Immigration Rules and Worked Illegally.pdf"
file_name = "TrumpEpstein/edge_lists/" + file.split("/")[-1]
elements = partition_pdf(file)
sentences = extract_sentences(elements)
connections = get_connections(sentences, model, labels)
edge_list = create_edge_list(connections, file_name=f"{file_name}.csv")
text = " ".join([s for s in sentences])


In [None]:
# TODO: Update this to use NER entities from GLiNER 
# TODO: Example: ('Gianfranco Ferré', 'Person'), ('Jean Paul Gaultier', 'Person'), ('Vogue', 'Organization')
def identify_person(name: str, entity_type: str, rag_chain, max_attempts: int = 2) -> str:
    """
    Identify a person's full name using the RAG pipeline with structured output.
    """
    queries = [
        f"What is the full name of {name}?",
        f"Find all mentions of {name} in the text. What is their complete name?",
        f"Who is {name}? Provide their full name."
    ]
    
    for i, query_text in enumerate(queries[:max_attempts]):
        try:
            # Pass both name and entity_type as separate inputs
            query = {
                "input": name,
                "entity_type": entity_type
            }
            response = rag_chain.invoke(query)
            answer = response['answer'].strip()
            # Post-process the answer for consistency
            return format_name_response(name, entity_type, answer)
            
        except Exception as e:
            print("Exception occurred:", e)
            if i == max_attempts - 1:
                return f"{name} (unknown)"
            continue
    
    return f"{name} (unknown)"

def format_name_response(input_name: str, entity_type: str, response: str) -> str:
    """
    Format the response to ensure consistent output format.
    """
    response_lower = response.lower()
    # Check for explicit indicators
    if any(indicator in response_lower for indicator in ['pseudonym', 'fake name', 'alias', 'not real']):
        # Extract the name part before adding (pseudonym)
        clean_name = response.split('(')[0].strip() if '(' in response else input_name
        return f"{clean_name} (pseudonym)"
    
    elif any(indicator in response_lower for indicator in ["unclear"]):
        return f"{input_name}, (unclear)" # f"{input_name}, (unknown)"

    elif any(indicator in response_lower for indicator in ["don't know", "unknown", "not found"]):
        return f"{input_name}, (unknown)" # f"{input_name}, (unknown)"
    
    elif len(response.split()) >= 2 and not any(word in response_lower for word in ['unknown', 'pseudonym']) or any(word in response_lower for word in ['real', 'real name']):
        # Looks like a full name
        pattern = r'\s*\(real(?:\s+name)?\)'
        cleaned = re.sub(pattern, '', response.strip())
        # cleaned = re.sub(r'^[\'\"]+|[\'\"]+$', '', cleaned)
        cleaned = cleaned.strip("'\"")  # Remove any leading/trailing quotes
        return cleaned

    else:
        print(f"Unexpected response format: {response}")
        return f"{input_name} ({entity_type})"
    
# Enhanced validation function for GLiNER results
def validate_ner_entities(connections, rag_chain):
    """
    Validate and enhance GLiNER entity extractions using RAG.
    """
    validated_connections = []
    
    for connection in connections:
        validated_connection = []
        
        for entity_text, entity_label in connection:
            if entity_label == "Person":
                # Use RAG to get full name with entity type hint
                full_name = identify_person(entity_text, entity_label, rag_chain)
                validated_connection.append((full_name, entity_label))
            else:
                # For non-person entities, still pass the type but expect it back as-is
                enhanced_name = identify_person(entity_text, entity_label, rag_chain)
                validated_connection.append((enhanced_name, entity_label))
        
        validated_connections.append(validated_connection)
    
    return validated_connections
    
    return validated_connections

def quality_check_validation(original_entities, validated_entities):
    """
    Compare original GLiNER output with validated results.
    """
    print("=== Validation Report ===")
    for i, (orig, val) in enumerate(zip(original_entities, validated_entities)):
        print(f"\nConnection {i+1}:")
        for (orig_text, orig_label), (val_text, val_label) in zip(orig, val):
            if orig_label == "Person" and orig_text != val_text:
                print(f"  Enhanced: {orig_text} → {val_text}")
            else:
                print(f"  Unchanged: {orig_text}")

def check_name(name):
    context = " ".join([r.page_content.strip(".").strip() for r in retriever.batch([name])[0]])
    class Person(BaseModel):
        name: str

    response = chat(
    messages=[
        {
        'role': 'user',
        'content': f"""Tell me about {name}. Here is some context: {context}. 
                        Provide the full name if available. 
                        If the name is not a person, return the name as is.
                        If the name is not a person (e.g., a company or organization), return the name as is without any additional labels.""",
        }
    ],
    model='phi3.5', format=Person.model_json_schema(),
    )

    person = Person.model_validate_json(response.message.content)
    return person



In [None]:
# Define the LLM to use with Ollama
llm = OllamaLLM(model="mistral-nemo", temperature=0)
# llm = OllamaLLM(model="phi3.5", temperature=0)

# Initialize embeddings and vector store
# Note: OllamaEmbeddings requires the Ollama server to be running with the specified model
# Ensure you have the Ollama server running with the model "nomic-embed-text" or "mxbai-embed-large"
# Note: "mxbai-embed-large" is a larger model and may take more time to process but provides better quality embeddings
# embeddings = OllamaEmbeddings(model="nomic-embed-text")
embeddings = OllamaEmbeddings(model="mxbai-embed-large")
vectorstore = InMemoryVectorStore(embeddings)
docs = []
for i, s in enumerate(sentences):
    docs.append(Document(id=str(i), page_content=s, metadata={"source": file_name}))

# text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
# Use smaller chunks for more precise retrieval
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=100,  # Smaller chunks for better precision
    chunk_overlap=55,
    separators=["\n\n", "\n", ". ", " "]
)
# text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=20)
all_texts = text_splitter.split_documents(docs)
vs = vectorstore.from_documents(documents=all_texts, embedding=embeddings)
# retriever = vs.as_retriever(k=6)
# Increase retrieval count and add similarity threshold
retriever = vs.as_retriever(search_kwargs={"k": 5, "score_threshold": 0.95}) # 0.3

system_prompt = (
    "You are a precise name identification assistant. Your task is to identify the full name of a person from the given context.\n"
    "The entity type hint is: {entity_type}\n"
    "Rules:\n"
    "1. If the entity type is 'Person' and you find the person's full name (first + last), return it exactly as written\n"
    "2. If the entity type is 'Person' and the name is explicitly mentioned as a pseudonym, fake name, or alias, return 'Name (pseudonym)'\n"
    "3. If the entity type is 'Person' and you cannot find the full name, return 'Name (unknown)'\n"
    "4. If the entity type is 'Organization' or 'Company', return the name as-is without any additional labels\n"
    "5. If the name is mentioned but unclear, or if there are multiple possibilities, return 'Name (unclear)'\n"
    "6. Only use information explicitly stated in the context\n"
    "7. Be concise - provide only the name and status\n"
    "8. Be factual. Do not invent new names. If you don't know the answer, respond with 'Name (unknown)'\n\n"
    "Examples:\n"
    "- Input: 'Rachel' (Person) → Output: 'Rachel Blais' (if full name found)\n"
    "- Input: 'Kate' (Person) → Output: 'Kate (pseudonym)' (if mentioned as fake name)\n"
    "- Input: 'Vogue' (Organization) → Output: 'Vogue' (return as-is)\n\n"
    "Context: {context}"
)

prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt),
        ("human", "Name: {input}\nEntity Type: {entity_type}"),
    ]
)

# qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever)
qa_chain = create_stuff_documents_chain(llm, prompt)
rag_chain = create_retrieval_chain(retriever, qa_chain)


In [None]:
# ('Vogue', 'Organization')
# query = {"input": f"Who is {name}? Give me their full name."}
query = {"input": "Vogue", "entity_type": "Organization"}
response = rag_chain.invoke(query)
print(response['answer'])
print("\n".join([r.page_content for r in response['context']]))


In [None]:
[r.page_content.strip(".").strip() for r in retriever.batch(["Vogue (Organization)"])[0]]


In [None]:
# identify_person("Trump", rag_chain) # unclear
# identify_person("Melania", rag_chain) # unknown
# res = identify_person("Vogue", rag_chain) # unknown
# print(res) 
# check_name(name=res)
# Example usage
# result = identify_person("Blais", "Person", rag_chain)
result = identify_person("Vogue", "Organization", rag_chain)
print(result) 
check_name(name=result)


In [None]:
connections


In [None]:
# check_name(name="Gehi, (unknown)")
# check_name(name="Vogue, (unknown)")


In [None]:
# name = "Kate"
# # query = {"input": f"Who is {name}? Give me their full name."}
# query = {"input": name}
# response = rag_chain.invoke(query)
# print(response['answer'])


In [None]:
# # Apply validation to your connections
# validated_connections = validate_ner_entities(connections, rag_chain)

# # Test with specific names
# test_names = ["Kate", "Anna", "Blais", "Melania"]
# for name in test_names:
#     result = identify_person(name, rag_chain)
#     print(f"{name} → {result}")


In [None]:

# # Run quality check
# quality_check_validation(connections, validated_connections)


In [None]:
# identify_person("kate", rag_chain)
# identify_person("gary", rag_chain)
# identify_person("Melania", rag_chain)
# identify_person("anna", rag_chain)
# identify_person("donald", rag_chain)
# identify_person("Lanzano", rag_chain)
# identify_person("Pierre Roussel", rag_chain)
# identify_person("Naresh", rag_chain)
# identify_person("Gehi", rag_chain, max_attempts=4)


# Trying different NER models

In [None]:
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForTokenClassification


In [None]:
# m = "dslim/bert-base-NER"
m ="dbmdz/bert-large-cased-finetuned-conll03-english"

tokenizer = AutoTokenizer.from_pretrained(m)
model = AutoModelForTokenClassification.from_pretrained(m)

nlp = pipeline("ner", model=model, tokenizer=tokenizer)
example = "My name is Wolfgang and I live in Berlin"
# example = "Hugging Face is based in New York City."

ner_results = nlp(example)
for e in ner_results:
    print(f"{e['word']}: {e['entity']} (score: {e['score']:.2f})")


In [None]:
nlp = pipeline("ner", model=model, tokenizer=tokenizer)
entities = nlp("TRUMP")

# print(entities)
for e in entities:
    print(f"{e['word']}: {e['entity']} (score: {e['score']:.2f})")


### Notes on RAG

Example on how to use InMemoryVectorStore

In [None]:
# # Define the LLM to use with Ollama
# llm = OllamaLLM(model="mistral-nemo", temperature=0)

# # Initialize embeddings and vector store
# # Note: OllamaEmbeddings requires the Ollama server to be running with the specified model
# # Ensure you have the Ollama server running with the model "nomic-embed-text"
# embeddings = OllamaEmbeddings(model="nomic-embed-text")
# vectorstore = InMemoryVectorStore(embeddings)
# docs = []
# for i, s in enumerate(sentences):
#     docs.append(Document(id=str(i), page_content=s, metadata={"source": file_name}))

# # text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
# text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
# all_texts = text_splitter.split_documents(docs)
# vs = vectorstore.from_documents(documents=all_texts, embedding=embeddings)
# retriever = vs.as_retriever(k=6)
# qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever)

# name = "Blais"
# query = f"Who is {name}? Give me their full name. Be concise. I just want first name and last name."
# response = qa_chain.invoke(query)
# print(response)


In [None]:
# # RAG

# pdf = "How A.I. Assistants Could Supercharge Workplace Software _ Inc.com.pdf"
# docs = PyPDFLoader(pdf).load()

# llm = OllamaLLM(model="phi3.5", temperature=0)
# embeddings = OllamaEmbeddings(model="all-minilm")

# text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
# splits = text_splitter.split_documents(docs)
# vectorstore = InMemoryVectorStore.from_documents(
#     documents=splits, embedding=embeddings
# )

# retriever = vectorstore.as_retriever()

# system_prompt = (
#     "You are an assistant for question-answering tasks. "
#     "Use the following pieces of retrieved context to answer "
#     "the question. If you don't know the answer, say that you "
#     "don't know. Use three sentences maximum and keep the "
#     "answer concise."
#     "\n\n"
#     "{context}"
# )

# prompt = ChatPromptTemplate.from_messages(
#     [
#         ("system", system_prompt),
#         ("human", "{input}"),
#     ]
# )

# question_answer_chain = create_stuff_documents_chain(llm, prompt)
# rag_chain = create_retrieval_chain(retriever, question_answer_chain)

# excerpts = rag_chain.invoke({"input": "Generate the top 5 key themes and phrases. Do not be overly verbose."})
# print(excerpts['answer'])


In [None]:
# # Define the LLM to use with Ollama
# # llm = OllamaLLM(model="mistral-nemo", temperature=0)
# llm = OllamaLLM(model="phi3.5", temperature=0)

# # Initialize embeddings and vector store
# # Note: OllamaEmbeddings requires the Ollama server to be running with the specified model
# # Ensure you have the Ollama server running with the model "nomic-embed-text" or "mxbai-embed-large"
# # Note: "mxbai-embed-large" is a larger model and may take more time to process but provides better quality embeddings
# # embeddings = OllamaEmbeddings(model="nomic-embed-text")
# embeddings = OllamaEmbeddings(model="mxbai-embed-large")
# vectorstore = InMemoryVectorStore(embeddings)
# docs = []
# for i, s in enumerate(sentences):
#     docs.append(Document(id=str(i), page_content=s, metadata={"source": file_name}))

# # text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
# text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
# all_texts = text_splitter.split_documents(docs)
# vs = vectorstore.from_documents(documents=all_texts, embedding=embeddings)
# retriever = vs.as_retriever(k=6)

# system_prompt = (
#     "You are an assistant for question-answering tasks. "
#     "Specifically, your goal is to help identify people and their full names."
#     "If you don't know the person's full name, say that you don't know."
#     "If the name is a pseudonym, indicate that it is a pseudonym: Bob <PSEUDONYM>."
#     "Use the following pieces of retrieved context to answer the question."
#     "Only provide the full name  or name <PSEUDONYM> if you are sure about it, otherwise say you don't know."
#     "Be concise."
#     "\n\n"
#     "{context}"
# )

# prompt = ChatPromptTemplate.from_messages(
#     [
#         ("system", system_prompt),
#         ("human", "{input}"),
#     ]
# )

# # qa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever)
# qa_chain = create_stuff_documents_chain(llm, prompt)
# rag_chain = create_retrieval_chain(retriever, qa_chain)

# name = "Kate"
# query = {"input": f"Who is {name}? Give me their full name. If the name is a pseudonym, indicate that it is a pseudonym like this: Name <PSEUDONYM>. Be concise."}
# response = rag_chain.invoke(query)
# print(response['answer'])


In [None]:
# def identify_person(name: str, rag_chain, max_attempts: int = 2) -> str:
#     """
#     Identify a person's full name using the RAG pipeline with structured output.
#     """
#     queries = [
#         f"What is the full name of {name}? Is {name} a real name or a pseudonym?",
#         f"Find all mentions of {name} in the text. What is their complete name?",
#         f"Who is {name}? Provide their full name and indicate if it's a pseudonym or alias."
#     ]
    
#     for i, query_text in enumerate(queries[:max_attempts]):
#         try:
#             query = {"input": query_text}
#             response = rag_chain.invoke(query)
#             answer = response['answer'].strip()
            
#             # Post-process the answer for consistency
#             return format_name_response(name, answer)
            
#         except Exception as e:
#             print(e)
#             if i == max_attempts - 1:  # Last attempt
#                 return f"{name} (unknown)"
#             continue
#     print("LLM not queried, returning unknown")
#     return f"{name} (unknown)"


In [None]:
# system_prompt = (
#     "You are a precise name identification assistant. Your task is to identify the full name of a person from the given context.\n"
#     "Rules:\n"
#     "1. If you find the person's full name (first + last), return it exactly as written\n"
#     "2. If the name is explicitly mentioned as a pseudonym, fake name, or alias, return 'Name (pseudonym)'\n"
#     "3. If you cannot find the full name, return 'Name (unknown)'\n"
#     "4. If the name is mentioned but unclear, or if there are multiple possibilities for example John Smith and Mary Smith, return 'Name (unclear)'\n"
#     "5. Only use information explicitly stated in the context\n"
#     "6. Be concise - provide only the name and status\n"
#     "7. Be factual. Do not invent new names. If you don't know the answer, respond with 'Name (unknown)'\n"
#     "8. If the name is not a person (e.g., a company or organization), return the name as is without any additional labels.\n\n"
#     "Examples:\n"
#     "- Input: 'Rachel' → Output: 'Rachel Blais' (if full name found)\n"
#     "- Input: 'Kate' → Output: 'Kate (pseudonym)' (if mentioned as fake name or pseudonym)\n"
#     "- Input: 'Bob' → Output: 'Bob (unknown)' (if full name not found)\n\n"
#     "Context: {context}"
# )
