# Document Retrieval

In [18]:
import os
import pandas as pd
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import joblib
from rank_bm25 import BM25Okapi
import nltk
from nltk.tokenize import word_tokenize

## Parameters

In [2]:
# change this accordingly
project_path = os.path.abspath(os.path.join(os.getcwd(), os.pardir))

dataset_dir = f"{project_path}/scicite_preprocessed"
dataset = "selected-features"

## 1. Load database

In [3]:
background_df = pd.read_json(f"{dataset_dir}/train_background.jsonl", lines=True)
method_df = pd.read_json(f"{dataset_dir}/train_method.jsonl", lines=True)
result_df = pd.read_json(f"{dataset_dir}/train_result.jsonl", lines=True)

## 2. Create BM25, Semantic and Hybrid Retrievers

In [20]:
class SemanticRetriever:
    def __init__(self, documents, paper_ids, model_name="allenai/scibert_scivocab_uncased", device="cuda" if torch.cuda.is_available() else "cpu"):
        """
        A Semantic Retriever using SciBERT embeddings for short documents (citations).
        Stores document-paper_id pairs for retrieval.
        """
        self.device = device

        # Load SciBERT tokenizer & model
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(device).eval()

        # Store (document, paper_id) pairs
        self.documents = documents
        self.paper_ids = paper_ids
        self.embeddings = self.embed_documents(documents)

    def embed_text(self, text, max_length=512):
        """ Converts text into SciBERT embeddings. """
        tokens = self.tokenizer(text, padding=True, truncation=True, max_length=max_length, return_tensors="pt").to(self.device)
        with torch.no_grad():
            outputs = self.model(**tokens)
        return outputs.last_hidden_state[:, 0, :].squeeze(0).cpu().numpy()  # [CLS] token

    def embed_documents(self, documents):
        """ Embeds all short documents using SciBERT. """
        return np.array([self.embed_text(doc) for doc in documents])

    def retrieve(self, query, top_k=3):
        """ Retrieves the most relevant documents along with their paper IDs. """
        query_embedding = self.embed_text(query).reshape(1, -1)
        similarities = cosine_similarity(query_embedding, self.embeddings).flatten()
        top_indices = similarities.argsort()[-top_k:][::-1]  # Get top-k sorted indices
        return [(self.documents[i], self.paper_ids[i]) for i in top_indices]

class BM25Retriever:
    def __init__(self, documents, paper_ids):
        
        self.documents = documents
        self.paper_ids = paper_ids
        
        self.tokenized_docs = [word_tokenize(doc.lower()) for doc in self.documents]
        self.bm25 = BM25Okapi(self.tokenized_docs)

        ################################
        # Your code ends here
        ################################
    
    def retrieve(self, query, top_k=3):
        
        query_tokens = word_tokenize(query.lower())
        scores = self.bm25.get_scores(query_tokens)
        top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_k]

        return [(self.documents[i], self.paper_ids[i]) for i in top_indices]

class HybridRetriever:
    def __init__(self, documents, paper_ids, bm25_weight=0.5, semantic_weight=0.5, model_name="allenai/scibert_scivocab_uncased", device="cuda" if torch.cuda.is_available() else "cpu"):
        """
        HybridRetriever: Combines BM25 and Semantic Retrieval.
        Uses a weighted combination of both scores to retrieve documents.
        """
        self.bm25_weight = bm25_weight
        self.semantic_weight = semantic_weight

        self.documents = documents
        self.paper_ids = paper_ids
        
        # BM25
        self.tokenized_docs = [word_tokenize(doc.lower()) for doc in self.documents]
        self.bm25 = BM25Okapi(self.tokenized_docs)

        # Semantic
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(device).eval()
        self.embeddings = self.embed_documents(documents)
    
    def embed_text(self, text, max_length=512):
        """ Converts text into SciBERT embeddings. """
        tokens = self.tokenizer(text, padding=True, truncation=True, max_length=max_length, return_tensors="pt").to(self.device)
        with torch.no_grad():
            outputs = self.model(**tokens)
        return outputs.last_hidden_state[:, 0, :].squeeze(0).cpu().numpy()  # [CLS] token

    def embed_documents(self, documents):
        """ Embeds all short documents using SciBERT. """
        return np.array([self.embed_text(doc) for doc in documents])
    
    def retrieve(self, query, top_k=3):
        """
        Retrieves the top_k most relevant documents using a combination of BM25 and semantic scores.
        """
        # BM25
        query_tokens = word_tokenize(query.lower())
        bm25_scores = self.bm25.get_scores(query_tokens)

        # Semantic
        query_embedding = self.embed_text(query).reshape(1, -1)
        semantic_scores = cosine_similarity(query_embedding, self.embeddings).flatten()
        
        # Weighted combination
        combined_scores = (self.bm25_weight * bm25_scores) + (self.semantic_weight * semantic_scores)

        # Retrieve top-k sentences
        top_indices = sorted(range(len(combined_scores)), key=lambda i: combined_scores[i], reverse=True)[:top_k]
        
        return [(self.documents[i], self.paper_ids[i]) for i in top_indices]

In [21]:
# Extract citations & paper IDs for each label
method_docs, method_ids = method_df["string"].tolist(), method_df["id"].tolist()
background_docs, background_ids = background_df["string"].tolist(), background_df["id"].tolist()
result_docs, result_ids = result_df["string"].tolist(), result_df["id"].tolist()

# Initialize retrievers for each category
bm_method_retriever = BM25Retriever(method_docs, method_ids)
bm_background_retriever = BM25Retriever(background_docs, background_ids)
bm_result_retriever = BM25Retriever(result_docs, result_ids)

sem_method_retriever = SemanticRetriever(method_docs, method_ids)
sem_background_retriever = SemanticRetriever(background_docs, background_ids)
sem_result_retriever = SemanticRetriever(result_docs, result_ids)

hyb_method_retriever = HybridRetriever(method_docs, method_ids)
hyb_background_retriever = HybridRetriever(background_docs, background_ids)
hyb_result_retriever = HybridRetriever(result_docs, result_ids)

## 4. Classify input query

In [22]:
ori_test_df = pd.read_json(f"{project_path}/scicite/test.jsonl", lines=True)
test_df = pd.read_csv(f"{dataset_dir}/test-{dataset}.csv")

sample_test_df = test_df.sample(n=1, random_state=42)  # Random 3 rows
X_test = sample_test_df.drop(columns=['label'])
y_test = sample_test_df["label"]

# Select the same indices from the original test dataset
ori_sample_test_df = ori_test_df.loc[sample_test_df.index]

# Display the selected rows
display(X_test)
display(ori_sample_test_df)

Unnamed: 0,citeEnd,citeStart,excerpt_index,source_acronym,source_acronymParen,source_andPhrase,source_etAlPhrase,source_explicit,source_properNoun,isKeyCitation_False,...,zhu et al_tfidf,zimbabwe_tfidf,zinc_tfidf,zinc finger_tfidf,zn_tfidf,zone_tfidf,äì_tfidf,ðþ_tfidf,βarr_tfidf,μm_tfidf
233,120.0,112.0,1,0.0,0.0,0.0,0.0,1.0,0.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


Unnamed: 0,source,citeEnd,sectionName,citeStart,string,label,citingPaperId,citedPaperId,isKeyCitation,id,unique_id,excerpt_index,label2,label_confidence
233,explicit,120.0,Plasma cell biology,112.0,When lymph node follicular B-cells encounter a...,background,3b6fce00a747ffaa22a67b334f733c9f40bff561,846066f8b6cefd92169d860e14bf60d9167d072a,False,3b6fce00a747ffaa22a67b334f733c9f40bff561>84606...,3b6fce00a747ffaa22a67b334f733c9f40bff561>84606...,1,,


In [23]:
# Load the classifier 
classifier = joblib.load(f"{dataset_dir}/selected_classifier.pkl")

pred = classifier.predict(X_test)

# Print predictions and true labels
print("Predictions vs True Labels:")
df_comparison = pd.DataFrame({"Predicted": pred, "True": y_test.values})
print(df_comparison)

Predictions vs True Labels:
   Predicted  True
0          0     0


In [28]:
# Initialize retrievers for each label
label_encoder = joblib.load(f"{dataset_dir}/label_encoder.pkl")
label_strings = label_encoder.inverse_transform(pred)

bm_retrievers = {
    "background": bm_background_retriever,
    "method": bm_method_retriever,
    "result": bm_result_retriever
}

sem_retrievers = {
    "background": sem_background_retriever,
    "method": sem_method_retriever,
    "result": sem_result_retriever
}

hyb_retrievers = {
    "background": hyb_background_retriever,
    "method": hyb_method_retriever,
    "result": hyb_result_retriever
}

# Retrieve relevant documents based on predicted labels
retrieved_docs = []

for idx, label in enumerate(label_strings):
    query = ori_sample_test_df.iloc[idx]["string"]
    bm_retriever = bm_retrievers.get(label)
    sem_retriever = sem_retrievers.get(label)
    hyb_retriever = hyb_retrievers.get(label)     
    print(f"Query: {query}\n")
    print(f"Predicted Label: {label}\n\n")

    print(f"BM Retriever:\n")
    relevant_docs = bm_retriever.retrieve(query, top_k=3)
    for doc, paper_id in relevant_docs:
        print(f"Document: {doc}")
        print(f"Paper ID: {paper_id}")
        print("-" * 50)
        retrieved_docs.append({"Predicted Label": label, "Document": doc, "Paper ID": paper_id})

    print(f"\n\nSemantic Retriever:\n")
    relevant_docs = sem_retriever.retrieve(query, top_k=3)
    for doc, paper_id in relevant_docs:
        print(f"Document: {doc}")
        print(f"Paper ID: {paper_id}")
        print("-" * 50)
        retrieved_docs.append({"Predicted Label": label, "Document": doc, "Paper ID": paper_id})

    print(f"\n\nHybrid Retriever:\n")
    relevant_docs = hyb_retriever.retrieve(query, top_k=3)
    for doc, paper_id in relevant_docs:
        print(f"Document: {doc}")
        print(f"Paper ID: {paper_id}")
        print("-" * 50)
        retrieved_docs.append({"Predicted Label": label, "Document": doc, "Paper ID": paper_id})


Query: When lymph node follicular B-cells encounter antigen and T-cell help, this results in germinal center formation (62, 74).

Predicted Label: background


BM Retriever:

Document: reported to be 3% by another study, lymphadenopathy was not among the clinical symptoms in the south of Iran.([23]) In previous studies, the lymph node metastasis was not reported in this area.
Paper ID: 4a254161278510ddcd1e386d71c8e97ca572774f>d81ce287e36056710a5d77ed74245a6cbe89824e
--------------------------------------------------
Document: Untreated infection results in the progressive depletion of the helper T-cell population, and the resulting immunodeficiency leads to death by opportunistic infection (Thompson et al. (2012)).
Paper ID: 957bf5cc5e797dd078bbea21cb6cb17d41026dcc>8b4b51fffb43ad57ac9aa8e5deeac45f6356e015
--------------------------------------------------
Document: ankle-link antigen (ALA) with ankle links (Goodyear and Richardson, 1999), and the tip-link antigen (TLA) with tip and kin