# Document Retrieval

In [23]:
import os
import pandas as pd
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import joblib
from rank_bm25 import BM25Okapi
import nltk
from nltk.tokenize import word_tokenize
import json
from openai import OpenAI
import re

## Parameters

In [8]:
# change this accordingly
project_path = os.path.abspath(os.path.join(os.getcwd(), os.pardir))

dataset_dir = f"{project_path}/scicite_preprocessed"
results_dir = f"{project_path}/results"

dataset = "selected-features"

## 1. Load database

In [3]:
background_df = pd.read_json(f"{dataset_dir}/train_background.jsonl", lines=True)
method_df = pd.read_json(f"{dataset_dir}/train_method.jsonl", lines=True)
result_df = pd.read_json(f"{dataset_dir}/train_result.jsonl", lines=True)

## 2. Create BM25, Semantic and Hybrid Retrievers

In [4]:
class SemanticRetriever:
    def __init__(self, documents, paper_ids, model_name="allenai/scibert_scivocab_uncased", device="cuda" if torch.cuda.is_available() else "cpu"):
        """
        A Semantic Retriever using SciBERT embeddings for short documents (citations).
        Stores document-paper_id pairs for retrieval.
        """
        self.device = device

        # Load SciBERT tokenizer & model
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(device).eval()

        # Store (document, paper_id) pairs
        self.documents = documents
        self.paper_ids = paper_ids
        self.embeddings = self.embed_documents(documents)

    def embed_text(self, text, max_length=512):
        """ Converts text into SciBERT embeddings. """
        tokens = self.tokenizer(text, padding=True, truncation=True, max_length=max_length, return_tensors="pt").to(self.device)
        with torch.no_grad():
            outputs = self.model(**tokens)
        return outputs.last_hidden_state[:, 0, :].squeeze(0).cpu().numpy()  # [CLS] token

    def embed_documents(self, documents):
        """ Embeds all short documents using SciBERT. """
        return np.array([self.embed_text(doc) for doc in documents])

    def retrieve(self, query, top_k=3):
        """ Retrieves the most relevant documents along with their paper IDs. """
        query_embedding = self.embed_text(query).reshape(1, -1)
        similarities = cosine_similarity(query_embedding, self.embeddings).flatten()
        top_indices = similarities.argsort()[-top_k:][::-1]  # Get top-k sorted indices
        return [(self.documents[i], self.paper_ids[i]) for i in top_indices]

class BM25Retriever:
    def __init__(self, documents, paper_ids):
        """
        A BM25 Retriever using term frequency and inverse document frequency scores.
        Stores document-paper_id pairs for retrieval.
        """
        self.documents = documents
        self.paper_ids = paper_ids

        self.tokenized_docs = [word_tokenize(doc.lower()) for doc in self.documents]
        self.bm25 = BM25Okapi(self.tokenized_docs)

    def retrieve(self, query, top_k=3):
        
        query_tokens = word_tokenize(query.lower())
        scores = self.bm25.get_scores(query_tokens)
        top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_k]

        return [(self.documents[i], self.paper_ids[i]) for i in top_indices]

class HybridRetriever:
    def __init__(self, documents, paper_ids, bm25_weight=0.5, semantic_weight=0.5, model_name="allenai/scibert_scivocab_uncased", device="cuda" if torch.cuda.is_available() else "cpu"):
        """
        HybridRetriever: Combines BM25 and Semantic Retrieval.
        Uses a weighted combination of both scores to retrieve documents.
        """
        self.bm25_weight = bm25_weight
        self.semantic_weight = semantic_weight

        self.documents = documents
        self.paper_ids = paper_ids
        
        # BM25
        self.tokenized_docs = [word_tokenize(doc.lower()) for doc in self.documents]
        self.bm25 = BM25Okapi(self.tokenized_docs)

        # Semantic
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(device).eval()
        self.embeddings = self.embed_documents(documents)
    
    def embed_text(self, text, max_length=512):
        """ Converts text into SciBERT embeddings. """
        tokens = self.tokenizer(text, padding=True, truncation=True, max_length=max_length, return_tensors="pt").to(self.device)
        with torch.no_grad():
            outputs = self.model(**tokens)
        return outputs.last_hidden_state[:, 0, :].squeeze(0).cpu().numpy()  # [CLS] token

    def embed_documents(self, documents):
        """ Embeds all short documents using SciBERT. """
        return np.array([self.embed_text(doc) for doc in documents])
    
    def retrieve(self, query, top_k=3):
        """
        Retrieves the top_k most relevant documents using a combination of BM25 and semantic scores.
        """
        # BM25
        query_tokens = word_tokenize(query.lower())
        bm25_scores = self.bm25.get_scores(query_tokens)

        # Semantic
        query_embedding = self.embed_text(query).reshape(1, -1)
        semantic_scores = cosine_similarity(query_embedding, self.embeddings).flatten()
        
        # Weighted combination
        combined_scores = (self.bm25_weight * bm25_scores) + (self.semantic_weight * semantic_scores)

        # Retrieve top-k sentences
        top_indices = sorted(range(len(combined_scores)), key=lambda i: combined_scores[i], reverse=True)[:top_k]
        
        return [(self.documents[i], self.paper_ids[i]) for i in top_indices]

In [5]:
# Extract citations & paper IDs for each label
method_docs, method_ids = method_df["string"].tolist(), method_df["id"].tolist()
background_docs, background_ids = background_df["string"].tolist(), background_df["id"].tolist()
result_docs, result_ids = result_df["string"].tolist(), result_df["id"].tolist()

# Initialize retrievers for each category
bm_method_retriever = BM25Retriever(method_docs, method_ids)
bm_background_retriever = BM25Retriever(background_docs, background_ids)
bm_result_retriever = BM25Retriever(result_docs, result_ids)

sem_method_retriever = SemanticRetriever(method_docs, method_ids)
sem_background_retriever = SemanticRetriever(background_docs, background_ids)
sem_result_retriever = SemanticRetriever(result_docs, result_ids)

hyb_method_retriever = HybridRetriever(method_docs, method_ids)
hyb_background_retriever = HybridRetriever(background_docs, background_ids)
hyb_result_retriever = HybridRetriever(result_docs, result_ids)

## 3. Classify input queries

In [42]:
ori_test_df = pd.read_json(f"{project_path}/scicite/test.jsonl", lines=True)
test_df = pd.read_csv(f"{dataset_dir}/test-{dataset}.csv")

sample_test_df = test_df.sample(n=100, random_state=42)
X_test = sample_test_df.drop(columns=['label'])
y_test = sample_test_df["label"]
# X_test = test_df.drop(columns=['label'])
# y_test = test_df["label"]

# Select the same indices from the original test dataset
ori_sample_test_df = ori_test_df.loc[sample_test_df.index]

# Display the selected rows
# display(X_test)
# display(ori_sample_test_df)

In [43]:
# Load the classifier 
classifier = joblib.load(f"{dataset_dir}/selected_classifier.pkl")

pred = classifier.predict(X_test)

# Print predictions and true labels
# print("Predictions vs True Labels:")
# df_comparison = pd.DataFrame({"Predicted": pred, "True": y_test.values})
# print(df_comparison)

## 4. Retrieve similar citations

In [44]:
# Initialize retrievers for each label
label_encoder = joblib.load(f"{dataset_dir}/label_encoder.pkl")
label_strings = label_encoder.inverse_transform(pred)

bm_retrievers = {
    "background": bm_background_retriever,
    "method": bm_method_retriever,
    "result": bm_result_retriever
}

sem_retrievers = {
    "background": sem_background_retriever,
    "method": sem_method_retriever,
    "result": sem_result_retriever
}

hyb_retrievers = {
    "background": hyb_background_retriever,
    "method": hyb_method_retriever,
    "result": hyb_result_retriever
}

# Initialize empty lists to hold the results for each retriever type
bm_retrieved_docs = []
sem_retrieved_docs = []
hyb_retrieved_docs = []

for idx, label in enumerate(label_strings):
    query = ori_sample_test_df.iloc[idx]["string"]
    bm_retriever = bm_retrievers.get(label)
    sem_retriever = sem_retrievers.get(label)
    hyb_retriever = hyb_retrievers.get(label)

    # BM Retriever
    relevant_docs = bm_retriever.retrieve(query, top_k=3)
    bm_results = []
    for doc, paper_id in relevant_docs:
        bm_results.append({"Document": doc, "Paper ID": paper_id})

    bm_retrieved_docs.append({
        "Query": query,
        "Predicted label": label,
        "Retrieved docs": bm_results
    })
    
    # Semantic Retriever
    relevant_docs = sem_retriever.retrieve(query, top_k=3)
    sem_results = []
    for doc, paper_id in relevant_docs:
        sem_results.append({"Document": doc, "Paper ID": paper_id})
        
    sem_retrieved_docs.append({
        "Query": query,
        "Predicted label": label,
        "Retrieved docs": sem_results
    })
    
    # Hybrid Retriever
    relevant_docs = hyb_retriever.retrieve(query, top_k=3)
    hyb_results = []
    for doc, paper_id in relevant_docs:
        hyb_results.append({"Document": doc, "Paper ID": paper_id})

    hyb_retrieved_docs.append({
        "Query": query,
        "Predicted label": label,
        "Retrieved docs": hyb_results
    })
    
# Save the retrieved documents for each retriever to a separate JSON file
with open(f"{results_dir}/bm_retrieved_docs.json", "w") as f:
    json.dump(bm_retrieved_docs, f, indent=4)

with open(f"{results_dir}/sem_retrieved_docs.json", "w") as f:
    json.dump(sem_retrieved_docs, f, indent=4)

with open(f"{results_dir}/hyb_retrieved_docs.json", "w") as f:
    json.dump(hyb_retrieved_docs, f, indent=4)

print(f"Saved results to {results_dir}")

Saved results to /home/brina/nus-mcomp/sem3/cs4248-natural-language-processing/Project/CS4248-NLP-Project/results


## 5. Evaluate retrieval

In [46]:
api_key = "fy3jHNMV7OC7t7fQprQkFgp7NeSlRsMG"
base_url = "https://api.deepinfra.com/v1/openai"
model = "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"
# model = "mistralai/Mistral-Small-24B-Instruct-2501"

llm = OpenAI(api_key=api_key, base_url=base_url)

# Function to call OpenAI and get relevance feedback (qualitative)
def evaluate_retrieval(query, retrieved_docs):
    # Prepare the prompt for OpenAI to evaluate the relevance

    prompt = """
                You are an evaluator tasked to determine whether each document is relevant to the citation query.
                For each document, return `1` if it is relevant to the citation, or `0` if irrelevant.
                Directly return the answer as `Answer: ` with no explanation.
                It should be a sequence of digits separated by spaces, with one digit per document and in the same order as presented.

                For example,
                Citation Query: "We adopt the method proposed by Smith et al. (2020) for neural text generation, which introduces a variational decoding strategy."  
                Document 1: "Smith et al. (2020) propose a variational decoding framework that improves diversity in neural text generation."  
                Document 2: "Brown et al. (2019) study reinforcement learning for policy optimization in robotics tasks."  
                Document 3: "The paper explores attention mechanisms in neural networks but does not mention variational decoding."
                Answer: 1 0 0
                
              """
    
    prompt += f"Citation Query: {query}\n"
                
    for i, doc in enumerate(retrieved_docs):
        prompt += f"Document {i + 1}: {doc['Document']}\n"

    prompt += "Answer: "

    messages = [{"role": "user", "content": prompt}]

    # Send the prompt to OpenAI's API
    response = llm.chat.completions.create(
        model=model,
        messages=messages,
        max_tokens=200
    )

    # Extract the digit sequence from the response
    result = response.choices[0].message.content.strip()
    # print(result)

    match = re.search(r"Answer:\s*(.*)", result)
    if match:
        digit_sequence = match.group(1)
    else:
        # If "Answer:" is not found, assume the model just returns the digits directly
        digit_sequence = result

    # Split and validate
    relevance_scores = digit_sequence.strip().split()
    
    # Ensure that relevance_scores only contain 0s and 1s, and match the length of retrieved_docs
    relevance_scores = [score if score in ['0', '1'] else '0' for score in relevance_scores]
    
    # If there are fewer relevance scores than retrieved docs, append 0s
    if len(relevance_scores) < len(retrieved_docs):
        relevance_scores += ['0'] * (len(retrieved_docs) - len(relevance_scores))
    
    # If there are more relevance scores than retrieved docs, truncate the extra scores
    elif len(relevance_scores) > len(retrieved_docs):
        relevance_scores = relevance_scores[:len(retrieved_docs)]

    # Package results
    evaluated_docs = []
    for doc, relevance in zip(retrieved_docs, relevance_scores):
        evaluated_docs.append({
            "Document": doc['Document'],
            "Paper ID": doc['Paper ID'],
            "Relevance": int(relevance)
        })

    return evaluated_docs

# Load the retrieved documents JSON files
with open(f"{results_dir}/bm_retrieved_docs.json", "r") as f:
    bm_retrieved_docs = json.load(f)

with open(f"{results_dir}/sem_retrieved_docs.json", "r") as f:
    sem_retrieved_docs = json.load(f)

with open(f"{results_dir}/hyb_retrieved_docs.json", "r") as f:
    hyb_retrieved_docs = json.load(f)

# Initialize lists to store evaluations
bm_evaluations = []
sem_evaluations = []
hyb_evaluations = []

# Evaluate each query for BM retriever
for query_data in bm_retrieved_docs:
    query = query_data['Query']
    retrieved_docs = query_data['Retrieved docs']
    evaluated_docs = evaluate_retrieval(query, retrieved_docs)
    bm_evaluations.append({
        "Query": query,
        "Predicted label": query_data['Predicted label'],
        "Retrieved docs": evaluated_docs
    })

# Evaluate each query for Semantic retriever
for query_data in sem_retrieved_docs:
    query = query_data['Query']
    retrieved_docs = query_data['Retrieved docs']
    evaluated_docs = evaluate_retrieval(query, retrieved_docs)
    sem_evaluations.append({
        "Query": query,
        "Predicted label": query_data['Predicted label'],
        "Retrieved docs": evaluated_docs
    })

# Evaluate each query for Hybrid retriever
for query_data in hyb_retrieved_docs:
    query = query_data['Query']
    retrieved_docs = query_data['Retrieved docs']
    evaluated_docs = evaluate_retrieval(query, retrieved_docs)
    hyb_evaluations.append({
        "Query": query,
        "Predicted label": query_data['Predicted label'],
        "Retrieved docs": evaluated_docs
    })

# Save the evaluations to JSON files
with open(f"{results_dir}/bm_evaluations.json", "w") as f:
    json.dump(bm_evaluations, f, indent=4)

with open(f"{results_dir}/sem_evaluations.json", "w") as f:
    json.dump(sem_evaluations, f, indent=4)

with open(f"{results_dir}/hyb_evaluations.json", "w") as f:
    json.dump(hyb_evaluations, f, indent=4)

print(f"Saved results to {results_dir}")

Saved results to /home/brina/nus-mcomp/sem3/cs4248-natural-language-processing/Project/CS4248-NLP-Project/results


In [47]:
def compute_precision(evaluations):
    total_docs = 0
    total_relevant = 0
    for query_eval in evaluations:
        for doc in query_eval['Retrieved docs']:
            total_docs += 1
            if doc['Relevance'] == 1:
                total_relevant += 1
    precision = total_relevant / total_docs if total_docs else 0
    return precision

bm_precision = compute_precision(bm_evaluations)
sem_precision = compute_precision(sem_evaluations)
hyb_precision = compute_precision(hyb_evaluations)

print("\n=== Precision Scores ===")
print(f"BM Precision: {bm_precision:.4f}")
print(f"Semantic Precision: {sem_precision:.4f}")
print(f"Hybrid Precision: {hyb_precision:.4f}")


=== Precision Scores ===
BM Precision: 0.4033
Semantic Precision: 0.3467
Hybrid Precision: 0.3700
