In [1]:
import datasets
from datasets import load_dataset
query_data = load_dataset("princeton-nlp/LitSearch", "query", split="full")
corpus_clean_data = load_dataset("princeton-nlp/LitSearch", "corpus_clean", split="full")
corpus_s2orc_data = load_dataset("princeton-nlp/LitSearch", "corpus_s2orc", split="full")


In [4]:
import os
from eval.retrieval.bm25 import BM25
from utils import utils
from eval.retrieval.kv_store import KVStore



In [9]:
import os
from typing import List
import argparse


def get_index_name(args: argparse.Namespace) -> str:
    return os.path.basename(args.dataset_path) + "." + args.key

def create_index(args: argparse.Namespace) -> KVStore:
    index_name = get_index_name(args)

    if args.index_type == "bm25":
        from eval.retrieval.bm25 import BM25
        index = BM25(index_name)
    elif args.index_type == "instructor":
        from eval.retrieval.instructor import Instructor
        if args.key == "title_abstract":
            query_instruction = "Represent the research question for retrieving relevant research paper abstracts:"
            key_instruction = "Represent the title and abstract of the research paper for retrieval:"
        elif args.key == "full_paper":
            query_instruction = "Represent the research question for retrieving relevant research papers:"
            key_instruction = "Represent the research paper for retrieval:"
        elif args.key == "paragraphs":
            query_instruction = "Represent the research question for retrieving passages from relevant research papers:"
            key_instruction = "Represent the passage from the research paper for retrieval:"
        else:
            raise ValueError("Invalid key")
        index = Instructor(index_name, key_instruction, query_instruction)
    elif args.index_type == "e5":
        from eval.retrieval.e5 import E5
        index = E5(index_name)
    elif args.index_type == "gtr":
        from eval.retrieval.gtr import GTR
        index = GTR(index_name)
    elif args.index_type == "grit":
        from eval.retrieval.grit import GRIT
        if args.key == "title_abstract":
            raw_instruction = "Given a research query, retrieve the title and abstract of the relevant research paper"
        elif args.key == "full_paper":
            raw_instruction = "Given a research query, retrieve the relevant research paper"
        elif args.key == "paragraphs":
            raw_instruction = "Given a research query, retrieve the passage from the relevant research paper"
        else:
            raise ValueError("Invalid key")
        index = GRIT(index_name, raw_instruction)
    else:
        raise ValueError("Invalid index type")
    return index

def create_kv_pairs(data: List[dict], key: str) -> dict:
    if key == "title_abstract":
        kv_pairs = {utils.get_clean_title_abstract(record): utils.get_clean_corpusid(record) for record in data}
    elif key == "full_paper":
        kv_pairs = {utils.get_clean_full_paper(record): utils.get_clean_corpusid(record) for record in data}
    elif key == "paragraphs":
        kv_pairs = {}
        for record in data:
            corpusid = utils.get_clean_corpusid(record)
            paragraphs = utils.get_clean_paragraphs(record)
            for paragraph_idx, paragraph in enumerate(paragraphs):
                kv_pairs[paragraph] = (corpusid, paragraph_idx)
    else:
        raise ValueError("Invalid key")
    return kv_pairs

In [10]:

args = argparse.Namespace(
    index_type="bm25",  # Simulate the --index_type argument
    key="title_abstract",  # Simulate the --key argument
    dataset_path="princeton-nlp/LitSearch",  # Default value (or you can customize)
    index_root_dir="retrieval_indices"  # Default value (or you can customize)
)


corpus_data = datasets.load_dataset(args.dataset_path, "corpus_clean", split="full")
index = create_index(args)
kv_pairs = create_kv_pairs(corpus_data, args.key)
index.create_index(kv_pairs)

index_name = get_index_name(args)
index.save(args.index_root_dir)

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\demo123\AppData\Roaming\nltk_data...
[nltk_data]   Unzipping tokenizers\punkt.zip.
[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\demo123\AppData\Roaming\nltk_data...
[nltk_data]   Unzipping corpora\stopwords.zip.
Creating LitSearch.title_abstract index: 100%|██████████| 57657/57657 [00:00<00:00, 3217035.41it/s]
100%|██████████| 57657/57657 [01:33<00:00, 613.62it/s]


Saving index to retrieval_indices\LitSearch.title_abstract.bm25


In [11]:

from tqdm import tqdm
from utils import utils
from eval.retrieval.kv_store import KVStore
def load_index(index_path: str) -> KVStore:
    index_type = os.path.basename(index_path).split(".")[-1]
    if index_type == "bm25":
        from eval.retrieval.bm25 import BM25
        index = BM25(None).load(index_path)
    elif index_type == "instructor":
        from eval.retrieval.instructor import Instructor
        index = Instructor(None, None, None).load(index_path)
    elif index_type == "e5":
        from eval.retrieval.e5 import E5
        index = E5(None).load(index_path)
    elif index_type == "gtr":
        from eval.retrieval.gtr import GTR
        index = GTR(None).load(index_path)
    elif index_type == "grit":
        from eval.retrieval.grit import GRIT
        index = GRIT(None, None).load(index_path)
    else:
        raise ValueError("Invalid index type")
    return index

In [12]:
args = argparse.Namespace(
    index_name="LitSearch.title_abstract.bm25",  # Simulate the --index_name argument
    top_k=200,  # Simulate the --top_k argument with a default value
    retrieval_results_root_dir="results/retrieval",  # Default value
    index_root_dir="retrieval_indices",  # Default value
    dataset_path="princeton-nlp/LitSearch"  # Default value
)

In [13]:
index = load_index(os.path.join(args.index_root_dir, args.index_name))
query_set = [query for query in datasets.load_dataset(args.dataset_path, "query", split="full")]
for query in tqdm(query_set):
    query_text = query["query"]
    top_k = index.query(query_text, args.top_k)
    query["retrieved"] = top_k

os.makedirs(args.retrieval_results_root_dir, exist_ok=True)
output_path = os.path.join(args.retrieval_results_root_dir, f"{args.index_name}.jsonl")
utils.write_json(query_set, output_path)

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\demo123\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\demo123\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


Loading index from retrieval_indices\LitSearch.title_abstract.bm25...


100%|██████████| 597/597 [02:14<00:00,  4.44it/s]

Saved 597 records to results/retrieval\LitSearch.title_abstract.bm25.jsonl



