<a href="https://colab.research.google.com/github/lmkmichelle/long-doc-extraction/blob/main/dpr_256.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git remote add origin https://github.com/lmkmichelle/long-doc-extraction.git

In [None]:
!pip install --upgrade huggingface_hub transformers datasets faiss-cpu

In [None]:
from google.colab import drive
from google.colab import userdata
import numpy as np
import faiss
from datasets import Dataset
from transformers import (
    DPRContextEncoder, DPRContextEncoderTokenizer,
    DPRQuestionEncoder, DPRQuestionEncoderTokenizer,
    AutoModelForCausalLM, AutoTokenizer
)
import torch
import re

In [None]:
drive.mount('/content/drive')

In [None]:
!huggingface-cli login

In [None]:
torch.set_grad_enabled(False)

ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")

q_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")

In [None]:
def read_file(file_path):
    with open(file_path, "r", encoding="utf-8") as f:
        return f.read()

def clean_text(text):
    paragraphs = text.split("\n\n")
    cleaned_paragraphs = [f"{i+1}. {p.strip()}" for i, p in enumerate(paragraphs) if p.strip()]
    return cleaned_paragraphs

def embed_text_batch(batch, max_length=512):
    texts = list(batch["text"])
    tokenized = ctx_tokenizer(
        texts, padding=True, truncation=True, max_length=max_length, return_tensors="pt"
    )
    with torch.no_grad():
        embeddings = ctx_encoder(**tokenized).pooler_output.numpy()
    return {"embeddings": embeddings.tolist()}


In [None]:
file_path = "pandp.txt"
text = read_file(file_path)
cleaned_paragraphs = clean_text(text)

In [None]:
cleaned_paragraphs

In [None]:
batch_size = 20
ds = Dataset.from_dict({"text": cleaned_paragraphs, "paragraph_id": list(range(1, len(cleaned_paragraphs) + 1))})
ds = ds.map(embed_text_batch, batched=True, batch_size=batch_size).add_faiss_index(column="embeddings")

In [None]:
def search(query, k=3):
    question_embedding = q_encoder(**q_tokenizer(query, return_tensors="pt"))[0][0].numpy()
    scores, retrieved_examples = ds.get_nearest_examples("embeddings", question_embedding, k=k)
    return [(retrieved_examples["paragraph_id"][i], retrieved_examples["text"][i]) for i in range(len(scores))]

In [None]:
# sample query
query = "Who is tolerable?"
results = search(query)

for para_id, text in results:
    print(f"Paragraph {para_id}: {text}\n")

In [None]:
!pip install openai google-generativeai

In [None]:
# def compute_coverage(retrieved_paragraphs, ground_truth_paragraphs):
#     return len(set(retrieved_paragraphs) & set(ground_truth_paragraphs)) / len(ground_truth_paragraphs)

# def compute_citation(retrieved_paragraphs, ground_truth_paragraphs):
#     return len(set(retrieved_paragraphs) & set(ground_truth_paragraphs)) / len(retrieved_paragraphs)

In [None]:
# ground truth paras // TODO ADD MORE
ground_truth_paragraphs = {"Who is tolerable?": [242, 244]}

retrieved_paragraphs = [para_id for para_id, _ in results]
coverage_score = compute_coverage(retrieved_paragraphs, ground_truth_paragraphs[query])
citation_score = compute_citation(retrieved_paragraphs, ground_truth_paragraphs[query])

print(f"Coverage Score: {coverage_score:.2f}")
print(f"Citation Score: {citation_score:.2f}")


In [None]:
from openai import OpenAI
from google import genai

google_client = genai.Client(api_key=userdata.get('GEMINI_API'))
openai_client = OpenAI(api_key=userdata.get('OPENAI_API_KEY'))

query = "Who is tolerable?"
def get_context(paragraphs):
    return "\n\n".join(paragraphs)

# truncate input context because token length
N = 300
truncated_full_context = get_context(cleaned_paragraphs[:N])

dpr_context = get_context([text for _, text in results])
oracle_context = get_context([cleaned_paragraphs[i-1] for i in ground_truth_paragraphs[query]])
full_context = get_context(cleaned_paragraphs)


def generate_gpt4o_answer(query, context):
    response = openai_client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[
            {"role": "system", "content": "You are a helpful assistant. Use the context to answer the question and Always cite the paragraph number(s) like this: 36 or 175 in your answer."},
            {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"}
        ],
        temperature=0.3
    )
    return response.choices[0].message.content

def generate_gemini_answer(query, context):
    prompt = f"Context:\n{context}\n\nQuestion: {query}\nAnswer clearly and always include paragraph numbers (e.g. 36 or 157) if relevant."
    response = model.generate_content(prompt)
    response = google_client.models.generate_content(
        model="gemini-1.5-flash",
        contents=prompt,
    )
    return response.text

gpt4o_answers = {
    "dpr": generate_gpt4o_answer(query, dpr_context),
    "oracle": generate_gpt4o_answer(query, oracle_context),
    "full": generate_gpt4o_answer(query, truncated_full_context)
}

gemini_answers = {
    "dpr": generate_gemini_answer(query, dpr_context),
    "oracle": generate_gemini_answer(query, oracle_context),
    "full": generate_gemini_answer(query, truncated_full_context)
}



In [None]:
print("GPT-4o-mini FULL:\n", gpt4o_answers["full"])
print("GPT-4o-mini DPR:\n", gpt4o_answers["dpr"])
print("GPT-4o-mini Oracle:\n", gpt4o_answers["oracle"])

print("\nGemini-1.5-flash FULL:\n", gemini_answers["full"])
print("Gemini-1.5-flash DPR:\n", gemini_answers["dpr"])
print("Gemini-1.5-flash ORACLE:\n", gemini_answers["oracle"])

In [None]:
# import re

# def extract_cited_paragraphs(answer_text):
#     matches = re.findall(r"(?:paragraph|para)?\s*\(?(\d{1,4})\)?", answer_text.lower())
#     return list(set(map(int, matches)))

# def evaluate_model_output(answer_text, ground_truth_ids):
#     cited = extract_cited_paragraphs(answer_text)
#     if not cited:
#         return 0.0, 0.0
#     coverage = compute_coverage(cited, ground_truth_ids)
#     citation = compute_citation(cited, ground_truth_ids)
#     return coverage, citation

# for model_name, outputs in [("GPT-4o", gpt4o_answers), ("Gemini", gemini_answers)]:
#     print(f"--- {model_name} ---")
#     for mode in ["dpr", "oracle", "full"]:
#         cov, cit = evaluate_model_output(outputs[mode], ground_truth_paragraphs[query])
#         print(f"{mode.upper()} → Coverage: {cov:.2f}, Citation: {cit:.2f}")
