In [6]:
!pip install --upgrade pip

[0m

In [8]:
!pip install tqdm sentence-transformers transformers faiss-gpu datasets

Collecting sentence-transformers
  Using cached sentence_transformers-4.1.0-py3-none-any.whl.metadata (13 kB)
Collecting transformers
  Downloading transformers-4.51.3-py3-none-any.whl.metadata (38 kB)
Collecting faiss-gpu
  Using cached faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)
Collecting datasets
  Using cached datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting huggingface-hub>=0.20.0 (from sentence-transformers)
  Using cached huggingface_hub-0.30.2-py3-none-any.whl.metadata (13 kB)
Collecting regex!=2019.12.17 (from transformers)
  Using cached regex-2024.11.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)
Collecting tokenizers<0.22,>=0.21 (from transformers)
  Downloading tokenizers-0.21.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)
Collecting safetensors>=0.4.3 (from transformers)
  Downloading safetensors-0.5.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_

Create Dataset

In [9]:
from datasets import load_dataset

query_data = load_dataset("princeton-nlp/LitSearch", "query", split="full")
corpus_clean_data = load_dataset("princeton-nlp/LitSearch", "corpus_clean", split="full")

README.md:   0%|          | 0.00/1.81k [00:00<?, ?B/s]

full-00000-of-00001.parquet:   0%|          | 0.00/50.6k [00:00<?, ?B/s]

Generating full split:   0%|          | 0/597 [00:00<?, ? examples/s]

full-00000-of-00006.parquet:   0%|          | 0.00/275M [00:00<?, ?B/s]

full-00001-of-00006.parquet:   0%|          | 0.00/196M [00:00<?, ?B/s]

full-00002-of-00006.parquet:   0%|          | 0.00/195M [00:00<?, ?B/s]

full-00003-of-00006.parquet:   0%|          | 0.00/199M [00:00<?, ?B/s]

full-00004-of-00006.parquet:   0%|          | 0.00/194M [00:00<?, ?B/s]

full-00005-of-00006.parquet:   0%|          | 0.00/198M [00:00<?, ?B/s]

Generating full split:   0%|          | 0/64183 [00:00<?, ? examples/s]

In [10]:
from typing import List, Tuple, Any
from datasets import Dataset

def get_clean_corpusid(item: dict) -> int:
    return item['corpusid']

def get_clean_title(item: dict) -> str:
    return item['title']

def get_clean_abstract(item: dict) -> str:
    return item['abstract']

def get_clean_title_abstract(item: dict) -> str:
    title = get_clean_title(item)
    abstract = get_clean_abstract(item)
    return f"Title: {title}\nAbstract: {abstract}"

def get_clean_full_paper(item: dict) -> str:
    return item['full_paper']

def get_clean_paragraph_indices(item: dict) -> List[Tuple[int, int]]:
    text = get_clean_full_paper(item)
    paragraph_indices = []
    paragraph_start = 0
    paragraph_end = 0
    while paragraph_start < len(text):
        paragraph_end = text.find("\n\n", paragraph_start)
        if paragraph_end == -1:
            paragraph_end = len(text)
        paragraph_indices.append((paragraph_start, paragraph_end))
        paragraph_start = paragraph_end + 2
    return paragraph_indices

def get_clean_text(item: dict, start_idx: int, end_idx: int) -> str:
    text = get_clean_full_paper(item)
    assert start_idx >= 0 and end_idx >= 0
    assert start_idx <= end_idx
    assert end_idx <= len(text)
    return text[start_idx:end_idx]

def get_clean_paragraphs(item: dict, min_words: int = 10) -> List[str]:
    paragraph_indices = get_clean_paragraph_indices(item)
    paragraphs = [get_clean_text(item, paragraph_start, paragraph_end) for paragraph_start, paragraph_end in paragraph_indices]
    paragraphs = [paragraph for paragraph in paragraphs if len(paragraph.split()) >= min_words]
    return paragraphs

def get_clean_citations(item: dict) -> List[int]:
    return item['citations']

def get_clean_dict(data: Dataset) -> dict:
    return {get_clean_corpusid(item): item for item in data}

def create_kv_pairs(data: List[dict], key: str) -> dict:
    if key == "title_abstract":
        kv_pairs = {get_clean_title_abstract(record): get_clean_corpusid(record) for record in data}
    elif key == "full_paper":
        kv_pairs = {get_clean_full_paper(record): get_clean_corpusid(record) for record in data}
    elif key == "paragraphs":
        kv_pairs = {}s
        for record in data:
            corpusid = get_clean_corpusid(record)
            paragraphs = get_clean_paragraphs(record)
            for paragraph_idx, paragraph in enumerate(paragraphs):
                kv_pairs[paragraph] = (corpusid, paragraph_idx)
    else:
        raise ValueError("Invalid key")
    return kv_pairs

In [11]:
kv_pairs = create_kv_pairs(corpus_clean_data, "title_abstract")

Model
1. Scibert
2. SciNCL
3. SPECTER2

Parent

In [14]:
from enum import Enum
from typing import Dict, Any, List

import numpy as np

class TextType(Enum):
    KEY = 1
    QUERY = 2

class Retrieval:
    def __init__(self, index_name: str, index_type: str) -> None:
        self.index_name = index_name
        self.index_type = index_type

        self.keys = []
        self.values = []

    def __len__(self) -> int:
        return len(self.keys)

    def _get_embeddings(self, textList: List[str], type: TextType, show_progress_bar: bool = False) -> Any:
        raise NotImplementedError

    def _query(self, query_embedding: np.ndarray, top_k: int = 10) -> List[int]:
        raise NotImplementedError

    def query(self, query_text: str, n: int, return_keys: bool = False) -> List[Any]:
        embedding_query = self._get_embeddings([query_text], TextType.QUERY)
        indices = self._query(embedding_query, n)
        if return_keys:
            results = [(self.keys[i], self.values[i]) for i in indices]
        else:
            results = [self.values[i] for i in indices]
        return results

    def clear(self) -> None:
        self.keys = []
        self.encoded_keys = []
        self.values = []

    def create_index(self, key_value_pairs: Dict[str, int]) -> None:
        if len(self.keys) > 0:
            raise ValueError("Index is not empty. Please create a new index or clear the existing one.")

        for key, value in key_value_pairs.items():
            self.keys.append(key)
            self.values.append(value)


SciBert

In [16]:
import torch
from transformers import AutoTokenizer, AutoModel
import faiss

from tqdm import tqdm
from typing import List, Dict, Any

class SciBert(Retrieval):
    def __init__(self, index_name: str):
        super().__init__(index_name, 'SciBert')
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self._tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
        self._model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased').to(self.device)
        self.index = None
        self.faiss_index = None

        def clear(self):
            super().clear()
            self.index = None
            self.faiss_index = None

    def _get_embeddings(self, textList: List[str], type: TextType, show_progress_bar: bool = True) -> torch.Tensor:
        batch_size = 16
        embeddings = []

        for i in tqdm(range(0, len(textList), batch_size), desc="Getting embeddings"):
            batch_keys = self.keys[i:i+batch_size]
            encoded = self._tokenizer(
                batch_keys,
                padding=True,
                truncation=True,
                return_tensors="pt",
                max_length=512  # 최대 길이 명시적 지정
            ).to(self.device)
            with torch.no_grad():
                outputs = self._model(**encoded)

            batch_embeddings = outputs.last_hidden_state[:,0,:].cpu()
            embeddings.append(batch_embeddings)

            if i % (batch_size * 10) == 0:
                torch.cuda.empty_cache()

        embeddings = torch.cat(embeddings, dim=0)

        return embeddings

    def create_index(self, key_value_pairs: Dict[str, Any]) -> None:
        super().create_index(key_value_pairs)
        self.index = self._get_embeddings(self.keys, TextType.KEY)

        # FAISS 인덱스 생성
        vector_dim = self.index.shape[1]
        index_flat = faiss.IndexFlatIP(vector_dim)
        index_vectors = self.index.numpy()
        faiss.normalize_L2(index_vectors)

        # 인덱스에 벡터 추가
        index_flat.add(index_vectors)
        self.faiss_index = index_flat

    def _query(self, query_embedding: torch.Tensor, top_k: int = 10) -> List[int]:
        if self.faiss_index is None:
            raise ValueError("FAISS index has not been created yet. Call create_index first.")

        query_vector = query_embedding.numpy()
        faiss.normalize_L2(query_vector)
        distances, indices = self.faiss_index.search(query_vector, top_k)

        return indices[0].tolist()

    def query(self, query_text: str, n: int, return_keys: bool = False) -> List[Any]:
        query_embedding = self._get_embeddings([query_text], TextType.QUERY)
        indices = self._query(query_embedding, n)

        if return_keys:
            results = [(self.keys[i], self.values[i]) for i in indices]
        else:
            results = [self.values[i] for i in indices]

        return results


2. SciNCL

In [18]:
class Scincl(Retrieval):
    def __init__(self, index_name:str):
        super().__init__(index_name, 'SciNCL')
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self._tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
        self._model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased').to(self.device)
        self.index = None
        self.faiss_index = None

    def clear(self):
        super().clear()
        self.index = None
        self.faiss_index = None

    def _get_embeddings(self, textList: List[str], type: TextType, show_progress_bar: bool = True) -> torch.Tensor:
        batch_size = 16
        embeddings = []

        for i in tqdm(range(0, len(textList), batch_size), desc="Getting embeddings"):
            batch_keys = self.keys[i:i+batch_size]
            encoded = self._tokenizer(
                batch_keys,
                padding=True,
                truncation=True,
                return_tensors="pt",
                max_length=512  # 최대 길이 명시적 지정
            ).to(self.device)
            with torch.no_grad():
                outputs = self._model(**encoded)

            batch_embeddings = outputs.last_hidden_state[:,0,:].cpu()
            embeddings.append(batch_embeddings)

            if i % (batch_size * 10) == 0:
                torch.cuda.empty_cache()

        embeddings = torch.cat(embeddings, dim=0)

        return embeddings

    def create_index(self, key_value_pairs: Dict[str, Any]) -> None:
        super().create_index(key_value_pairs)
        self.index = self._get_embeddings(self.keys, TextType.KEY)

        # FAISS 인덱스 생성
        vector_dim = self.index.shape[1]
        index_flat = faiss.IndexFlatIP(vector_dim)
        index_vectors = self.index.numpy()
        faiss.normalize_L2(index_vectors)

        # 인덱스에 벡터 추가
        index_flat.add(index_vectors)
        self.faiss_index = index_flat

    def _query(self, query_embedding: torch.Tensor, top_k: int = 10) -> List[int]:
        if self.faiss_index is None:
            raise ValueError("FAISS index has not been created yet. Call create_index first.")

        query_vector = query_embedding.numpy()
        faiss.normalize_L2(query_vector)
        distances, indices = self.faiss_index.search(query_vector, top_k)

        return indices[0].tolist()

    def query(self, query_text: str, n: int, return_keys: bool = False) -> List[Any]:
        query_embedding = self._get_embeddings([query_text], TextType.QUERY)
        indices = self._query(query_embedding, n)

        if return_keys:
            results = [(self.keys[i], self.values[i]) for i in indices]
        else:
            results = [self.values[i] for i in indices]

        return results


3. SPECTER2

In [45]:
!pip install adapters

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting adapters
  Downloading adapters-1.1.1-py3-none-any.whl.metadata (17 kB)
Collecting transformers~=4.48.3 (from adapters)
  Downloading transformers-4.48.3-py3-none-any.whl.metadata (44 kB)
Downloading adapters-1.1.1-py3-none-any.whl (289 kB)
Downloading transformers-4.48.3-py3-none-any.whl (9.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.7/9.7 MB[0m [31m30.8 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: transformers, adapters
  Attempting uninstall: transformers
    Found existing installation: transformers 4.51.3
    Uninstalling transformers-4.51.3:
      Successfully uninstalled transformers-4.51.3
Successfully installed adapters-1.1.1 transformers-4.48.3
[0m

In [63]:
from typing import List, Any, Tuple, Dict, Optional
from transformers import AutoTokenizer
from adapters import AutoAdapterModel

class Specter2(Retrieval):
    def __init__(self, index_name: str):
        super().__init__(index_name, 'SPECTER2')
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # SPECTER2 base model과 tokenizer 로드
        self._tokenizer = AutoTokenizer.from_pretrained('allenai/specter2_base')
        self._model = AutoAdapterModel.from_pretrained('allenai/specter2_base').to(self.device)

        # Retrieval용 proximity adapter 로드 (문서 임베딩용)
        self._model.load_adapter("allenai/specter2", source="hf", load_as="proximity")
        # Adhoc query adapter 로드 (쿼리 임베딩용)
        self._model.load_adapter("allenai/specter2_adhoc_query", source="hf", load_as="adhoc_query")

        print(f"Model device check: {next(self._model.parameters()).device}")

        self.index = None
        self.faiss_index = None

    def clear(self):
        super().clear()
        self.index = None
        self.faiss_index = None

    def _get_embeddings(self, textList: List[str], type: TextType, show_progress_bar: bool = True) -> torch.Tensor:

        if type == TextType.KEY:
            self._model.set_active_adapters("proximity")
        else:
            self._model.set_active_adapters("adhoc_query")

        batch_size = 16
        embeddings = []

        for i in tqdm(range(0, len(textList), batch_size), desc="Getting embeddings"):
            batch_texts = textList[i:i+batch_size]
            encoded = self._tokenizer(
                batch_texts,
                padding=True,
                truncation=True,
                return_tensors="pt",
                max_length=512  # 최대 길이 명시적 지정
            ).to(self.device)
            print(f"Model device check: {encoded.device}")
            
            with torch.no_grad():
                outputs = self._model(**encoded)

            batch_embeddings = outputs.last_hidden_state[:,0,:].cpu()
            embeddings.append(batch_embeddings)

            if i % (batch_size * 10) == 0:
                torch.cuda.empty_cache()

        embeddings = torch.cat(embeddings, dim=0)

        return embeddings

    def create_index(self, key_value_pairs: Dict[str, Any]) -> None:
        super().create_index(key_value_pairs)
        self.index = self._get_embeddings(self.keys, TextType.KEY)

        # FAISS 인덱스 생성
        vector_dim = self.index.shape[1]
        index_flat = faiss.IndexFlatIP(vector_dim)
        index_vectors = self.index.numpy()
        faiss.normalize_L2(index_vectors)

        # 인덱스에 벡터 추가
        index_flat.add(index_vectors)
        self.faiss_index = index_flat

    def _query(self, query_embedding: torch.Tensor, top_k: int = 10) -> List[int]:
        if self.faiss_index is None:
            raise ValueError("FAISS index has not been created yet. Call create_index first.")

        query_vector = query_embedding.numpy()
        faiss.normalize_L2(query_vector)
        distances, indices = self.faiss_index.search(query_vector, top_k)

        return indices[0].tolist()

    def query(self, query_text: str, n: int, return_keys: bool = False) -> List[Any]:
        query_embedding = self._get_embeddings([query_text], TextType.QUERY)
        indices = self._query(query_embedding, n)

        if return_keys:
            results = [(self.keys[i], self.values[i]) for i in indices]
        else:
            results = [self.values[i] for i in indices]

        return results


Evaluate for litSearch datasets

In [53]:
scibert = SciBert("Title_Abstract")

In [54]:
scibert.create_index(kv_pairs)

Getting embeddings:   1%|          | 27/3604 [00:06<13:58,  4.27it/s]


KeyboardInterrupt: 

In [29]:
scibert_query_set = [query for query in query_data]
for query in tqdm(scibert_query_set):
    query_text = query["query"]
    top_k = scibert.query(query_text, 20)
    query["retrieved"] = top_k

  0%|          | 0/597 [00:00<?, ?it/s]
Getting embeddings:   0%|          | 0/1 [00:00<?, ?it/s][A
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00,  4.19it/s][A
  0%|          | 1/597 [00:00<03:37,  2.75it/s]
Getting embeddings:   0%|          | 0/1 [00:00<?, ?it/s][A
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00,  4.18it/s][A
  0%|          | 2/597 [00:00<03:44,  2.65it/s]
Getting embeddings:   0%|          | 0/1 [00:00<?, ?it/s][A
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00,  4.18it/s][A
  1%|          | 3/597 [00:01<03:26,  2.88it/s]
Getting embeddings:   0%|          | 0/1 [00:00<?, ?it/s][A
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00,  4.16it/s][A
  1%|          | 4/597 [00:01<03:35,  2.75it/s]
Getting embeddings:   0%|          | 0/1 [00:00<?, ?it/s][A
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00,  4.19it/s][A
  1%|          | 5/597 [00:01<03:25,  2.89it/s]
Getting embeddings:   0%|          | 0/1 [00:00<?, ?it/s][A
Getting e

Separate specific and broad set

In [32]:
def calculate_recall(corpusids: list, retrieved: list, k: int):
    top_k = retrieved[:k]
    intersection = set(corpusids) & set(top_k)
    return len(intersection) / len(corpusids) if corpusids else 0.0

In [30]:
import pandas as pd

scibert_query_set_df = pd.DataFrame(scibert_query_set)
broad_scibert_query_set_df = scibert_query_set_df[scibert_query_set_df['specificity'] == 0]
specific_scibert_query_set_df = scibert_query_set_df[scibert_query_set_df['specificity'] == 1]

In [33]:
# Specific

scibert_all_recall_at5 = []

for _, query in specific_scibert_query_set_df.iterrows():
    r5 = calculate_recall(query['corpusids'], query['retrieved'], 5)
    scibert_all_recall_at5.append(r5)

scibert_mean_recall_at5 = np.mean(scibert_all_recall_at5)

In [34]:
# Broad

scibert_all_recall_at20 = []

for _, query in broad_scibert_query_set_df.iterrows():
    r20 = calculate_recall(query['corpusids'], query['retrieved'], 20)
    scibert_all_recall_at20.append(r20)

scibert_mean_recall_at20 = np.mean(scibert_all_recall_at20)

In [35]:
print(scibert_mean_recall_at5, scibert_mean_recall_at20)

0.0 0.0


2. SciNCL

In [36]:
scincl = Scincl("Title_Abstract")

In [37]:
scincl.create_index(kv_pairs)

Getting embeddings: 100%|██████████| 3604/3604 [10:33<00:00,  5.69it/s]


In [38]:
scincl_query_set = [query for query in query_data]
for query in tqdm(scincl_query_set):
    query_text = query["query"]
    top_k = scincl.query(query_text, 20)
    query["retrieved"] = top_k

  0%|          | 0/597 [00:00<?, ?it/s]
Getting embeddings:   0%|          | 0/1 [00:00<?, ?it/s][A
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00,  4.14it/s][A
  0%|          | 1/597 [00:00<03:17,  3.02it/s]
Getting embeddings:   0%|          | 0/1 [00:00<?, ?it/s][A
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00,  4.18it/s][A
  0%|          | 2/597 [00:00<03:47,  2.61it/s]
Getting embeddings:   0%|          | 0/1 [00:00<?, ?it/s][A
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00,  4.03it/s][A
  1%|          | 3/597 [00:01<04:10,  2.37it/s]
Getting embeddings:   0%|          | 0/1 [00:00<?, ?it/s][A
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00,  4.18it/s][A
  1%|          | 4/597 [00:01<04:25,  2.23it/s]
Getting embeddings:   0%|          | 0/1 [00:00<?, ?it/s][A
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00,  4.16it/s][A
  1%|          | 5/597 [00:02<04:07,  2.39it/s]
Getting embeddings:   0%|          | 0/1 [00:00<?, ?it/s][A
Getting e

In [39]:
scincl_query_set_df = pd.DataFrame(scincl_query_set)
broad_scincl_query_set_df = scincl_query_set_df[scincl_query_set_df['specificity'] == 0]
specific_scincl_query_set_df = scincl_query_set_df[scincl_query_set_df['specificity'] == 1]

In [40]:
# Specific

scincl_all_recall_at5 = []

for _, query in specific_scincl_query_set_df.iterrows():
    r5 = calculate_recall(query['corpusids'], query['retrieved'], 5)
    scincl_all_recall_at5.append(r5)

scincl_mean_recall_at5 = np.mean(scincl_all_recall_at5)

In [41]:
# Broad

scincl_all_recall_at20 = []

for _, query in broad_scincl_query_set_df.iterrows():
    r20 = calculate_recall(query['corpusids'], query['retrieved'], 20)
    scincl_all_recall_at20.append(r20)

scincl_mean_recall_at20 = np.mean(scincl_all_recall_at20)

In [42]:
print(scincl_mean_recall_at5, scincl_mean_recall_at20)

0.0 0.0


3. SPECTER2

In [87]:
from typing import List, Any, Tuple, Dict, Optional
from transformers import AutoTokenizer
from adapters import AutoAdapterModel

class Specter2(Retrieval):
    def __init__(self, index_name: str):
        super().__init__(index_name, 'SPECTER2')
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # SPECTER2 base model과 tokenizer 로드
        self._tokenizer = AutoTokenizer.from_pretrained('allenai/specter2_base')
        self._model = AutoAdapterModel.from_pretrained('allenai/specter2_base')

        # Retrieval용 proximity adapter 로드 (문서 임베딩용)
        self._model.load_adapter("allenai/specter2", source="hf", load_as="proximity")
        # Adhoc query adapter 로드 (쿼리 임베딩용)
        self._model.load_adapter("allenai/specter2_adhoc_query", source="hf", load_as="adhoc_query")

        # load_adapter 이후 gpu로 옮기지 않으면 adapter parameter가 CPU에 남아있습니다.
        self._model = self._model.to(self.device)

        self.index = None
        self.faiss_index = None

    def clear(self):
        super().clear()
        self.index = None
        self.faiss_index = None

    def _get_embeddings(self, textList: List[str], type: TextType, show_progress_bar: bool = True) -> torch.Tensor:

        if type == TextType.KEY:
            self._model.set_active_adapters("proximity")
        else:
            self._model.set_active_adapters("adhoc_query")

        batch_size = 16
        embeddings = []

        for i in tqdm(range(0, len(textList), batch_size), desc="Getting embeddings"):
            batch_texts = textList[i:i+batch_size]
            encoded = self._tokenizer(
                batch_texts,
                padding=True,
                truncation=True,
                return_tensors="pt",
                max_length=512  # 최대 길이 명시적 지정
            ).to(self.device)
            
            with torch.no_grad():
                outputs = self._model(**encoded)

            batch_embeddings = outputs.last_hidden_state[:,0,:].cpu()
            embeddings.append(batch_embeddings)

            if i % (batch_size * 10) == 0:
                torch.cuda.empty_cache()

        embeddings = torch.cat(embeddings, dim=0)

        return embeddings

    def create_index(self, key_value_pairs: Dict[str, Any]) -> None:
        super().create_index(key_value_pairs)
        self.index = self._get_embeddings(self.keys, TextType.KEY)

        # FAISS 인덱스 생성
        vector_dim = self.index.shape[1]
        index_flat = faiss.IndexFlatIP(vector_dim)
        index_vectors = self.index.numpy()
        faiss.normalize_L2(index_vectors)

        # 인덱스에 벡터 추가
        index_flat.add(index_vectors)
        self.faiss_index = index_flat

    def _query(self, query_embedding: torch.Tensor, top_k: int = 10) -> List[int]:
        if self.faiss_index is None:
            raise ValueError("FAISS index has not been created yet. Call create_index first.")

        query_vector = query_embedding.numpy()
        faiss.normalize_L2(query_vector)
        distances, indices = self.faiss_index.search(query_vector, top_k)

        return indices[0].tolist()

    def query(self, query_text: str, n: int, return_keys: bool = False) -> List[Any]:
        query_embedding = self._get_embeddings([query_text], TextType.QUERY)
        indices = self._query(query_embedding, n)

        if return_keys:
            results = [(self.keys[i], self.values[i]) for i in indices]
        else:
            results = [self.values[i] for i in indices]

        return results


In [88]:
specter2 = Specter2("Title_Abstract")

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

In [89]:
specter2.create_index(kv_pairs)

Getting embeddings: 100%|██████████| 3604/3604 [11:04<00:00,  5.43it/s]


In [90]:
specter2_query_set = [query for query in query_data]
for query in tqdm(specter2_query_set):
    query_text = query["query"]
    top_k = specter2.query(query_text, 20)
    query["retrieved"] = top_k

  0%|          | 0/597 [00:00<?, ?it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 60.36it/s]

Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 72.81it/s]
  0%|          | 2/597 [00:00<00:35, 16.96it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 80.61it/s]

Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 77.70it/s]
  1%|          | 4/597 [00:00<00:36, 16.36it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 77.16it/s]

Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 74.81it/s]
  1%|          | 6/597 [00:00<00:39, 15.04it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 75.53it/s]

Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 81.59it/s]
  1%|▏         | 8/597 [00:00<00:45, 12.82it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 76.52it/s]

Getting embeddings: 100%|██████████| 1/1 [00:00<00:00, 78.15it/s]
  2%|▏         | 10/597 [00:00<00:41, 14.23it/s]
Getting embeddings: 100%|██████████| 1/1 [00:00<00:00,

In [91]:
specter2_query_set_df = pd.DataFrame(specter2_query_set)
broad_specter2_query_set_df = specter2_query_set_df[specter2_query_set_df['specificity'] == 0]
specific_specter2_query_set_df = specter2_query_set_df[specter2_query_set_df['specificity'] == 1]

In [92]:
# Specific

specter2_all_recall_at5 = []

for _, query in specific_specter2_query_set_df.iterrows():
    r5 = calculate_recall(query['corpusids'], query['retrieved'], 5)
    specter2_all_recall_at5.append(r5)

specter2_mean_recall_at5 = np.mean(specter2_all_recall_at5)

In [93]:
# Broad

specter2_all_recall_at20 = []

for _, query in broad_specter2_query_set_df.iterrows():
    r20 = calculate_recall(query['corpusids'], query['retrieved'], 20)
    specter2_all_recall_at20.append(r20)

specter2_mean_recall_at20 = np.mean(specter2_all_recall_at20)

In [94]:
print(specter2_mean_recall_at5, specter2_mean_recall_at20)

0.3257918552036199 0.34559139784946236
