# Semantic Retriever

In [20]:
import os
import pandas as pd
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import joblib

## Parameters

In [14]:
# change this accordingly
project_path = os.path.abspath(os.path.join(os.getcwd(), os.pardir))

dataset_dir = f"{project_path}/scicite_preprocessed"
dataset = "selected-features"

## 1. Load database

In [5]:
background_df = pd.read_json(f"{dataset_dir}/train_background.jsonl", lines=True)
method_df = pd.read_json(f"{dataset_dir}/train_method.jsonl", lines=True)
result_df = pd.read_json(f"{dataset_dir}/train_result.jsonl", lines=True)

## 2. Create Semantic Retrievers

In [55]:
class SemanticRetriever:
    def __init__(self, documents, paper_ids, model_name="allenai/scibert_scivocab_uncased", device="cuda" if torch.cuda.is_available() else "cpu"):
        """
        A Semantic Retriever using SciBERT embeddings for short documents (citations).
        Stores document-paper_id pairs for retrieval.
        """
        self.device = device

        # Load SciBERT tokenizer & model
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(device).eval()

        # Store (document, paper_id) pairs
        self.documents = documents
        self.paper_ids = paper_ids
        self.embeddings = self.embed_documents(documents)

    def embed_text(self, text, max_length=512):
        """ Converts text into SciBERT embeddings. """
        tokens = self.tokenizer(text, padding=True, truncation=True, max_length=max_length, return_tensors="pt").to(self.device)
        with torch.no_grad():
            outputs = self.model(**tokens)
        return outputs.last_hidden_state[:, 0, :].squeeze(0).cpu().numpy()  # [CLS] token

    def embed_documents(self, documents):
        """ Embeds all short documents using SciBERT. """
        return np.array([self.embed_text(doc) for doc in documents])

    def retrieve(self, query, top_k=3):
        """ Retrieves the most relevant documents along with their paper IDs. """
        query_embedding = self.embed_text(query).reshape(1, -1)
        similarities = cosine_similarity(query_embedding, self.embeddings).flatten()
        top_indices = similarities.argsort()[-top_k:][::-1]  # Get top-k sorted indices
        return [(self.documents[i], self.paper_ids[i]) for i in top_indices]

In [56]:
# Extract citations & paper IDs for each label
method_docs, method_ids = method_df["string"].tolist(), method_df["id"].tolist()
background_docs, background_ids = background_df["string"].tolist(), background_df["id"].tolist()
result_docs, result_ids = result_df["string"].tolist(), result_df["id"].tolist()

# Initialize retrievers for each category
method_retriever = SemanticRetriever(method_docs, method_ids)
background_retriever = SemanticRetriever(background_docs, background_ids)
result_retriever = SemanticRetriever(result_docs, result_ids)

## 4. Classify input query

In [77]:
ori_test_df = pd.read_json(f"{project_path}/scicite/test.jsonl", lines=True)
test_df = pd.read_csv(f"{dataset_dir}/test-{dataset}.csv")

sample_test_df = test_df.sample(n=1, random_state=42)  # Random 3 rows
X_test = sample_test_df.drop(columns=['label'])
y_test = sample_test_df["label"]

# Select the same indices from the original test dataset
ori_sample_test_df = ori_test_df.loc[sample_test_df.index]

# Display the selected rows
display(X_test)
display(ori_sample_test_df)

Unnamed: 0,citeEnd,citeStart,excerpt_index,source_acronym,source_acronymParen,source_andPhrase,source_etAlPhrase,source_explicit,source_properNoun,isKeyCitation_False,...,zhu et al_tfidf,zimbabwe_tfidf,zinc_tfidf,zinc finger_tfidf,zn_tfidf,zone_tfidf,äì_tfidf,ðþ_tfidf,βarr_tfidf,μm_tfidf
233,120.0,112.0,1,0.0,0.0,0.0,0.0,1.0,0.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


Unnamed: 0,source,citeEnd,sectionName,citeStart,string,label,citingPaperId,citedPaperId,isKeyCitation,id,unique_id,excerpt_index,label2,label_confidence
233,explicit,120.0,Plasma cell biology,112.0,When lymph node follicular B-cells encounter a...,background,3b6fce00a747ffaa22a67b334f733c9f40bff561,846066f8b6cefd92169d860e14bf60d9167d072a,False,3b6fce00a747ffaa22a67b334f733c9f40bff561>84606...,3b6fce00a747ffaa22a67b334f733c9f40bff561>84606...,1,,


In [78]:
# Load the classifier 
classifier = joblib.load(f"{dataset_dir}/selected_classifier.pkl")

pred = classifier.predict(X_test)

# Print predictions and true labels
print("Predictions vs True Labels:")
df_comparison = pd.DataFrame({"Predicted": pred, "True": y_test.values})
print(df_comparison)

Predictions vs True Labels:
   Predicted  True
0          0     0


In [88]:
# Initialize retrievers for each label
label_encoder = joblib.load(f"{dataset_dir}/label_encoder.pkl")
label_strings = label_encoder.inverse_transform(pred)

retrievers = {
    "background": background_retriever,
    "method": method_retriever,
    "result": result_retriever
}

# Retrieve relevant documents based on predicted labels
retrieved_docs = []

for idx, label in enumerate(label_strings):
    query = ori_sample_test_df.iloc[idx]["string"]
    retriever = retrievers.get(label)  # Get corresponding retriever
    print(f"Query: {query}\n")
    print(f"Predicted Label: {label}\n\n")

    if retriever:
        relevant_docs = retriever.retrieve(query, top_k=3)  # Get top 3 relevant docs
        for doc, paper_id in relevant_docs:
            # Print inside the loop
            print(f"Document: {doc}")
            print(f"Paper ID: {paper_id}")
            print("-" * 50)
            retrieved_docs.append({"Predicted Label": label, "Document": doc, "Paper ID": paper_id})

Query: When lymph node follicular B-cells encounter antigen and T-cell help, this results in germinal center formation (62, 74).

Predicted Label: background


Document: When IL-17 is blocked during a high-dose challenge, neutrophil recruitment is hindered, and this may alter subsequent development of inflammation (55).
Paper ID: 998b2e755456700a4fc07c6dd1f32d243899d09f>c5162ef5c68848eba50818b64f7a783f1cde4c73
--------------------------------------------------
Document: Untreated infection results in the progressive depletion of the helper T-cell population, and the resulting immunodeficiency leads to death by opportunistic infection (Thompson et al. (2012)).
Paper ID: 957bf5cc5e797dd078bbea21cb6cb17d41026dcc>8b4b51fffb43ad57ac9aa8e5deeac45f6356e015
--------------------------------------------------
Document: Antigen receptor signaling during clonal expansion likely drives down Bcl-2 levels and hyper-induce Mcl-1 in responding CD8 T cells (Dunkle et al., 2011; Opferman et al., 2003).
P

Bad pipe message: %s [b'\x95\x8er\x0eB\xc2\xdfE\x982\xed\xf1\xfd\xd3\x10_\x90W\x00\x02\xbc\x00\x00\x00\x01\x00\x02\x00\x03\x00\x04\x00\x05\x00\x06\x00\x07\x00\x08\x00\t\x00\n\x00\x0b\x00\x0c\x00\r\x00\x0e\x00\x0f\x00\x10\x00\x11\x00\x12\x00\x13\x00\x14\x00\x15\x00\x16\x00\x17\x00\x18\x00\x19\x00\x1a\x00\x1b\x00\x1e\x00\x1f\x00 \x00!\x00"\x00#\x00$\x00%\x00&\x00\'\x00(\x00)\x00*\x00+\x00,\x00-\x00.\x00/\x000\x001\x002\x003\x004\x005\x006\x007\x008\x009\x00:\x00;\x00<\x00=\x00>\x00?\x00@\x00A\x00B\x00C\x00D\x00E\x00F\x00g\x00h\x00i\x00j\x00k\x00l\x00m\x00\x84\x00']
Bad pipe message: %s [b'\x86\x00\x87\x00\x88\x00\x89\x00\x8a\x00\x8b\x00\x8c\x00\x8d\x00\x8e\x00\x8f\x00\x90\x00\x91\x00\x92\x00\x93\x00\x94\x00\x95\x00\x96\x00\x97\x00\x98\x00\x99\x00\x9a\x00\x9b\x00\x9c\x00\x9d\x00\x9e\x00\x9f\x00\xa0\x00\xa1\x00\xa2\x00\xa3\x00\xa4\x00\xa5\x00\xa6\x00\xa7\x00\xa8\x00\xa9\x00\xaa\x00\xab\x00\xac\x00\xad\x00\xae\x00\xaf\x00\xb0\x00\xb1\x00\xb2\x00\xb3\x00\xb4\x00\xb5\x00\xb6\x00\xb7\x00\xb8\x