In [1]:
!pip install tqdm rank_bm25==0.2.2 nltk transformers==4.30.2 numpy==1.23.5 torch==1.13.1 sentence-transformers==2.2.2 datasets==2.20.0

Collecting rank_bm25==0.2.2
  Downloading rank_bm25-0.2.2-py3-none-any.whl.metadata (3.2 kB)
Collecting nltk
  Downloading nltk-3.9.1-py3-none-any.whl.metadata (2.9 kB)
Collecting transformers==4.30.2
  Downloading transformers-4.30.2-py3-none-any.whl.metadata (113 kB)
Collecting numpy==1.23.5
  Downloading numpy-1.23.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.3 kB)
Collecting torch==1.13.1
  Downloading torch-1.13.1-cp310-cp310-manylinux1_x86_64.whl.metadata (24 kB)
Collecting sentence-transformers==2.2.2
  Downloading sentence-transformers-2.2.2.tar.gz (85 kB)
  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting datasets==2.20.0
  Downloading datasets-2.20.0-py3-none-any.whl.metadata (19 kB)
Collecting huggingface-hub<1.0,>=0.14.1 (from transformers==4.30.2)
  Downloading huggingface_hub-0.30.2-py3-none-any.whl.metadata (13 kB)
Collecting regex!=2019.12.17 (from transformers==4.30.2)
  Downloading regex-2024.11.6-cp310-cp310-manylinux_2_17_x86

Load data

In [2]:
from datasets import load_dataset

query_data = load_dataset("princeton-nlp/LitSearch", "query", split="full")
corpus_clean_data = load_dataset("princeton-nlp/LitSearch", "corpus_clean", split="full")

Downloading readme:   0%|          | 0.00/1.81k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/50.6k [00:00<?, ?B/s]

Generating full split:   0%|          | 0/597 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/275M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/196M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/195M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/199M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/194M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/198M [00:00<?, ?B/s]

Generating full split:   0%|          | 0/64183 [00:00<?, ? examples/s]

In [3]:
from typing import List, Tuple, Any
from datasets import Dataset

def get_clean_corpusid(item: dict) -> int:
    return item['corpusid']

def get_clean_title(item: dict) -> str:
    return item['title']

def get_clean_abstract(item: dict) -> str:
    return item['abstract']

def get_clean_title_abstract(item: dict) -> str:
    title = get_clean_title(item)
    abstract = get_clean_abstract(item)
    return f"Title: {title}\nAbstract: {abstract}"

def get_clean_full_paper(item: dict) -> str:
    return item['full_paper']

def get_clean_paragraph_indices(item: dict) -> List[Tuple[int, int]]:
    text = get_clean_full_paper(item)
    paragraph_indices = []
    paragraph_start = 0
    paragraph_end = 0
    while paragraph_start < len(text):
        paragraph_end = text.find("\n\n", paragraph_start)
        if paragraph_end == -1:
            paragraph_end = len(text)
        paragraph_indices.append((paragraph_start, paragraph_end))
        paragraph_start = paragraph_end + 2
    return paragraph_indices

def get_clean_text(item: dict, start_idx: int, end_idx: int) -> str:
    text = get_clean_full_paper(item)
    assert start_idx >= 0 and end_idx >= 0
    assert start_idx <= end_idx
    assert end_idx <= len(text)
    return text[start_idx:end_idx]

def get_clean_paragraphs(item: dict, min_words: int = 10) -> List[str]:
    paragraph_indices = get_clean_paragraph_indices(item)
    paragraphs = [get_clean_text(item, paragraph_start, paragraph_end) for paragraph_start, paragraph_end in paragraph_indices]
    paragraphs = [paragraph for paragraph in paragraphs if len(paragraph.split()) >= min_words]
    return paragraphs

def get_clean_citations(item: dict) -> List[int]:
    return item['citations']

def get_clean_dict(data: Dataset) -> dict:
    return {get_clean_corpusid(item): item for item in data}

def create_kv_pairs(data: List[dict], key: str) -> dict:
    if key == "title_abstract":
        kv_pairs = {get_clean_title_abstract(record): get_clean_corpusid(record) for record in data}
    elif key == "full_paper":
        kv_pairs = {get_clean_full_paper(record): get_clean_corpusid(record) for record in data}
    elif key == "paragraphs":
        kv_pairs = {}
        for record in data:
            corpusid = get_clean_corpusid(record)
            paragraphs = get_clean_paragraphs(record)
            for paragraph_idx, paragraph in enumerate(paragraphs):
                kv_pairs[paragraph] = (corpusid, paragraph_idx)
    else:
        raise ValueError("Invalid key")
    return kv_pairs

In [4]:
kv_pairs = create_kv_pairs(corpus_clean_data, "title_abstract")

BM25 model

In [5]:
import os
import pickle
from tqdm import tqdm
from enum import Enum
from typing import List, Tuple, Any

class TextType(Enum):
    KEY = 1
    QUERY = 2

class KVStore:
    def __init__(self, index_name: str, index_type: str) -> None:
        self.index_name = index_name
        self.index_type = index_type

        self.keys = []
        self.encoded_keys = []
        self.values = []

    def __len__(self) -> int:
        return len(self.keys)

    def _encode(self, text: str, type: TextType) -> Any:
        return self._encode_batch([text], type, show_progress_bar=False)[0]
    
    def _encode_batch(self, texts: List[str], type: TextType, show_progress_bar: bool = True) -> List[Any]: 
        raise NotImplementedError
    
    def _query(self, encoded_query: Any, n: int) -> List[int]:
        raise NotImplementedError
    
    def clear(self) -> None:
        self.keys = []
        self.encoded_keys = []
        self.values = []

    def create_index(self, key_value_pairs: List[Tuple[str, Any]]) -> None:
        if len(self.keys) > 0:
            raise ValueError("Index is not empty. Please create a new index or clear the existing one.")
        
        for key, value in tqdm(key_value_pairs.items(), desc=f"Creating {self.index_name} index"):
            self.keys.append(key)
            self.values.append(value)
        self.encoded_keys = self._encode_batch(self.keys, TextType.KEY)

    def query(self, query_text: str, n: int, return_keys: bool = False) -> List[Any]:
        encoded_query = self._encode(query_text, TextType.QUERY)
        indices = self._query(encoded_query, n)
        if return_keys:
            results = [(self.keys[i], self.values[i]) for i in indices]
        else:
            results = [self.values[i] for i in indices]
        return results

    # def save(self, dir_name: str) -> None:
    #     save_dict = {}
    #     for key, value in self.__dict__.items():
    #         if key[0] != "_":
    #             save_dict[key] = value

    #     print(f"Saving index to {os.path.join(dir_name, f'{self.index_name}.{self.index_type}')}")
    #     os.makedirs(dir_name, exist_ok=True)
    #     with open(os.path.join(dir_name, f"{self.index_name}.{self.index_type}"), 'wb') as file:
    #         pickle.dump(save_dict, file, protocol=pickle.HIGHEST_PROTOCOL)
        

    # def load(self, file_path: str) -> None:
    #     if len(self.keys) > 0:
    #         raise ValueError("Index is not empty. Please create a new index or clear the existing one before loading from disk.")
        
    #     print(f"Loading index from {file_path}...")
    #     with open(file_path, 'rb') as file:
    #         pickle_data = pickle.load(file)
        
    #     for key, value in pickle_data.items():
    #         setattr(self, key, value)

In [6]:
from rank_bm25 import BM25Okapi
import nltk
import numpy as np

class BM25(KVStore):
    def __init__(self, index_name: str):
        super().__init__(index_name, 'bm25')

        nltk.download('punkt')
        nltk.download('stopwords')
        nltk.download('punkt_tab')

        self._tokenizer = nltk.word_tokenize
        self._stop_words = set(nltk.corpus.stopwords.words('english'))
        self._stemmer = nltk.stem.PorterStemmer().stem
        self.index = None   # BM25 index

    def _encode_batch(self, texts: List[str], type: TextType, show_progress_bar: bool = True) -> List[str]:
        # lowercase, tokenize, remove stopwords, and stem
        tokens_list = []
        for text in tqdm(texts, disable=not show_progress_bar):
            tokens = self._tokenizer(text.lower())
            tokens = [token for token in tokens if token not in self._stop_words]
            tokens = [self._stemmer(token) for token in tokens]
            tokens_list.append(tokens)
        return tokens_list

    def _query(self, encoded_query: List[str], n: int) -> List[int]:
        top_indices = np.argsort(self.index.get_scores(encoded_query))[::-1][:n].tolist()
        return top_indices

    def clear(self) -> None:
        super().clear()
        self.index = None

    def create_index(self, key_value_pairs: List[Tuple[str, Any]]) -> None:
        super().create_index(key_value_pairs)
        self.index = BM25Okapi(self.encoded_keys)

    # def load(self, dir_name: str) -> None:
    #     super().load(dir_name)
    #     self._tokenizer = nltk.word_tokenize
    #     self._stop_words = set(nltk.corpus.stopwords.words('english'))
    #     self._stemmer = nltk.stem.PorterStemmer().stem
    #     return self


In [7]:
bm25_model = BM25("Title_Abstract_BM25")

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


In [8]:
bm25_model.create_index(kv_pairs)

Creating Title_Abstract_BM25 index: 100%|██████████| 57657/57657 [00:00<00:00, 1565583.50it/s]
100%|██████████| 57657/57657 [02:11<00:00, 437.40it/s]


In [9]:
query_set = [query for query in query_data]
for query in tqdm(query_set):
    query_text = query["query"]
    top_k = bm25_model.query(query_text, 20)
    query["retrieved"] = top_k

100%|██████████| 597/597 [02:41<00:00,  3.71it/s]


In [10]:
# how does the element of query_set shape?
query_set[0]

{'query_set': 'inline_acl',
 'query': 'Are there any research papers on methods to compress large-scale language models using task-agnostic knowledge distillation techniques?',
 'specificity': 0,
 'quality': 2,
 'corpusids': [202719327],
 'retrieved': [248227350,
  221995575,
  257038997,
  235294276,
  258212842,
  3643430,
  256461337,
  201670719,
  219176798,
  247741658,
  259137871,
  247596810,
  250390982,
  258865530,
  234482764,
  239016109,
  259370686,
  258960671,
  237563200,
  222290473]}

In [11]:
import pandas as pd

query_set_df = pd.DataFrame(query_set)
broad_query_set_df = query_set_df[query_set_df['specificity'] == 0]
specific_query_set_df = query_set_df[query_set_df['specificity'] == 1]

In [12]:
broad_query_set_df.info()

<class 'pandas.core.frame.DataFrame'>
Index: 155 entries, 0 to 595
Data columns (total 6 columns):
 #   Column       Non-Null Count  Dtype 
---  ------       --------------  ----- 
 0   query_set    155 non-null    object
 1   query        155 non-null    object
 2   specificity  155 non-null    int64 
 3   quality      155 non-null    int64 
 4   corpusids    155 non-null    object
 5   retrieved    155 non-null    object
dtypes: int64(2), object(4)
memory usage: 8.5+ KB


In [13]:
def calculate_recall(corpusids: list, retrieved: list, k: int):
    top_k = retrieved[:k]
    intersection = set(corpusids) & set(top_k)
    return len(intersection) / len(corpusids) if corpusids else 0.0

In [14]:
# Broad

all_recall_at20 = []

for _, query in broad_query_set_df.iterrows():
    r20 = calculate_recall(query['corpusids'], query['retrieved'], 20)
    all_recall_at20.append(r20)

mean_recall_at20 = np.mean(all_recall_at20)

In [15]:
mean_recall_at20

0.3993548387096774

In [16]:
# Specific

all_recall_at5 = []

for _, query in specific_query_set_df.iterrows():
    r5 = calculate_recall(query['corpusids'], query['retrieved'], 5)
    all_recall_at5.append(r5)

mean_recall_at5 = np.mean(all_recall_at5)

In [17]:
mean_recall_at5

0.5

SciBert model

In [18]:
from sklearn.preprocessing import normalize

import torch
from transformers import AutoTokenizer, AutoModel

In [55]:
class SciBert(KVStore):
    def __init__(self, index_name: str):
        super().__init__(index_name, 'SciBert')
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self._model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased', torch_dtype=torch.float16).to(self.device)
        self._tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
        self.index = None

    def _encode_batch(self, texts: List[str], type: TextType, show_progress_bar: bool = True):
        encoded = self._tokenizer(
                    texts,
                    padding=True,
                    truncation=True,
                    return_tensors="pt",
                    max_length=512  # 최대 길이 명시적 지정
                ).to(self.device)
        return encoded

    def _query(self, encoded_query: List[str], n: int) -> List[int]:
        chunk_size = 4096  # 메모리 제한에 맞게 조정
        similarities = []
        
        # 쿼리 임베딩을 CPU로 이동 및 FP16 변환
        query_embedding = encoded_query.float().cpu()  # [1, 768]
        
        for i in range(0, len(self.index), chunk_size):
            index_chunk = self.index[i:i+chunk_size].to(self.device)  # [chunk_size, 768]
            sim_chunk = torch.nn.functional.cosine_similarity(
                query_embedding.to(self.device),  # [1, 768]
                index_chunk,                      # [chunk_size, 768]
                dim=1
            )
            similarities.append(sim_chunk.cpu())
        
        full_sim = torch.cat(similarities)
        return full_sim.topk(n).indices.numpy().tolist()

    def clear(self):
        super().clear()
        self.index = None

    def create_index(self, key_value_pairs: List[Tuple[str, Any]]) -> None:
        super().create_index(key_value_pairs)
        batch_size = 16
        with torch.no_grad():
            embeddings = []
            for i in range(0, len(self.keys), batch_size):
                batch_keys = self.keys[i:i+batch_size]
                encoded = self._encode_batch(batch_keys, TextType.KEY)
                outputs = self._model(**encoded)
                emb = outputs.last_hidden_state[:,0,:].detach().cpu()
                embeddings.append(emb)
                del encoded, outputs
                torch.cuda.empty_cache()
            self.index = torch.cat(embeddings).pin_memory()

    def query(self, query_text: str, n: int, return_keys: bool = False) -> List[Any]:
        encoded_query = self._encode_batch([query_text], TextType.QUERY)
    
        with torch.no_grad(), torch.cuda.amp.autocast():
            outputs = self._model(**encoded_query)
            query_embedding = outputs.last_hidden_state[:,0,:].detach()
            
        indices = self._query(query_embedding, n)
        
        # 중간 텐서 메모리 해제
        del encoded_query, outputs, query_embedding
        torch.cuda.empty_cache()
        if return_keys:
            results = [(self.keys[i], self.values[i]) for i in indices]
        else:
            results = [self.values[i] for i in indices]
        return results


In [56]:
scibert_model = SciBert("Title_Abstract_SciBert")

Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [57]:
scibert_model.create_index(kv_pairs)

Creating Title_Abstract_SciBert index: 100%|██████████| 57657/57657 [00:00<00:00, 1266038.02it/s]


In [58]:
scibert_query_set = [query for query in query_data]
for query in tqdm(scibert_query_set):
    query_text = query["query"]
    top_k = scibert_model.query(query_text, 20)
    query["retrieved"] = top_k

100%|██████████| 597/597 [00:11<00:00, 51.40it/s]


In [59]:
scibert_query_set

[{'query_set': 'inline_acl',
  'query': 'Are there any research papers on methods to compress large-scale language models using task-agnostic knowledge distillation techniques?',
  'specificity': 0,
  'quality': 2,
  'corpusids': [202719327],
  'retrieved': [243865348,
   220045412,
   9592788,
   219310310,
   52191945,
   429630,
   202770954,
   12385862,
   7490195,
   28804091,
   6580561,
   252715492,
   196217527,
   236486298,
   18036516,
   256461002,
   259370838,
   20943355,
   2357627,
   235623770]},
 {'query_set': 'inline_acl',
  'query': 'Are there any resources available for translating Tunisian Arabic dialect that contain both manually translated comments by native speakers and additional data augmented through methods like segmentation at stop words level?',
  'specificity': 1,
  'quality': 2,
  'corpusids': [227231792],
  'retrieved': [243865348,
   219310310,
   9592788,
   429630,
   6580561,
   28804091,
   7490195,
   52191945,
   36916465,
   202770954,
   98

In [61]:
scibert_query_set_df = pd.DataFrame(scibert_query_set)
broad_scibert_query_set_df = scibert_query_set_df[scibert_query_set_df['specificity'] == 0]
specific_scibert_query_set_df = scibert_query_set_df[scibert_query_set_df['specificity'] == 1]

In [62]:
# Specific

scibert_all_recall_at5 = []

for _, query in specific_scibert_query_set_df.iterrows():
    r5 = calculate_recall(query['corpusids'], query['retrieved'], 5)
    scibert_all_recall_at5.append(r5)

scibert_mean_recall_at5 = np.mean(scibert_all_recall_at5)

In [63]:
# Broad

scibert_all_recall_at20 = []

for _, query in broad_scibert_query_set_df.iterrows():
    r20 = calculate_recall(query['corpusids'], query['retrieved'], 20)
    scibert_all_recall_at20.append(r20)

scibert_mean_recall_at20 = np.mean(scibert_all_recall_at20)

In [64]:
scibert_mean_recall_at5

0.0

In [65]:
scibert_mean_recall_at20

0.0