# Full Pipeline (Original Logic Preserved + Multithreaded Download)

## PDF Finder (Original)

In [None]:
# pdf_finder.py
from selenium import webdriver
from selenium.webdriver.chrome.service import Service
from webdriver_manager.chrome import ChromeDriverManager
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.common.by import By
import time
from selenium.webdriver.common.keys import Keys
from selenium.webdriver.common.action_chains import ActionChains
from typing import List, Set

def find_pdf_links(url: str, base_url: str) -> Set[str]:
    """
    Navigates a specified URL, handles cookie banner, clicks 'See all results',
    and scrapes all unique PDF links on the resulting page.

    Args:
        url: The starting URL for the search.
        base_url: The base URL to prepend to relative links.

    Returns:
        A set of absolute PDF URLs.
    """
    
    # Initialize the WebDriver
    driver = webdriver.Chrome(service=Service(ChromeDriverManager().install()))
    driver.get(url)
    wait = WebDriverWait(driver, 20)
    
    # Making the window fullscreen
    driver.maximize_window()
    time.sleep(2)

    try:
        # --- Handle Cookie Banner ---
        WebDriverWait(driver, 10).until(
            EC.presence_of_element_located((By.TAG_NAME, "body")) 
        )
        driver.find_element(By.TAG_NAME, "body").click() 
        actions = ActionChains(driver)
        actions.send_keys(Keys.TAB).send_keys(Keys.ENTER).perform()
        print("Accepted cookies using keyboard actions (Tab + Enter).")

    except Exception as e:
        print(f"Could not interact with the cookie banner (maybe it didn't appear). Error: {e}")

    time.sleep(2)

    # --- Click 'See all results' ---
    see_all_results_selector = "a.btn.btn-primary.fit-content"
    try:
        button_element = WebDriverWait(driver, 20).until(
            EC.visibility_of_element_located((By.CSS_SELECTOR, see_all_results_selector))
        )
        
        # Scroll to and click the button
        driver.execute_script("arguments[0].scrollIntoView({block: 'center'});", button_element)
        time.sleep(0.5) 
        driver.execute_script("arguments[0].click();", button_element)

        print("Clicked the 'See all results' link/button using JavaScript force-click.")

    except Exception as e:
        print(f"Could not interact with the 'See all results' button. Error: {e}")
        driver.quit() # Ensure the driver closes if this fails
        return set() # Return an empty set

    # Giving the webpage time to load completely
    # NOTE: 35 seconds is quite long; ensure this time is necessary!
    time.sleep(35) 

    # --- Find the PDF URLs ---
    # Finding the pdf urls
    # Filtering out the agreements that are not in English
    # **NOTE: Ensure this CSS selector is correct for the final page!**
    link_elements = driver.find_elements(By.CSS_SELECTOR, "a[href$='.pdf']:not(.btn-download)")

    pdf_urls_set = set()
    for link_element in link_elements:
        href = link_element.get_attribute('href')
        if href:
            if not href.startswith("http"):
                absolute_url = base_url + href
                pdf_urls_set.add(absolute_url)
            else:
                pdf_urls_set.add(href)

    print(f"Found {len(pdf_urls_set)} unique PDF URLs.")

    # Always close the browser when finished
    driver.quit()
    
    # Return the collected URLs
    return pdf_urls_set


if __name__ == '__main__':
    # This block allows you to test the function directly if you run the .py file
    URL = "https://www.peaceagreements.org/agreements/search/?search_type=basic-search&show_timeline=0&match_any_issues=True"
    BASE_URL = "https://www.peaceagreements.org"
    
    # Example usage when running the file directly
    all_pdf_links = find_pdf_links(URL, BASE_URL)
    
    if all_pdf_links:
        print("\nFirst 5 URLs collected:")
        for link in list(all_pdf_links)[:5]:
            print(link)
    else:
        print("\nNo links were found.")

## Multithreaded PDF Downloader

In [None]:

import concurrent.futures, requests, os
from pathlib import Path

def download_pdf(url, outdir):
    out=Path(outdir); out.mkdir(exist_ok=True)
    fname=url.split('/')[-1].split('?')[0]
    path=out/fname
    if path.exists(): return path
    try:
        r=requests.get(url,timeout=20)
        r.raise_for_status()
        path.write_bytes(r.content)
        return path
    except Exception as e:
        print("Fail",url,e)
        return None

def download_all(urls, outdir="downloads", max_workers=10):
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex:
        return list(ex.map(lambda u: download_pdf(u,outdir), urls))


## chunk_docs.py (Original)

In [None]:
import os
import pickle
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.documents import Document


def chunk_documents_parallel(
    pdf_dir: str,
    txt_dir: str,
    chunk_size: int = 1500,
    chunk_overlap: int = 200,
    output_path="document_chunks.pkl",
    workers=8
):
    """
    Loads and chunks PDF files from `pdf_dir` and TXT files from `txt_dir` in parallel.

    Returns a pickle file containing:
        [
            {
                "doc_id": filename,
                "document_text": full_text,
                "chunks": [Document, Document, ...]
            },
            ...
        ]
    """
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap
    )

    all_docs = []

    # ----------------------------
    #   Helper to process 1 file
    # ----------------------------
    def process_file(file_path: str, filename: str, is_pdf: bool):
        try:
            # Load text
            if is_pdf:
                loader = PyPDFLoader(file_path)
                pages = loader.load()
                full_text = "\n".join([p.page_content for p in pages])
            else:
                with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
                    full_text = f.read()

            doc_id = filename

            # Chunk text
            chunks = text_splitter.split_text(full_text)
            chunk_docs = [
                Document(
                    page_content=chunk,
                    metadata={
                        "source": file_path,
                        "chunk_id": f"{doc_id}-chunk:{i}",
                        "doc_id": doc_id,
                    },
                )
                for i, chunk in enumerate(chunks)
            ]

            return {
                "doc_id": doc_id,
                "document_text": full_text,
                "chunks": chunk_docs,
            }

        except Exception as e:
            print(f"‚ö†Ô∏è Error processing {filename}: {e}")
            return None

    # ----------------------------
    #      Collect file lists
    # ----------------------------
    pdf_files = [
        (os.path.join(pdf_dir, f), f, True)
        for f in os.listdir(pdf_dir)
        if f.lower().endswith(".pdf")
    ]

    txt_files = [
        (os.path.join(txt_dir, f), f, False)
        for f in os.listdir(txt_dir)
        if f.lower().endswith(".txt")
    ]

    # ----------------------------
    #     Parallel processing
    # ----------------------------
    def run_parallel(file_list, desc):
        results = []
        with ThreadPoolExecutor(max_workers=workers) as executor:
            futures = {
                executor.submit(process_file, path, fname, is_pdf): fname
                for path, fname, is_pdf in file_list
            }

            for future in tqdm(as_completed(futures), total=len(futures), desc=desc):
                result = future.result()
                if result:
                    results.append(result)

        return results

    print("üìÑ Processing PDFs...")
    all_docs.extend(run_parallel(pdf_files, "PDFs"))

    print("üìù Processing text files...")
    all_docs.extend(run_parallel(txt_files, "TXTs"))

    # ----------------------------
    #        Save results
    # ----------------------------
    with open(output_path, "wb") as f:
        pickle.dump(all_docs, f)

    print(f"‚úÖ Saved {len(all_docs)} documents to {output_path}")
    return all_docs


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--pdf-dir", required=True)
    parser.add_argument("--txt-dir", required=True)
    parser.add_argument("--chunk-size", type=int, default=1500)
    parser.add_argument("--chunk-overlap", type=int, default=200)
    parser.add_argument("--output", default="document_chunks.pkl")
    parser.add_argument("--workers", type=int, default=8)
    args = parser.parse_args()

    chunk_documents_parallel(
        args.pdf_dir,
        args.txt_dir,
        args.chunk_size,
        args.chunk_overlap,
        args.output,
        args.workers
    )


## ingest_vectors.py (Original)

In [None]:
import pickle
import json
import os
from pathlib import Path
from tqdm import tqdm
from langchain_community.vectorstores import Neo4jVector
from langchain_core.documents import Document
import torch
import gc
from concurrent.futures import ThreadPoolExecutor, as_completed
from dotenv import load_dotenv

load_dotenv()


def ingest_vectors(pkl_path, embeddings, args):
    with open(pkl_path, "rb") as f:
        dataset = pickle.load(f)
    
    print(f"Loaded {len(dataset)} documents for vector ingestion")
    
    # Setup progress tracking
    progress_file = Path(args.progress_file)
    completed_docs = set()
    
    if progress_file.exists():
        with open(progress_file, "r") as f:
            progress_data = json.load(f)
            completed_docs = set(progress_data.get("completed_docs", []))
        print(f"Resuming: {len(completed_docs)} documents already processed")
    
    # Filter out already completed documents
    docs_to_process = [d for d in dataset if d["doc_id"] not in completed_docs]
    print(f"Processing {len(docs_to_process)} remaining documents")
    
    vector_store_config = {
        "embedding": embeddings,
        "url": args.neo4j_url,
        "username": args.neo4j_user,
        "password": args.neo4j_password,
        "index_name": args.index_name,
        "node_label": "Chunk",
        "text_node_property": "text",
        "embedding_node_property": "embedding"
    }
    
    # Create vector store (index auto-created if needed)
    vector_store = Neo4jVector(**vector_store_config)
    
    def process_doc(doc_entry):
        doc_id = doc_entry["doc_id"]
        chunks = doc_entry["chunks"]
        
        # Store main document node
        query_doc = """
        MERGE (d:Document {doc_id: $doc_id})
        SET d.source = $source, d.text = $text
        """
        vector_store._driver.execute_query(
            query_doc, 
            doc_id=doc_id, 
            source=chunks[0].metadata["source"], 
            text=doc_entry["document_text"]
        )
        
        # Embed and store chunks
        vector_store.add_documents(
            chunks, 
            ids=[c.metadata["chunk_id"] for c in chunks]
        )
        
        # Link document to chunks
        link_query = """
        MATCH (d:Document {doc_id: $doc_id})
        UNWIND $chunk_ids AS cid
        MATCH (c:Chunk {id: cid})
        MERGE (d)-[:HAS_CHUNK]->(c)
        """
        vector_store._driver.execute_query(
            link_query, 
            doc_id=doc_id, 
            chunk_ids=[c.metadata["chunk_id"] for c in chunks]
        )
        
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        return doc_id
    
    with ThreadPoolExecutor(max_workers=args.concurrency) as executor:
        futures = {executor.submit(process_doc, d): d for d in docs_to_process}
        
        for future in tqdm(as_completed(futures), total=len(futures), desc="Processing docs"):
            try:
                doc_id = future.result()
                completed_docs.add(doc_id)
                
                # Save progress periodically
                with open(progress_file, "w") as f:
                    json.dump({"completed_docs": list(completed_docs)}, f)
            except Exception as e:
                doc_entry = futures[future]
                print(f"‚ö†Ô∏è Error processing {doc_entry['doc_id']}: {e}")
    
    print("‚úÖ Vector ingestion complete")


if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--pkl", default="document_chunks.pkl")
    parser.add_argument("--neo4j-url", default=os.getenv("NEO4J_URL"))
    parser.add_argument("--neo4j-user", default=os.getenv("NEO4J_USER"))
    parser.add_argument("--neo4j-password", default=os.getenv("NEO4J_PASSWORD"))
    parser.add_argument("--index-name", default="peace_index")
    parser.add_argument("--concurrency", type=int, default=8)
    parser.add_argument("--progress-file", default="vector_ingestion_progress.json")
    parser.add_argument("--embedding-model", default=os.getenv("EMBEDDING_MODEL", "text-embedding-3-small"))
    args = parser.parse_args()
    
    # Initialize embeddings
    from langchain_openai import OpenAIEmbeddings
    embeddings = OpenAIEmbeddings(
        model=args.embedding_model,
        openai_api_key=os.getenv("OPENAI_API_KEY", "not-needed"),
        openai_api_base=os.getenv("OPENAI_API_BASE")
    )
    
    ingest_vectors(args.pkl, embeddings, args)

## QA_retrieval.py (Original)

In [None]:
import os
import json
import sqlite3
from datetime import datetime
from typing import List, Dict

from neo4j import GraphDatabase
from openai import OpenAI
from dotenv import load_dotenv

load_dotenv()

# ------------------------------------------------------------
# LLM RAW RESPONSE LOGGING DB
# ------------------------------------------------------------

conn = sqlite3.connect("qa_logs.db")
c = conn.cursor()

c.execute("""
CREATE TABLE IF NOT EXISTS llm_raw_responses (
    id INTEGER PRIMARY KEY AUTOINCREMENT,
    timestamp TEXT,
    question TEXT,
    prompt TEXT,
    raw_response TEXT
)
""")

conn.commit()
conn.close()

# ------------------------------------------------------------
# Retrieval / context knobs
# ------------------------------------------------------------

TOP_K_VECTOR = 25          # how many vector hits to pull
TOP_K_KEYWORD = 25         # how many fulltext hits to pull (when fusion is on)
TOP_K_KG = 15              # how many KG hits to pull (when fusion is on)
TOP_K_FUSED = 20           # how many fused results to keep for context

MAX_CONTEXT_CHARS = 24000  # hard cap on context size passed to GPT
MAX_CHARS_PER_CHUNK = 2000 # truncate very long chunks for safety

# ============================================================
# SQLite logging for final QA answers
# ============================================================

QA_DB_PATH = "qa_answers.db"


def init_qa_db():
    conn = sqlite3.connect(QA_DB_PATH)
    cur = conn.cursor()
    cur.execute("""
        CREATE TABLE IF NOT EXISTS answers (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            timestamp TEXT,
            question TEXT,
            answer TEXT,
            fusion_enabled INTEGER
        )
    """)
    conn.commit()
    conn.close()


def log_qa_answer(question: str, answer: str, fusion_enabled: bool):
    conn = sqlite3.connect(QA_DB_PATH)
    cur = conn.cursor()
    cur.execute("""
        INSERT INTO answers (timestamp, question, answer, fusion_enabled)
        VALUES (?, ?, ?, ?)
    """, (
        datetime.utcnow().isoformat(),
        question,
        answer,
        1 if fusion_enabled else 0
    ))
    conn.commit()
    conn.close()


# ============================================================
# Neo4j helper
# ============================================================

def neo4j_query(driver, query: str, params: dict) -> List[Dict]:
    with driver.session() as session:
        return session.run(query, params).data()


# ============================================================
# Retrieval helpers (vector, fulltext, KG)
# ============================================================

def get_vector_results(question: str, driver, client: OpenAI) -> List[Dict]:
    """Retrieve top chunks by vector similarity."""
    embedding = client.embeddings.create(
        model="text-embedding-3-small",  # must match ingest
        input=question
    ).data[0].embedding

    query = """
    CALL db.index.vector.queryNodes('chunk_embeddings', $topK, $embedding)
    YIELD node, score
    RETURN node.doc_id AS doc_id, node.text AS text, score
    ORDER BY score DESC
    """

    return neo4j_query(driver, query, {
        "embedding": embedding,
        "topK": TOP_K_VECTOR,
    })


def get_fulltext_results(question: str, driver) -> List[Dict]:
    """Retrieve top chunks using fulltext index on :Chunk(text)."""
    query = f"""
    CALL db.index.fulltext.queryNodes('chunk_fulltext', $q)
    YIELD node, score
    RETURN node.doc_id AS doc_id, node.text AS text, score
    ORDER BY score DESC
    LIMIT {TOP_K_KEYWORD}
    """
    return neo4j_query(driver, query, {"q": question})


def get_kg_results(question: str, driver) -> List[Dict]:
    """
    Simple KG-based retrieval:
    find Documents linked from entities whose name contains the query string.
    """
    query = f"""
    MATCH (e)-[:DERIVED_FROM]->(d:Document)
    WHERE e.name CONTAINS $q
    RETURN d.doc_id AS doc_id, d.text AS text, 1.0 AS score
    LIMIT {TOP_K_KG}
    """
    return neo4j_query(driver, query, {"q": question})


# ============================================================
# RAG fusion scoring (chunk-level)
# ============================================================

def reciprocal_rank_fusion(result_lists: List[List[Dict]], k: int = TOP_K_FUSED) -> List[Dict]:
    """
    Reciprocal Rank Fusion over chunk-level results.
    Each result is a dict with at least: doc_id, text, score.
    We dedupe by (doc_id, text) so multiple chunks from the same doc
    can still appear, but identical duplicates won't.
    """
    def _index_map(results: List[Dict]):
        return {(r["doc_id"], r["text"]): idx for idx, r in enumerate(results)}

    # Build index maps
    maps = [_index_map(results) for results in result_lists]

    # All unique (doc_id, text) keys
    all_keys = set()
    for m in maps:
        all_keys |= set(m.keys())

    scores = {}
    for key in all_keys:
        rr = 0.0
        for m in maps:
            if key in m:
                # standard RRF scoring
                rr += 1.0 / (60 + m[key])
        scores[key] = rr

    # Sort by fused score
    ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)

    # Turn back into list of result dicts
    fused_results = []
    # Flatten all results for lookup
    flat = []
    for lst in result_lists:
        flat.extend(lst)

    for (doc_id, text), score in ranked[:k]:
        # find one matching entry to copy fields from
        match = next((r for r in flat if r["doc_id"] == doc_id and r["text"] == text), None)
        if match:
            fused_results.append({
                "doc_id": doc_id,
                "text": text,
                "score": score
            })

    return fused_results


# ============================================================
# Context building
# ============================================================

def build_context(results: List[Dict]) -> str:
    """
    Build a structured context string from retrieval results, with:
    - doc headers
    - chunk truncation
    - global char budget
    """
    pieces = []
    total_chars = 0

    for r in results:
        doc_id = r.get("doc_id", "unknown_doc")
        text = r.get("text", "") or ""

        if not text:
            continue

        # Truncate overly long chunks for safety
        if len(text) > MAX_CHARS_PER_CHUNK:
            text = text[:MAX_CHARS_PER_CHUNK] + " ... [truncated]"

        section = f"[DOC: {doc_id}]\n{text}"

        # Enforce global context budget
        if total_chars + len(section) + 5 > MAX_CONTEXT_CHARS:
            break

        pieces.append(section)
        total_chars += len(section) + 5  # +5 for separators

    return "\n\n-----\n\n".join(pieces)


# ============================================================
# Answer synthesis using LLM
# ============================================================

def generate_answer(context: str, question: str, client: OpenAI):
    prompt = f"""
You are an expert analyst answering questions about peace agreements, security arrangements, and related political texts.

You MUST:
- Use ONLY the information in the context below (no outside knowledge).
- Be as EXHAUSTIVE as possible given the context.
- Synthesize across multiple documents and chunks when needed.
- If the context gives only a partial answer, say clearly: "This answer is based only on the retrieved documents and may be incomplete."

CONTEXT:
{context}

QUESTION:
{question}

Answer in clear, structured bullet points (or short paragraphs if more natural).
"""

    response = client.chat.completions.create(
        model="gpt-5",
        messages=[{"role": "user", "content": prompt}],
        max_completion_tokens=10000,
        # temperature=0,  # gpt-5 doesn't support custom temperature yet
    )

    raw_json = json.dumps(response.model_dump(), indent=2)

    # Store in SQLite (raw LLM response)
    conn = sqlite3.connect("qa_logs.db")
    c = conn.cursor()
    c.execute("""
        INSERT INTO llm_raw_responses (timestamp, question, prompt, raw_response)
        VALUES (?, ?, ?, ?)
    """, (
        datetime.utcnow().isoformat(),
        question,
        prompt,
        raw_json
    ))
    conn.commit()
    conn.close()

    return response.choices[0].message.content.strip()


# ============================================================
# Main QA function
# ============================================================

def answer_question(question: str, use_fusion: bool = False):
    # Connect to Neo4j
    driver = GraphDatabase.driver(
        os.getenv("NEO4J_URI"),
        auth=(os.getenv("NEO4J_USER"), os.getenv("NEO4J_PASSWORD"))
    )

    # OpenAI client
    client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

    # ------------------------------------------------------------
    # 1. Retrieval
    # ------------------------------------------------------------
    vec_results = get_vector_results(question, driver, client)
    kw_results = get_fulltext_results(question, driver)
    kg_results = get_kg_results(question, driver)

    # ------------------------------------------------------------
    # 2. Fusion or vanilla RAG
    # ------------------------------------------------------------
    if use_fusion:
        fused = reciprocal_rank_fusion([vec_results, kw_results, kg_results])
        retrieval_used = fused
    else:
        # Just vector search; keep top-K_FUSED
        retrieval_used = vec_results[:TOP_K_FUSED]

    # ------------------------------------------------------------
    # 3. Build context string
    # ------------------------------------------------------------
    context = build_context(retrieval_used)

    # (Optional debug)
    # print("\n---- VECTOR RESULTS ----")
    # print(vec_results)
    # print("\n---- FULLTEXT RESULTS ----")
    # print(kw_results)
    # print("\n---- KG RESULTS ----")
    # print(kg_results)
    # print("\n---- CONTEXT SENT TO MODEL (truncated) ----")
    # print(context[:2000])
    # print("... (truncated) ...")

    # If somehow nothing was retrieved
    if not context.strip():
        answer = (
            "I could not retrieve any relevant context from the database for this question, "
            "so I cannot provide a grounded answer."
        )
        return answer

    # ------------------------------------------------------------
    # 4. Generate LLM answer
    # ------------------------------------------------------------
    answer = generate_answer(context, question, client)

    return answer


# ============================================================
# CLI ENTRYPOINT ‚Äî with SQLite logging
# ============================================================

if __name__ == "__main__":
    import argparse
    import sys

    init_qa_db()

    parser = argparse.ArgumentParser()
    parser.add_argument("--q", help="Single question")
    parser.add_argument("--questions-file", help="File containing one question per line")
    parser.add_argument("--fusion", action="store_true", help="Enable RAG Fusion scoring")
    args = parser.parse_args()

    questions = []

    if args.q:
        questions.append(args.q)

    if args.questions_file and os.path.exists(args.questions_file):
        with open(args.questions_file, "r", encoding="utf-8") as f:
            questions.extend([line.strip() for line in f.readlines() if line.strip()])

    if not questions:
        print("No questions provided. Nothing to do.")
        sys.exit(0)

    for q in questions:
        print("\n====================================================")
        print(f"QUESTION: {q}")
        print("====================================================\n")

        ans = answer_question(q, use_fusion=args.fusion)

        print("ANSWER:\n")
        print(ans)
        print("\n----------------------------------------------------\n")

        # Log into DB
        log_qa_answer(q, ans, args.fusion)


## QA_retrieval_eval.py (Original)

In [None]:
import os
import csv
from typing import List

from dotenv import load_dotenv
from neo4j import GraphDatabase
from openai import OpenAI

# DeepEval imports
from deepeval.metrics import (
    ContextualPrecisionMetric,
    # ContextualRecallMetric,
    ContextualRelevancyMetric,
    AnswerRelevancyMetric,
    FaithfulnessMetric,
)
from deepeval.test_case import LLMTestCase
from deepeval import evaluate

# Import your existing QA pipeline
import QA_retrieval
from QA_retrieval import neo4j_query, reciprocal_rank_fusion

# If these constants exist in QA_retrieval, reuse them; else set safe defaults
MAX_CONTEXT_CHARS = getattr(QA_retrieval, "MAX_CONTEXT_CHARS", 24000)
MAX_CHARS_PER_CHUNK = getattr(QA_retrieval, "MAX_CHARS_PER_CHUNK", 2000)

load_dotenv()

def sanitize_lucene(q: str) -> str:
    bad = r'+ - && || ! ( ) { } [ ] ^ " ~ * ? : \ /'
    for c in bad.split():
        q = q.replace(c, " ")
    return q

# ------------------------------------------------------------
# Context reconstruction that mirrors QA_retrieval logic
# ------------------------------------------------------------

def get_context_for_question(
    question: str,
    use_fusion: bool,
    driver,
    client: OpenAI,
) -> List[str]:
    """
    Rebuilds the RAG context chunks using the same retrieval logic
    as QA_retrieval.answer_question.

    Returns a list of chunk strings (one per retrieved text).
    """

    # 1) Embed question ‚Äì same model as in QA_retrieval
    embedding = client.embeddings.create(
        model="text-embedding-3-small",
        input=question
    ).data[0].embedding

    # 2) Vector search (same Cypher as your QA script)
    vec_query = """
    CALL db.index.vector.queryNodes('chunk_embeddings', 5, $embedding)
    YIELD node, score
    RETURN node.doc_id AS doc_id, node.text AS text, score
    """

    vec_results = neo4j_query(driver, vec_query, {"embedding": embedding})

    # 3) Fulltext search (same as QA)
    kw_query = """
    CALL db.index.fulltext.queryNodes('chunk_fulltext', $q)
    YIELD node, score
    RETURN node.doc_id AS doc_id, node.text AS text, score
    LIMIT 5
    """

    kw_results = neo4j_query(driver, kw_query, {"q": sanitize_lucene(question)})

    # 4) KG search (same as QA)
    kg_query = """
    MATCH (e)-[:DERIVED_FROM]->(d:Document)
    WHERE e.name CONTAINS $q
    RETURN d.doc_id AS doc_id, d.text AS text, 1 AS score
    LIMIT 5
    """

    kg_results = neo4j_query(driver, kg_query, {"q": question})

    # 5) Fusion vs vanilla
    if use_fusion:
        fused = reciprocal_rank_fusion([vec_results, kw_results, kg_results])

        # If reciprocal_rank_fusion already returns full result dicts (as your debug shows)
        if fused and isinstance(fused[0], dict):
            ranked_results = fused
            print("Fused:", fused)
        else:
            # Fallback: fused is a list of doc_id's ‚Äì map them back to rows
            all_results = vec_results + kw_results + kg_results
            result_map = {str(r["doc_id"]): r for r in all_results}
            ranked_results = []
            for d in fused:
                r = result_map.get(str(d))
                if r:
                    ranked_results.append(r)

        top_contexts = []
        for r in ranked_results[:4]:
            text = r.get("text") or ""
            if len(text) > MAX_CHARS_PER_CHUNK:
                text = text[:MAX_CHARS_PER_CHUNK]
            top_contexts.append(text)

    else:
        top_contexts = []
        for r in vec_results[:4]:
            text = r.get("text") or ""
            if len(text) > MAX_CHARS_PER_CHUNK:
                text = text[:MAX_CHARS_PER_CHUNK]
            top_contexts.append(text)
    print("VEC:", len(vec_results))
    print("KW:", len(kw_results))
    print("KG:", len(kg_results))
    print("Matched contexts:", len(top_contexts))
    # Hard cap total context length
    joined = "\n\n".join(top_contexts)
    if len(joined) > MAX_CONTEXT_CHARS:
        joined = joined[:MAX_CONTEXT_CHARS]

    # If no contexts, return an empty list instead of ['']
    if not joined:
        return []

    return joined.split("\n\n")


# ------------------------------------------------------------
# Load eval dataset
# ------------------------------------------------------------

def load_eval_set(path: str):
    """
    CSV with columns: question, ground_truth
    """
    rows = []
    with open(path, "r", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        for row in reader:
            q = row.get("question", "").strip()
            gt = row.get("ground_truth", "").strip()
            if q and gt:
                rows.append({"question": q, "ground_truth": gt})
    return rows


# ------------------------------------------------------------
# DeepEval metrics
# ------------------------------------------------------------

def build_metrics():
    contextual_precision = ContextualPrecisionMetric()
    # contextual_recall = ContextualRecallMetric()
    contextual_relevancy = ContextualRelevancyMetric()
    answer_relevancy = AnswerRelevancyMetric()
    faithfulness = FaithfulnessMetric()

    return [
        contextual_precision,
        # contextual_recall,
        contextual_relevancy,
        answer_relevancy,
        faithfulness,
    ]


# ------------------------------------------------------------
# Main eval runner
# ------------------------------------------------------------

def main():
    import argparse
    import traceback
    import csv

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--eval-file",
        default="rag_eval_set.csv",
        help="CSV with columns: question,ground_truth"
    )
    parser.add_argument(
        "--fusion",
        action="store_true",
        help="Use the same RAG fusion mode as QA_retrieval (vector+fulltext+KG).",
    )
    args = parser.parse_args()

    eval_rows = load_eval_set(args.eval_file)
    if not eval_rows:
        print(f"No rows found in {args.eval_file}. Exiting.")
        return

    # Setup
    driver = GraphDatabase.driver(
        os.getenv("NEO4J_URI"),
        auth=(os.getenv("NEO4J_USER"), os.getenv("NEO4J_PASSWORD")),
    )
    client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
    metrics = build_metrics()

    # Prepare raw output logging
    raw_output_path = "eval_raw_outputs.csv"
    with open(raw_output_path, "w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow(["question", "ground_truth", "model_answer", "error"])

    print(f"Running evaluation on {len(eval_rows)} questions.")
    print(f"Fusion mode: {'ON' if args.fusion else 'OFF'}\n")

    # Evaluate each question *individually*
    for i, row in enumerate(eval_rows, start=1):
        q = row["question"]
        gt = row["ground_truth"]

        print(f"\n============================")
        print(f"[{i}/{len(eval_rows)}] QUESTION")
        print(q)

        answer = None
        error_msg = None
        ctx_chunks = []

        # ---- Step 1 ‚Äî safe answer generation ----
        try:
            answer = QA_retrieval.answer_question(q, use_fusion=args.fusion)
        except Exception as e:
            error_msg = f"Answer error: {repr(e)}"
            print("\n‚ùå ERROR producing answer:")
            print(error_msg)
            answer = "ERROR"

        # ---- Step 2 ‚Äî safe context retrieval ----
        try:
            if answer != "ERROR":
                ctx_chunks = get_context_for_question(
                    q, use_fusion=args.fusion, driver=driver, client=client
                )
        except Exception as e:
            err = f"Context error: {repr(e)}"
            print("\n‚ö†Ô∏è ERROR retrieving context (skipping context):")
            print(err)
            error_msg = error_msg or err
            ctx_chunks = []

        # ---- Save raw output immediately ----
        with open(raw_output_path, "a", newline="", encoding="utf-8") as f:
            writer = csv.writer(f)
            writer.writerow([q, gt, answer, error_msg or ""])

        # ---- Step 3 ‚Äî run metrics *for this single question* ----
        # ---- Step 3 ‚Äî run metrics *for this single question* ----
        try:
            tc = LLMTestCase(
                input=q,
                actual_output=answer,
                expected_output=gt,
                retrieval_context=ctx_chunks,
            )

            per_q_results = evaluate(
                test_cases=[tc],
                metrics=metrics,
            )

            print("\nüìä METRICS:")

            # DeepEval returns an object, not a dict
            for metric_result in per_q_results.metrics_data:
                name = metric_result.metric_name
                score = metric_result.score

                # score is already a float
                print(f"- {name}: {score:.3f}")

        except Exception as e:
            print("\n‚ùå ERROR running DeepEval metrics:")
            print(traceback.format_exc())
            continue


    print("\n============================")
    print("Evaluation complete.")
    print(f"Raw outputs saved to: {raw_output_path}")

if __name__ == "__main__":
    main()


## extract_kg.py (Original)

In [None]:
import json
import asyncio
import os
import yaml
import sqlite3
from datetime import datetime
from pathlib import Path
from openai import AsyncOpenAI
from tqdm import tqdm
from neo4j import GraphDatabase
from dotenv import load_dotenv

load_dotenv()

DB_PATH = "kg_raw_responses.db"


def init_db():
    conn = sqlite3.connect(DB_PATH)
    cur = conn.cursor()
    # Create table if it doesn't exist
    cur.execute(
        """
        CREATE TABLE IF NOT EXISTS responses (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            timestamp TEXT,
            doc_id TEXT,
            status TEXT,          -- 'success', 'empty', 'parse_error', 'api_error'
            raw_response TEXT,    -- full JSON response from the API
            content TEXT,         -- message.content
            parsed_json TEXT,     -- parsed JSON (entities/relationships)
            error TEXT            -- error message / finish_reason notes
        )
        """
    )
    # Ensure raw_response column exists even if table was created earlier without it
    cur.execute("PRAGMA table_info(responses)")
    cols = [row[1] for row in cur.fetchall()]
    if "raw_response" not in cols:
        cur.execute("ALTER TABLE responses ADD COLUMN raw_response TEXT")
    conn.commit()
    conn.close()


def log_response(
    doc_id: str,
    status: str,
    raw_response: str = "",
    content: str = "",
    parsed_json: str = "",
    error: str = "",
):
    conn = sqlite3.connect(DB_PATH)
    cur = conn.cursor()
    cur.execute(
        """
        INSERT INTO responses (timestamp, doc_id, status, raw_response, content, parsed_json, error)
        VALUES (?, ?, ?, ?, ?, ?, ?)
        """,
        (
            datetime.utcnow().isoformat(),
            doc_id,
            status,
            raw_response,
            content,
            parsed_json,
            error,
        ),
    )
    conn.commit()
    conn.close()


def load_schema(schema_path: str) -> dict:
    """Load and parse the schema YAML file."""
    with open(schema_path, "r") as f:
        return yaml.safe_load(f)


def build_schema_prompt(schema: dict) -> str:
    """Build a prompt section describing the schema."""
    prompt_parts = []

    if "entities" in schema:
        entity_lines = []
        for entity_dict in schema["entities"]:
            for entity_type, description in entity_dict.items():
                entity_lines.append(f"- {entity_type}: {description}")
        prompt_parts.append("Extract these entity types:\n" + "\n".join(entity_lines))

    if "relationships" in schema:
        rel_lines = []
        for rel_dict in schema["relationships"]:
            for rel_type, rel_spec in rel_dict.items():
                rel_lines.append(f"- {rel_type}: {rel_spec}")
        prompt_parts.append(
            "\nExtract these relationship types:\n" + "\n".join(rel_lines)
        )

    return "\n".join(prompt_parts)


def serialize_response(response) -> str:
    """
    Best-effort conversion of the OpenAI response object into a JSON string.
    Handles different client versions.
    """
    try:
        obj = response.model_dump()
    except AttributeError:
        try:
            obj = response.to_dict()
        except AttributeError:
            try:
                # Some clients have .json() returning a string
                text = response.json()
                # If that's already JSON, keep it as-is
                json.loads(text)
                return text
            except Exception:
                obj = str(response)
    if isinstance(obj, str):
        return obj
    return json.dumps(obj, ensure_ascii=False)


async def extract_kg(args):
    """Extract a knowledge graph from stored document nodes in Neo4j."""
    client = AsyncOpenAI(
        base_url=args.llm_base_url,
        api_key=args.llm_api_key,
    )
    driver = GraphDatabase.driver(
        args.neo4j_url,
        auth=(args.neo4j_user, args.neo4j_password),
    )

    # Load schema
    schema = load_schema(args.schema_path)
    schema_prompt = build_schema_prompt(schema)

    # Progress tracking
    progress_file = Path(args.progress_file)
    completed_docs = set()
    if progress_file.exists():
        with open(progress_file, "r") as f:
            progress_data = json.load(f)
            completed_docs = set(progress_data.get("completed_docs", []))
        print(f"Resuming: {len(completed_docs)} documents already processed")

    # Load documents from Neo4j
    with driver.session() as session:
        docs = session.run(
            "MATCH (d:Document) RETURN d.doc_id AS id, d.text AS text"
        ).data()

    docs_to_process = [d for d in docs if d["id"] not in completed_docs]
    print(f"Extracting KG from {len(docs_to_process)} remaining documents")

    semaphore = asyncio.Semaphore(args.concurrent_requests)

    async def process_doc(doc):
        """Process a single document with the LLM and write results to Neo4j."""
        doc_id = doc["id"]
        text = doc["text"] or ""

        system_prompt = (
            "You are an information extraction model. "
            "Return only valid JSON with two arrays: 'entities' and 'relationships'. "
            "Do not include any explanations or reasoning text."
        )
        user_prompt = f"""Extract entities and relationships from this document according to the following schema:

{schema_prompt}

Return JSON in this exact format:
{{"entities":[{{"type":"","name":"","properties":{{}}}}], "relationships":[{{"source":"","target":"","type":"","properties":{{}}}}]}}

Document text:
{text[:10000]}"""  # truncate if too long

        try:
            response = await client.chat.completions.create(
                model=args.model,
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": user_prompt},
                ],
                max_completion_tokens=20048,
                response_format={"type": "json_object"},
                # extra_body={"service_tier": "flex"},

            )

            raw_response_str = serialize_response(response)
            choice = response.choices[0]
            msg = choice.message
            raw_content = msg.content or ""
            finish_reason = choice.finish_reason

            # If model didn't actually return any visible content,
            # log and bail out for this doc.
            if not raw_content.strip():
                log_response(
                    doc_id,
                    status="empty",
                    raw_response=raw_response_str,
                    content=raw_content,
                    parsed_json="",
                    error=f"finish_reason={finish_reason}, no content returned",
                )
                return

            # Try to parse JSON
            try:
                parsed = json.loads(raw_content)
            except Exception as e:
                log_response(
                    doc_id,
                    status="parse_error",
                    raw_response=raw_response_str,
                    content=raw_content,
                    parsed_json="",
                    error=f"{type(e).__name__}: {e}",
                )
                return

            # Log successful parse
            log_response(
                doc_id,
                status="success",
                raw_response=raw_response_str,
                content=raw_content,
                parsed_json=json.dumps(parsed, ensure_ascii=False),
                error=f"finish_reason={finish_reason}",
            )

            entities = parsed.get("entities") or []
            relationships = parsed.get("relationships") or []

            with driver.session() as session:
                # Insert entities
                for ent in entities:
                    if not isinstance(ent, dict):
                        continue

                    name = ent.get("name")
                    if not name:
                        continue

                    entity_type = ent.get("type", "Entity")
                    properties = ent.get("properties", {}) or {}
                    properties["name"] = name

                    prop_items = ", ".join([f"e.{k} = ${k}" for k in properties.keys()])

                    session.run(
                        f"MERGE (e:{entity_type} {{name: $name}}) "
                        f"SET {prop_items} "
                        "WITH e MATCH (d:Document {doc_id: $doc_id}) "
                        "MERGE (e)-[:DERIVED_FROM]->(d)",
                        **properties,
                        doc_id=doc_id,
                    )

                # Insert relationships
                for rel in relationships:
                    if not isinstance(rel, dict):
                        continue

                    src = rel.get("source")
                    tgt = rel.get("target")
                    if not src or not tgt:
                        continue

                    rel_type = rel.get("type", "RELATED_TO").upper().replace(" ", "_")
                    properties = rel.get("properties", {}) or {}

                    if properties:
                        prop_items = ", ".join(
                            [f"r.{k} = ${k}" for k in properties.keys()]
                        )
                        prop_set = f"SET {prop_items}"
                    else:
                        prop_set = ""

                    session.run(
                        f"MATCH (s {{name: $src}}), (t {{name: $tgt}}) "
                        f"MERGE (s)-[r:{rel_type}]->(t) "
                        f"{prop_set}",
                        src=src,
                        tgt=tgt,
                        **properties,
                    )

            # Mark doc as completed
            completed_docs.add(doc_id)
            with open(progress_file, "w") as f:
                json.dump({"completed_docs": list(completed_docs)}, f)

        except Exception as e:
            # Catch API / connection errors
            log_response(
                doc_id,
                status="api_error",
                raw_response="",
                content="",
                parsed_json="",
                error=f"{type(e).__name__}: {e}",
            )
            print(f"‚ö†Ô∏è Error processing {doc_id}: {e}")

    async def process_with_semaphore(doc):
        async with semaphore:
            await process_doc(doc)

    tasks = [process_with_semaphore(d) for d in docs_to_process]
    for coro in tqdm(
        asyncio.as_completed(tasks), total=len(tasks), desc="KG extraction"
    ):
        await coro

    driver.close()
    print("‚úÖ KG extraction complete")


if __name__ == "__main__":
    import argparse

    init_db()
    parser = argparse.ArgumentParser()
    parser.add_argument("--neo4j-url", default=os.getenv("NEO4J_URI"))
    parser.add_argument("--neo4j-user", default=os.getenv("NEO4J_USER"))
    parser.add_argument("--neo4j-password", default=os.getenv("NEO4J_PASSWORD"))
    parser.add_argument(
        "--llm-base-url",
        default=os.getenv("LLM_BASE_URL", "https://api.openai.com/v1"),
    )
    parser.add_argument("--llm-api-key", default=os.getenv("OPENAI_API_KEY"))
    parser.add_argument("--model", default=os.getenv("LLM_MODEL", "gpt-5-mini"))
    parser.add_argument("--schema-path", default="schema.yaml")
    parser.add_argument("--progress-file", default="kg_extraction_progress.json")
    parser.add_argument("--concurrent-requests", type=int, default=5)
    args = parser.parse_args()

    asyncio.run(extract_kg(args))
