In [1]:
!pip install --upgrade pip

[0m

In [2]:
!pip install tqdm sentence-transformers transformers faiss-gpu datasets adapters

Collecting sentence-transformers
  Using cached sentence_transformers-4.1.0-py3-none-any.whl.metadata (13 kB)
Collecting transformers
  Using cached transformers-4.51.3-py3-none-any.whl.metadata (38 kB)
Collecting faiss-gpu
  Using cached faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)
Collecting datasets
  Downloading datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
Collecting adapters
  Using cached adapters-1.1.1-py3-none-any.whl.metadata (17 kB)
Collecting huggingface-hub>=0.20.0 (from sentence-transformers)
  Downloading huggingface_hub-0.31.1-py3-none-any.whl.metadata (13 kB)
Collecting regex!=2019.12.17 (from transformers)
  Using cached regex-2024.11.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)
Collecting tokenizers<0.22,>=0.21 (from transformers)
  Using cached tokenizers-0.21.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)
Collecting safetensors>=0.4.3 (from transformers

Create Dataset

In [3]:
from datasets import load_dataset

query_data = load_dataset("princeton-nlp/LitSearch", "query", split="full")
corpus_clean_data = load_dataset("princeton-nlp/LitSearch", "corpus_clean", split="full")

In [11]:
from typing import List, Tuple, Any
from datasets import Dataset

def get_clean_corpusid(item: dict) -> int:
    return item['corpusid']

def get_clean_title(item: dict) -> str:
    return item['title']

def get_clean_abstract(item: dict) -> str:
    return item['abstract']

def get_clean_title_abstract(item: dict) -> str:
    title = get_clean_title(item)
    abstract = get_clean_abstract(item)
    return title + '[SEP]' + abstract

def get_clean_full_paper(item: dict) -> str:
    return item['full_paper']

def get_clean_paragraph_indices(item: dict) -> List[Tuple[int, int]]:
    text = get_clean_full_paper(item)
    paragraph_indices = []
    paragraph_start = 0
    paragraph_end = 0
    while paragraph_start < len(text):
        paragraph_end = text.find("\n\n", paragraph_start)
        if paragraph_end == -1:
            paragraph_end = len(text)
        paragraph_indices.append((paragraph_start, paragraph_end))
        paragraph_start = paragraph_end + 2
    return paragraph_indices

def get_clean_text(item: dict, start_idx: int, end_idx: int) -> str:
    text = get_clean_full_paper(item)
    assert start_idx >= 0 and end_idx >= 0
    assert start_idx <= end_idx
    assert end_idx <= len(text)
    return text[start_idx:end_idx]

def get_clean_paragraphs(item: dict, min_words: int = 10) -> List[str]:
    paragraph_indices = get_clean_paragraph_indices(item)
    paragraphs = [get_clean_text(item, paragraph_start, paragraph_end) for paragraph_start, paragraph_end in paragraph_indices]
    paragraphs = [paragraph for paragraph in paragraphs if len(paragraph.split()) >= min_words]
    return paragraphs

def get_clean_citations(item: dict) -> List[int]:
    return item['citations']

def get_clean_dict(data: Dataset) -> dict:
    return {get_clean_corpusid(item): item for item in data}

def create_kv_pairs(data: List[dict], key: str) -> dict:
    if key == "title_abstract":
        kv_pairs = {get_clean_title_abstract(record): get_clean_corpusid(record) for record in data}
    elif key == "full_paper":
        kv_pairs = {get_clean_full_paper(record): get_clean_corpusid(record) for record in data}
    elif key == "paragraphs":
        kv_pairs = {}
        for record in data:
            corpusid = get_clean_corpusid(record)
            paragraphs = get_clean_paragraphs(record)
            for paragraph_idx, paragraph in enumerate(paragraphs):
                kv_pairs[paragraph] = (corpusid, paragraph_idx)
    else:
        raise ValueError("Invalid key")
    return kv_pairs

In [12]:
kv_pairs = create_kv_pairs(corpus_clean_data, "title_abstract")

Model
1. Scibert
2. SciNCL
3. SPECTER2

Parent

In [6]:
from enum import Enum
from typing import Dict, Any, List

import numpy as np

class TextType(Enum):
    KEY = 1
    QUERY = 2

class Retrieval:
    def __init__(self, index_name: str, index_type: str) -> None:
        self.index_name = index_name
        self.index_type = index_type

        self.keys = []
        self.values = []

    def __len__(self) -> int:
        return len(self.keys)

    def _get_embeddings(self, textList: List[str], type: TextType, show_progress_bar: bool = False) -> Any:
        raise NotImplementedError

    def _query(self, query_embedding: np.ndarray, top_k: int = 10) -> List[int]:
        raise NotImplementedError

    def query(self, query_text: str, n: int, return_keys: bool = False) -> List[Any]:
        embedding_query = self._get_embeddings([query_text], TextType.QUERY)
        indices = self._query(embedding_query, n)
        if return_keys:
            results = [(self.keys[i], self.values[i]) for i in indices]
        else:
            results = [self.values[i] for i in indices]
        return results

    def clear(self) -> None:
        self.keys = []
        self.encoded_keys = []
        self.values = []

    def create_index(self, key_value_pairs: Dict[str, int]) -> None:
        if len(self.keys) > 0:
            raise ValueError("Index is not empty. Please create a new index or clear the existing one.")

        for key, value in key_value_pairs.items():
            self.keys.append(key)
            self.values.append(value)


SciBert

In [27]:
import torch
from transformers import AutoTokenizer, AutoModel
import faiss

from tqdm import tqdm
from typing import List, Dict, Any

class SciBert(Retrieval):
    def __init__(self, index_name: str):
        super().__init__(index_name, 'SciBert')
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self._tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
        self._model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased').to(self.device)
        self.index = None
        self.faiss_index = None

        def clear(self):
            super().clear()
            self.index = None
            self.faiss_index = None

    def _get_embeddings(self, textList: List[str], type: TextType, show_progress_bar: bool = True) -> torch.Tensor:
        batch_size = 32
        embeddings = []

        should_show_progress = show_progress_bar and (type == TextType.KEY)

        iterator = range(0, len(textList), batch_size)
        if should_show_progress:
            iterator = tqdm(iterator, desc="Processing document embeddings")

        for i in iterator:
            batch_texts = textList[i:i+batch_size]
            encoded = self._tokenizer(
                batch_texts,
                padding=True,
                truncation=True,
                return_tensors="pt",
                max_length=512
            ).to(self.device)

            with torch.no_grad():
                outputs = self._model(**encoded)

            batch_embeddings = outputs.last_hidden_state[:,0,:].cpu()
            embeddings.append(batch_embeddings)

            if i % (batch_size * 10) == 0:
                torch.cuda.empty_cache()

        embeddings = torch.cat(embeddings, dim=0)

        return embeddings

    def create_index(self, key_value_pairs: Dict[str, Any]) -> None:
        super().create_index(key_value_pairs)
        self.index = self._get_embeddings(self.keys, TextType.KEY)

        # FAISS 인덱스 생성
        vector_dim = self.index.shape[1]
        index_flat = faiss.IndexFlatIP(vector_dim)
        index_vectors = self.index.numpy()
        faiss.normalize_L2(index_vectors)

        # 인덱스에 벡터 추가
        index_flat.add(index_vectors)
        self.faiss_index = index_flat

    def _query(self, query_embedding: torch.Tensor, top_k: int = 10) -> List[int]:
        if self.faiss_index is None:
            raise ValueError("FAISS index has not been created yet. Call create_index first.")

        query_vector = query_embedding.numpy()
        faiss.normalize_L2(query_vector)
        distances, indices = self.faiss_index.search(query_vector, top_k)

        return indices[0].tolist()

    def query(self, query_text: str, n: int, return_keys: bool = False) -> List[Any]:
        query_embedding = self._get_embeddings([query_text], TextType.QUERY)
        indices = self._query(query_embedding, n)

        if return_keys:
            results = [(self.keys[i], self.values[i]) for i in indices]
        else:
            results = [self.values[i] for i in indices]

        return results


2. SciNCL

In [30]:
class Scincl(Retrieval):
    def __init__(self, index_name:str):
        super().__init__(index_name, 'SciNCL')
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self._tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
        self._model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased').to(self.device)
        self.index = None
        self.faiss_index = None

    def clear(self):
        super().clear()
        self.index = None
        self.faiss_index = None

    def _get_embeddings(self, textList: List[str], type: TextType, show_progress_bar: bool = True) -> torch.Tensor:
        batch_size = 32
        embeddings = []

        should_show_progress = show_progress_bar and (type == TextType.KEY)

        iterator = range(0, len(textList), batch_size)
        if should_show_progress:
            iterator = tqdm(iterator, desc="Processing document embeddings")

        for i in iterator:
            batch_texts = textList[i:i+batch_size]
            encoded = self._tokenizer(
                batch_texts,
                padding=True,
                truncation=True,
                return_tensors="pt",
                max_length=512
            ).to(self.device)

            with torch.no_grad():
                outputs = self._model(**encoded)

            batch_embeddings = outputs.last_hidden_state[:,0,:].cpu()
            embeddings.append(batch_embeddings)

            if i % (batch_size * 10) == 0:
                torch.cuda.empty_cache()

        embeddings = torch.cat(embeddings, dim=0)

        return embeddings

    def create_index(self, key_value_pairs: Dict[str, Any]) -> None:
        super().create_index(key_value_pairs)
        self.index = self._get_embeddings(self.keys, TextType.KEY)

        # FAISS 인덱스 생성
        vector_dim = self.index.shape[1]
        index_flat = faiss.IndexFlatIP(vector_dim)
        index_vectors = self.index.numpy()
        faiss.normalize_L2(index_vectors)

        # 인덱스에 벡터 추가
        index_flat.add(index_vectors)
        self.faiss_index = index_flat

    def _query(self, query_embedding: torch.Tensor, top_k: int = 10) -> List[int]:
        if self.faiss_index is None:
            raise ValueError("FAISS index has not been created yet. Call create_index first.")

        query_vector = query_embedding.numpy()
        faiss.normalize_L2(query_vector)
        distances, indices = self.faiss_index.search(query_vector, top_k)

        return indices[0].tolist()

    def query(self, query_text: str, n: int, return_keys: bool = False) -> List[Any]:
        query_embedding = self._get_embeddings([query_text], TextType.QUERY)
        indices = self._query(query_embedding, n)

        if return_keys:
            results = [(self.keys[i], self.values[i]) for i in indices]
        else:
            results = [self.values[i] for i in indices]

        return results


3. SPECTER2

In [80]:
from typing import List, Any, Tuple, Dict, Optional
from transformers import AutoTokenizer
from adapters import AutoAdapterModel

class Specter2(Retrieval):
    def __init__(self, index_name: str):
        super().__init__(index_name, 'SPECTER2')
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # SPECTER2 base model과 tokenizer 로드
        self._tokenizer = AutoTokenizer.from_pretrained('allenai/specter2_base')
        self._model = AutoAdapterModel.from_pretrained('allenai/specter2_base')

        # Retrieval용 proximity adapter 로드 (문서 임베딩용)
        self._model.load_adapter("allenai/specter2", source="hf", load_as="proximity")
        # Adhoc query adapter 로드 (쿼리 임베딩용)
        self._model.load_adapter("allenai/specter2_adhoc_query", source="hf", load_as="adhoc_query")


        self._model.to(self.device)
        self.index = None
        self.faiss_index = None

    def clear(self):
        super().clear()
        self.index = None
        self.faiss_index = None

    def _get_embeddings(self, textList: List[str], type: TextType, show_progress_bar: bool = True) -> torch.Tensor:

        if type == TextType.KEY:
            self._model.set_active_adapters("proximity")
        else:
            self._model.set_active_adapters("adhoc_query")

        batch_size = 32
        embeddings = []

        should_show_progress = show_progress_bar and (type == TextType.KEY)

        iterator = range(0, len(textList), batch_size)
        if should_show_progress:
            iterator = tqdm(iterator, desc="Processing document embeddings")

        for i in iterator:
            batch_texts = textList[i:i+batch_size]
            encoded = self._tokenizer(
                batch_texts,
                padding=True,
                truncation=True,
                return_tensors="pt",
                max_length=512
            ).to(self.device)

            with torch.no_grad():
                outputs = self._model(**encoded)

            batch_embeddings = outputs.last_hidden_state[:,0,:].cpu()
            embeddings.append(batch_embeddings)

            if i % (batch_size * 10) == 0:
                torch.cuda.empty_cache()

        embeddings = torch.cat(embeddings, dim=0)

        return embeddings

    def create_index(self, key_value_pairs: Dict[str, Any]) -> None:
        super().create_index(key_value_pairs)
        self.index = self._get_embeddings(self.keys, TextType.KEY)

        # FAISS 인덱스 생성
        vector_dim = self.index.shape[1]
        index_flat = faiss.IndexFlatIP(vector_dim)
        index_vectors = self.index.numpy()
        faiss.normalize_L2(index_vectors)

        # 인덱스에 벡터 추가
        index_flat.add(index_vectors)
        self.faiss_index = index_flat

    def _query(self, query_embedding: torch.Tensor, top_k: int = 10) -> List[int]:
        if self.faiss_index is None:
            raise ValueError("FAISS index has not been created yet. Call create_index first.")

        query_vector = query_embedding.numpy()
        faiss.normalize_L2(query_vector)
        distances, indices = self.faiss_index.search(query_vector, top_k)

        return indices[0].tolist()

    def query(self, query_text: str, n: int, return_keys: bool = False) -> List[Any]:
        query_embedding = self._get_embeddings([query_text], TextType.QUERY)
        indices = self._query(query_embedding, n)

        if return_keys:
            results = [(self.keys[i], self.values[i]) for i in indices]
        else:
            results = [self.values[i] for i in indices]

        return results


Evaluate for litSearch datasets

In [28]:
scibert = SciBert("Title_Abstract")

In [29]:
scibert.create_index(kv_pairs)

Processing document embeddings: 100%|██████████| 1802/1802 [11:39<00:00,  2.58it/s]


In [32]:
scibert_query_set = [query for query in query_data]
for query in tqdm(scibert_query_set):
    query_text = query["query"]
    top_k = scibert.query(query_text, 20)
    query["retrieved"] = top_k

100%|██████████| 597/597 [00:39<00:00, 15.09it/s]


Separate specific and broad set

In [61]:
def calculate_recall(corpusids: list, retrieved: list, k: int):
    top_k = retrieved[:k]
    intersection = set(corpusids) & set(top_k)
    return len(intersection) / len(corpusids) if corpusids and intersection else 0.0

In [64]:
import pandas as pd

scibert_query_df = pd.DataFrame(scibert_query_set)
scibert_broad_query_df = scibert_query_df[scibert_query_df['specificity'] == 0]
scibert_specific_query_df = scibert_query_df[scibert_query_df['specificity'] == 1]

scibert_broad_inline_query_df = scibert_broad_query_df[scibert_broad_query_df['query_set'].str.contains('inline')]
scibert_broad_manual_query_df = scibert_broad_query_df[scibert_broad_query_df['query_set'].str.contains('manual')]

scibert_specific_inline_query_df = scibert_specific_query_df[scibert_specific_query_df['query_set'].str.contains('inline')]
scibert_specific_manual_query_df = scibert_specific_query_df[scibert_specific_query_df['query_set'].str.contains('manual')]

In [65]:
recall_scibert_broad_inline = []
recall_scibert_broad_manual = []
recall_scibert_specific_inline = []
recall_scibert_specific_manual = []

for _, query in scibert_broad_inline_query_df.iterrows():
    result = calculate_recall(query['corpusids'], query['retrieved'], 20)
    recall_scibert_broad_inline.append(result)

for _, query in scibert_broad_manual_query_df.iterrows():
    result = calculate_recall(query['corpusids'], query['retrieved'], 20)
    recall_scibert_broad_manual.append(result)

for _, query in scibert_specific_inline_query_df.iterrows():
    result = calculate_recall(query['corpusids'], query['retrieved'], 20)
    recall_scibert_specific_inline.append(result)

for _, query in scibert_specific_manual_query_df.iterrows():
    result = calculate_recall(query['corpusids'], query['retrieved'], 20)
    recall_scibert_specific_manual.append(result)

result_scibert_broad_inline = np.mean(recall_scibert_broad_inline)
result_scibert_broad_manual = np.mean(recall_scibert_broad_manual)
result_scibert_specific_inline = np.mean(recall_scibert_specific_inline)
result_scibert_specific_manual = np.mean(recall_scibert_specific_manual)

In [66]:
print("=" * 50)
print("SciBERT Model Evaluation Results")
print("=" * 50)
print(f"{'Query Type':<20} {'Inline':<15} {'Manual':<15}")
print("-" * 50)
print(f"{'Broad Queries':<20} {result_scibert_broad_inline:.4f}{' '*9} {result_scibert_broad_manual:.4f}")
print(f"{'Specific Queries':<20} {result_scibert_specific_inline:.4f}{' '*9} {result_scibert_specific_manual:.4f}")
print("=" * 50)

SciBERT Model Evaluation Results
Query Type           Inline          Manual         
--------------------------------------------------
Broad Queries        0.0000          0.0000
Specific Queries     0.0000          0.0000


In [70]:
k_values = [1, 5, 10, 20]
recalls = {
    'scibert_broad_inline': {k: [] for k in k_values},
    'scibert_broad_manual': {k: [] for k in k_values},
    'scibert_specific_inline': {k: [] for k in k_values},
    'scibert_specific_manual': {k: [] for k in k_values}
}

for _, query in scibert_broad_inline_query_df.iterrows():
    for k in k_values:
        result = calculate_recall(query['corpusids'], query['retrieved'], k)
        recalls['scibert_broad_inline'][k].append(result)

for _, query in scibert_broad_manual_query_df.iterrows():
    for k in k_values:
        result = calculate_recall(query['corpusids'], query['retrieved'], k)
        recalls['scibert_broad_manual'][k].append(result)

for _, query in scibert_specific_inline_query_df.iterrows():
    for k in k_values:
        result = calculate_recall(query['corpusids'], query['retrieved'], k)
        recalls['scibert_specific_inline'][k].append(result)

for _, query in scibert_specific_manual_query_df.iterrows():
    for k in k_values:
        result = calculate_recall(query['corpusids'], query['retrieved'], k)
        recalls['scibert_specific_manual'][k].append(result)

results = {
    'scibert_broad_inline': {k: np.mean(recalls['scibert_broad_inline'][k]) for k in k_values},
    'scibert_broad_manual': {k: np.mean(recalls['scibert_broad_manual'][k]) for k in k_values},
    'scibert_specific_inline': {k: np.mean(recalls['scibert_specific_inline'][k]) for k in k_values},
    'scibert_specific_manual': {k: np.mean(recalls['scibert_specific_manual'][k]) for k in k_values}
}

In [71]:
print("=" * 80)
print("SciBERT Model Evaluation Results")
print("=" * 80)
print(f"{'Query Type':<20} {'k=1':<15} {'k=5':<15} {'k=10':<15} {'k=20':<15}")
print("-" * 80)
print(f"{'Broad Inline':<20} {results['scibert_broad_inline'][1]:.4f}{' '*9} {results['scibert_broad_inline'][5]:.4f}{' '*9} {results['scibert_broad_inline'][10]:.4f}{' '*9} {results['scibert_broad_inline'][20]:.4f}")
print(f"{'Broad Manual':<20} {results['scibert_broad_manual'][1]:.4f}{' '*9} {results['scibert_broad_manual'][5]:.4f}{' '*9} {results['scibert_broad_manual'][10]:.4f}{' '*9} {results['scibert_broad_manual'][20]:.4f}")
print(f"{'Specific Inline':<20} {results['scibert_specific_inline'][1]:.4f}{' '*9} {results['scibert_specific_inline'][5]:.4f}{' '*9} {results['scibert_specific_inline'][10]:.4f}{' '*9} {results['scibert_specific_inline'][20]:.4f}")
print(f"{'Specific Manual':<20} {results['scibert_specific_manual'][1]:.4f}{' '*9} {results['scibert_specific_manual'][5]:.4f}{' '*9} {results['scibert_specific_manual'][10]:.4f}{' '*9} {results['scibert_specific_manual'][20]:.4f}")
print("=" * 80)

SciBERT Model Evaluation Results
Query Type           k=1             k=5             k=10            k=20           
--------------------------------------------------------------------------------
Broad Inline         0.0000          0.0000          0.0000          0.0000
Broad Manual         0.0000          0.0000          0.0000          0.0000
Specific Inline      0.0000          0.0000          0.0000          0.0000
Specific Manual      0.0000          0.0000          0.0000          0.0000


2. SciNCL

In [67]:
scincl = Scincl("Title_Abstract")

In [68]:
scincl.create_index(kv_pairs)

Processing document embeddings: 100%|██████████| 1802/1802 [11:39<00:00,  2.58it/s]


In [69]:
scincl_query_set = [query for query in query_data]
for query in tqdm(scincl_query_set):
    query_text = query["query"]
    top_k = scincl.query(query_text, 20)
    query["retrieved"] = top_k

100%|██████████| 597/597 [00:39<00:00, 15.15it/s]


In [73]:
scincl_query_df = pd.DataFrame(scincl_query_set)
scincl_broad_query_df = scincl_query_df[scincl_query_df['specificity'] == 0]
scincl_specific_query_df = scincl_query_df[scincl_query_df['specificity'] == 1]

scincl_broad_inline_query_df = scincl_broad_query_df[scincl_broad_query_df['query_set'].str.contains('inline')]
scincl_broad_manual_query_df = scincl_broad_query_df[scincl_broad_query_df['query_set'].str.contains('manual')]

scincl_specific_inline_query_df = scincl_specific_query_df[scincl_specific_query_df['query_set'].str.contains('inline')]
scincl_specific_manual_query_df = scincl_specific_query_df[scincl_specific_query_df['query_set'].str.contains('manual')]

In [74]:
k_values = [1, 5, 10, 20]
recalls = {
    'scincl_broad_inline': {k: [] for k in k_values},
    'scincl_broad_manual': {k: [] for k in k_values},
    'scincl_specific_inline': {k: [] for k in k_values},
    'scincl_specific_manual': {k: [] for k in k_values}
}

for _, query in scincl_broad_inline_query_df.iterrows():
    for k in k_values:
        result = calculate_recall(query['corpusids'], query['retrieved'], k)
        recalls['scincl_broad_inline'][k].append(result)

for _, query in scincl_broad_manual_query_df.iterrows():
    for k in k_values:
        result = calculate_recall(query['corpusids'], query['retrieved'], k)
        recalls['scincl_broad_manual'][k].append(result)

for _, query in scincl_specific_inline_query_df.iterrows():
    for k in k_values:
        result = calculate_recall(query['corpusids'], query['retrieved'], k)
        recalls['scincl_specific_inline'][k].append(result)

for _, query in scincl_specific_manual_query_df.iterrows():
    for k in k_values:
        result = calculate_recall(query['corpusids'], query['retrieved'], k)
        recalls['scincl_specific_manual'][k].append(result)

results = {
    'scincl_broad_inline': {k: np.mean(recalls['scincl_broad_inline'][k]) for k in k_values},
    'scincl_broad_manual': {k: np.mean(recalls['scincl_broad_manual'][k]) for k in k_values},
    'scincl_specific_inline': {k: np.mean(recalls['scincl_specific_inline'][k]) for k in k_values},
    'scincl_specific_manual': {k: np.mean(recalls['scincl_specific_manual'][k]) for k in k_values}
}

In [75]:
print("=" * 80)
print("SciNCL Model Evaluation Results")
print("=" * 80)
print(f"{'Query Type':<20} {'k=1':<15} {'k=5':<15} {'k=10':<15} {'k=20':<15}")
print("-" * 80)
print(f"{'Broad Inline':<20} {results['scincl_broad_inline'][1]:.4f}{' '*9} {results['scincl_broad_inline'][5]:.4f}{' '*9} {results['scincl_broad_inline'][10]:.4f}{' '*9} {results['scincl_broad_inline'][20]:.4f}")
print(f"{'Broad Manual':<20} {results['scincl_broad_manual'][1]:.4f}{' '*9} {results['scincl_broad_manual'][5]:.4f}{' '*9} {results['scincl_broad_manual'][10]:.4f}{' '*9} {results['scincl_broad_manual'][20]:.4f}")
print(f"{'Specific Inline':<20} {results['scincl_specific_inline'][1]:.4f}{' '*9} {results['scincl_specific_inline'][5]:.4f}{' '*9} {results['scincl_specific_inline'][10]:.4f}{' '*9} {results['scincl_specific_inline'][20]:.4f}")
print(f"{'Specific Manual':<20} {results['scincl_specific_manual'][1]:.4f}{' '*9} {results['scincl_specific_manual'][5]:.4f}{' '*9} {results['scincl_specific_manual'][10]:.4f}{' '*9} {results['scincl_specific_manual'][20]:.4f}")
print("=" * 80)

SciNCL Model Evaluation Results
Query Type           k=1             k=5             k=10            k=20           
--------------------------------------------------------------------------------
Broad Inline         0.0000          0.0000          0.0000          0.0000
Broad Manual         0.0000          0.0000          0.0000          0.0000
Specific Inline      0.0000          0.0000          0.0000          0.0000
Specific Manual      0.0000          0.0000          0.0000          0.0000


3. SPECTER2

In [81]:
specter2 = Specter2("Title_Abstract")

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

In [77]:
def get_clean_title_abstract_for_specter2(item: dict) -> str:
    title = get_clean_title(item)
    abstract = get_clean_abstract(item)
    return f"{title} + {specter2._tokenizer.sep_token} + {abstract}"

def create_kv_pairs(data: List[dict], key: str) -> dict:
    if key == "title_abstract":
        kv_pairs = {get_clean_title_abstract_for_specter2(record): get_clean_corpusid(record) for record in data}
    elif key == "full_paper":
        kv_pairs = {get_clean_full_paper(record): get_clean_corpusid(record) for record in data}
    elif key == "paragraphs":
        kv_pairs = {}
        for record in data:
            corpusid = get_clean_corpusid(record)
            paragraphs = get_clean_paragraphs(record)
            for paragraph_idx, paragraph in enumerate(paragraphs):
                kv_pairs[paragraph] = (corpusid, paragraph_idx)
    else:
        raise ValueError("Invalid key")
    return kv_pairs

In [78]:
kv_pairs_for_specter2 = create_kv_pairs(corpus_clean_data, "title_abstract")

In [82]:
specter2.create_index(kv_pairs_for_specter2)

Processing document embeddings: 100%|██████████| 1802/1802 [12:14<00:00,  2.45it/s]


In [83]:
specter2_query_set = [query for query in query_data]
for query in tqdm(specter2_query_set):
    query_text = query["query"]
    top_k = specter2.query(query_text, 20)
    query["retrieved"] = top_k

100%|██████████| 597/597 [00:44<00:00, 13.54it/s]


In [84]:
specter2_query_df = pd.DataFrame(specter2_query_set)
specter2_broad_query_df = specter2_query_df[specter2_query_df['specificity'] == 0]
specter2_specific_query_df = specter2_query_df[specter2_query_df['specificity'] == 1]

specter2_broad_inline_query_df = specter2_broad_query_df[specter2_broad_query_df['query_set'].str.contains('inline')]
specter2_broad_manual_query_df = specter2_broad_query_df[specter2_broad_query_df['query_set'].str.contains('manual')]

specter2_specific_inline_query_df = specter2_specific_query_df[specter2_specific_query_df['query_set'].str.contains('inline')]
specter2_specific_manual_query_df = specter2_specific_query_df[specter2_specific_query_df['query_set'].str.contains('manual')]

In [85]:
k_values = [1, 5, 10, 20]
recalls = {
    'specter2_broad_inline': {k: [] for k in k_values},
    'specter2_broad_manual': {k: [] for k in k_values},
    'specter2_specific_inline': {k: [] for k in k_values},
    'specter2_specific_manual': {k: [] for k in k_values}
}

for _, query in specter2_broad_inline_query_df.iterrows():
    for k in k_values:
        result = calculate_recall(query['corpusids'], query['retrieved'], k)
        recalls['specter2_broad_inline'][k].append(result)

for _, query in specter2_broad_manual_query_df.iterrows():
    for k in k_values:
        result = calculate_recall(query['corpusids'], query['retrieved'], k)
        recalls['specter2_broad_manual'][k].append(result)

for _, query in specter2_specific_inline_query_df.iterrows():
    for k in k_values:
        result = calculate_recall(query['corpusids'], query['retrieved'], k)
        recalls['specter2_specific_inline'][k].append(result)

for _, query in specter2_specific_manual_query_df.iterrows():
    for k in k_values:
        result = calculate_recall(query['corpusids'], query['retrieved'], k)
        recalls['specter2_specific_manual'][k].append(result)

results = {
    'specter2_broad_inline': {k: np.mean(recalls['specter2_broad_inline'][k]) for k in k_values},
    'specter2_broad_manual': {k: np.mean(recalls['specter2_broad_manual'][k]) for k in k_values},
    'specter2_specific_inline': {k: np.mean(recalls['specter2_specific_inline'][k]) for k in k_values},
    'specter2_specific_manual': {k: np.mean(recalls['specter2_specific_manual'][k]) for k in k_values}
}

In [88]:
combined_results = {}
# Broad 카테고리 - Inline과 Manual 합침
for k in k_values:
    # 원본 데이터 리스트 합치기
    combined_broad_data = recalls['specter2_broad_inline'][k] + recalls['specter2_broad_manual'][k]
    # 합친 데이터의 평균 계산
    combined_results[f'specter2_broad_combined_{k}'] = np.mean(combined_broad_data) if combined_broad_data else 0.0

# Specific 카테고리 - Inline과 Manual 합침
for k in k_values:
    # 원본 데이터 리스트 합치기
    combined_specific_data = recalls['specter2_specific_inline'][k] + recalls['specter2_specific_manual'][k]
    # 합친 데이터의 평균 계산
    combined_results[f'specter2_specific_combined_{k}'] = np.mean(combined_specific_data) if combined_specific_data else 0.0

In [86]:
print("=" * 80)
print("SPECTER2 Model Evaluation Results")
print("=" * 80)
print(f"{'Query Type':<20} {'k=1':<15} {'k=5':<15} {'k=10':<15} {'k=20':<15}")
print("-" * 80)
print(f"{'Broad Inline':<20} {results['specter2_broad_inline'][1]:.4f}{' '*9} {results['specter2_broad_inline'][5]:.4f}{' '*9} {results['specter2_broad_inline'][10]:.4f}{' '*9} {results['specter2_broad_inline'][20]:.4f}")
print(f"{'Broad Manual':<20} {results['specter2_broad_manual'][1]:.4f}{' '*9} {results['specter2_broad_manual'][5]:.4f}{' '*9} {results['specter2_broad_manual'][10]:.4f}{' '*9} {results['specter2_broad_manual'][20]:.4f}")
print(f"{'Specific Inline':<20} {results['specter2_specific_inline'][1]:.4f}{' '*9} {results['specter2_specific_inline'][5]:.4f}{' '*9} {results['specter2_specific_inline'][10]:.4f}{' '*9} {results['specter2_specific_inline'][20]:.4f}")
print(f"{'Specific Manual':<20} {results['specter2_specific_manual'][1]:.4f}{' '*9} {results['specter2_specific_manual'][5]:.4f}{' '*9} {results['specter2_specific_manual'][10]:.4f}{' '*9} {results['specter2_specific_manual'][20]:.4f}")
print("=" * 80)

SPECTER2 Model Evaluation Results
Query Type           k=1             k=5             k=10            k=20           
--------------------------------------------------------------------------------
Broad Inline         0.1208          0.2267          0.3033          0.3814
Broad Manual         0.0857          0.1286          0.1714          0.2000
Specific Inline      0.1710          0.3377          0.4069          0.4502
Specific Manual      0.1754          0.3460          0.4265          0.5118


In [89]:
print("\nCombined Results (Inline + Manual):")
print("-" * 80)
print(f"{'Query Type':<20} {'k=1':<15} {'k=5':<15} {'k=10':<15} {'k=20':<15}")
print("-" * 80)
print(f"{'Broad Combined':<20} {combined_results['specter2_broad_combined_1']:.4f}{' '*9} {combined_results['specter2_broad_combined_5']:.4f}{' '*9} {combined_results['specter2_broad_combined_10']:.4f}{' '*9} {combined_results['specter2_broad_combined_20']:.4f}")
print(f"{'Specific Combined':<20} {combined_results['specter2_specific_combined_1']:.4f}{' '*9} {combined_results['specter2_specific_combined_5']:.4f}{' '*9} {combined_results['specter2_specific_combined_10']:.4f}{' '*9} {combined_results['specter2_specific_combined_20']:.4f}")
print("=" * 80)


Combined Results (Inline + Manual):
--------------------------------------------------------------------------------
Query Type           k=1             k=5             k=10            k=20           
--------------------------------------------------------------------------------
Broad Combined       0.1129          0.2045          0.2735          0.3404
Specific Combined    0.1731          0.3416          0.4163          0.4796
