In [1]:
# !pip install rank_bm25 langchain transformers torch faiss-cpu
from langchain.text_splitter import RecursiveCharacterTextSplitter
from multiprocessing import Pool
from rank_bm25 import BM25Okapi
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel

import faiss
import json
import numpy as np
import os
import pandas as pd
import re
import torch

In [2]:
def construct_qid2query(fout='./data/json_format/qid2query.json'):
    mirage = json.load(open('./mirage/mirage_benchmark.json'))
    medq = pd.read_excel('./medq/medq-30.xlsx')
    questions = {}
    for dataset in mirage:
        for q in mirage[dataset]:
            qid = f'{dataset}:{q}'
            questions[qid] = mirage[dataset][q]['question']
            
    for _, row in medq.iterrows():
        q = row['ID']
        qid = f'medq30:{q}'
        questions[qid] = row['Question']
    json.dump(questions, open(fout, 'w+'), indent=2)

In [3]:
def construct_qid2answer(fout='./data/json_format/qid2answer.json'):
    mirage = json.load(open('./mirage/mirage_benchmark.json'))
    medq = pd.read_excel('./medq/medq-30.xlsx')
    answers = {}
    for dataset in mirage:
        for q in mirage[dataset]:
            answer = {}
            qid = f'{dataset}:{q}'
            answer['answer'] = mirage[dataset][q]['answer']
            if 'options' in mirage[dataset][q]:
                answer['options'] = mirage[dataset][q]['options']
            answers[qid] = answer
            
    for _, row in medq.iterrows():
        q = row['ID']
        qid = f'medq30:{q}'
        answers[qid] = {'answer': row['Answer']}
    json.dump(answers, open(fout, 'w+'), indent=2)

In [4]:
def construct_all_docs_content(fout='./data/json_format/content_map_all.json'):
    def get_abstract_content(pmid, pmid2content):
        if not str(pmid) in pmid2content:
            # logging.warning(f'{pmid} not present in the repo.')
            return ''
        raw_content = pmid2content[str(pmid)]
        return ' '.join([raw_content['t'], raw_content['a']])

    pmid2content = {}
    for i in tqdm(range(38)):
        content = json.load(open(f'./medcpt/pubmed_embeds/pubmed_chunk_{i}.json'))
        pmid2content.update(content)
        
    textbooks = []
    chunk_dir = './data/textbooks/chunk/'
    chunk_files = [
        t for t in os.listdir(chunk_dir)
        if t.endswith('.jsonl')
    ]
    for f in chunk_files:
        textbooks += [
            json.loads(l.strip())['content'] 
            for l in open(os.path.join(chunk_dir, f))
        ]

    content_map = {}
    for pmid in pmid2content:
        content_map[str(pmid)] = get_abstract_content(pmid, pmid2content)
    for i in range(len(textbooks)):
        content_map[f'textbooks-{i}'] = textbooks[i]
        
    json.dump(content_map, open(fout, 'w+'), indent=2)

In [5]:
def construct_qid2dense_retrieval(fout='./data/json_format/qid2dense_retrieval_k64.json'):
    doc_ids = []
    index = faiss.IndexFlatIP(768)

    for i in tqdm(range(38)):
        index.add(np.load(f'./medcpt/pubmed_embeds/embeds_chunk_{i}.npy'))
        doc_ids += json.load(open(f'./medcpt/pubmed_embeds/pmids_chunk_{i}.json'))

    content_map = json.load(open(
        './data/json_format/content_map_all.json'))
    textbooks_id = [k for k in content_map if k.startswith('textbooks')]
    textbooks = [content_map[i] for i in textbooks_id]
    del content_map
    model = AutoModel.from_pretrained("ncbi/MedCPT-Article-Encoder")
    tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Article-Encoder")

    batch_size = 100
    embeds_list = []
    for i in tqdm(range(0, len(textbooks), batch_size)):
        batch = textbooks[i: i+batch_size]
        with torch.no_grad():
            # tokenize the queries
            encoded = tokenizer(
                batch, 
                truncation=True, 
                padding=True, 
                return_tensors='pt', 
                max_length=512,
            )

            # encode the queries (use the [CLS] last hidden states as the representations)
            textbooks_embeds = model(**encoded).last_hidden_state[:, 0, :]
            embeds_list.append(textbooks_embeds.cpu().detach().numpy())

    for embeds in embeds_list:
        index.add(embeds)
    doc_ids += textbooks_id
    qid2query = json.load(open('./data/json_format/qid2query.json'))
    results = []
    question_ids = [k for k in qid2query]
    queries = [qid2query[k] for k in question_ids]

    model = AutoModel.from_pretrained("ncbi/MedCPT-Query-Encoder")
    tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Query-Encoder")

    batch_size=64
    # for query in tqdm(queries):
    for i in tqdm(range(0, len(queries), batch_size)):
        batch = queries[i:i+batch_size]
        with torch.no_grad():
            # tokenize the queries
            encoded = tokenizer(
                batch, 
                truncation=True, 
                padding=True, 
                return_tensors='pt', 
                max_length=64,
            )

            # encode the queries (use the [CLS] last hidden states as the representations)
            embeds = model(**encoded).last_hidden_state[:, 0, :]
        # search the Faiss index
        scores_batch, inds_batch = index.search(embeds, k=64)
        for j in range(len(batch)):
            scores = scores_batch[j]
            inds = inds_batch[j]
            result_doc_ids = [doc_ids[ind] for ind in inds]
            results.append((scores, result_doc_ids))

    retrieval_results = {}
    for qid, result in zip(question_ids, results):
        data = {}
        retieval_similarity, retrieval_pmids = result
        data['qid'] = qid
        data['question'] = qid2query[qid]
        data['relevance_scores'] = retieval_similarity.tolist()
        data['doc_ids'] = retrieval_pmids
        retrieval_results[qid] = data
    json.dump(
        retrieval_results,
        open(fout, 'w+'), 
        indent=2
    )

In [9]:
def construct_docs_content(fout='./data/json_format/docs_content.json'):
    all_docs_content = json.load(open(
        './data/json_format/content_map_all.json'))
    dense_retrieval = json.load(open(
        './data/json_format/qid2dense_retrieval_k64.json'))
    keys = set()
    for qid in dense_retrieval:
        keys = keys | set(dense_retrieval[qid]['doc_ids'])
    print(len(keys))
    
    docs_content = {}
    for k in keys:
        docs_content[k] = all_docs_content[k]
        
    json.dump(docs_content, open(fout, 'w+'), indent=2)

In [11]:
dense_retrieval = json.load(open(
    './data/json_format/qid2dense_retrieval_k64.json'))
docs_content = json.load(open(
    './data/json_format/docs_content.json'))
content2id = {}
for k in docs_content:
    content2id[docs_content[k]] = k

In [12]:
def bm25_rerank(qid):
    qa_data = dense_retrieval[qid]
    doc_ids = qa_data['doc_ids']
    corpus = [docs_content[k] for k in doc_ids]
    query = qa_data['question']

    default_top_k = len(doc_ids)
    bm25 = BM25Okapi([d.split(' ') for d in corpus])
    raw_bm25_retrieval = bm25.get_top_n(
        query.split(' '), corpus, n=default_top_k)

    rerank_doc_ids = [content2id[v] for v in raw_bm25_retrieval]
    return rerank_doc_ids

def construct_qid2sparse_rerank(fout='./data/json_format/qid2sparse_rerank.json'):
    qids = list(dense_retrieval.keys())
    with Pool(64) as p:
        sparse_rerank = p.map(bm25_rerank, qids)
    results = {}
    for qid, sr in zip(qids, sparse_rerank):
        results[qid] = sr
    json.dump(results, open(fout, 'w+'), indent=2)

In [14]:
def construct_qid2key_info(fout='./data/json_format/qid2key_info.json'):
    mirage = json.load(open('./mirage/mirage_benchmark.json'))
    answers = {}
    for dataset in mirage:
        for q in mirage[dataset]:
            qid = f'{dataset}:{q}'
            if 'PMID' in mirage[dataset][q]:
                answers[qid] = mirage[dataset][q]['PMID']
    json.dump(answers, open(fout, 'w+'), indent=2)