In [11]:
import requests
import time

import ujson
import pandas as pd

from bigbio.dataloader import BigBioConfigHelpers
from tqdm import tqdm, trange
from collections import defaultdict

conhelps = BigBioConfigHelpers()



In [106]:

def query_pmid(pmids, url="http://bern2.korea.ac.kr/pubmed"):
    request_object =  requests.get(url + "/" + ",".join(pmids))
    # print(request_object.status_code)
    # print(request_object.reason)
    return request_object.json()

def check_query_status(pmids, url="http://bern2.korea.ac.kr/pubmed"):
    return requests.get(url + "/" + ",".join(pmids)).status_code == 200


def query_plain(text, url="http://bern2.korea.ac.kr/plain"):
    return requests.post(url, json={"text": text}).json()


def retrieve_pmid_list(pmid_list, chunksize=900, sleep_interval=100, pmids_to_omit=[]):
    all_retrieved_documents = []
    pmid_list = [x for x in pmid_list if x not in pmids_to_omit]
    for i in trange(len(pmid_list) // chunksize + 1):
        pmid_chunk = pmid_list[i * chunksize : (i + 1) * chunksize]
        retrieved_docs = query_pmid(pmid_chunk)
        if len(retrieved_docs) == 0:
            print("Error on PMIDS:", pmid_chunk)
        all_retrieved_documents.extend(retrieved_docs)
        time.sleep(sleep_interval)

    return all_retrieved_documents


def retrieve_full_text_documents(all_full_text_dict, chunksize=20, sleep_interval=10, pmids_to_pull=None):
    all_annotations = []
    chunk_iter = 0
    for pmid, doc in tqdm(all_full_text_dict.items()):
        if pmids_to_pull is not None:
            if pmid not in pmids_to_pull:
                continue
        if chunk_iter == chunksize:
            time.sleep(sleep_interval)
            chunk_iter = 0
        annotations = query_plain(doc)
        if len(annotations) == 0:
            print("Error for PMID:", pmid)
        annotations["document_id"] = pmid
        all_annotations.append(annotations)
        chunk_iter += 1

    return all_annotations



In [13]:

all_pmids = set([])
all_full_text = defaultdict(str)
total_docs = 0


for dataset in tqdm(
    ["medmentions_full", "bc5cdr", "gnormplus", "ncbi_disease", "nlmchem", "nlm_gene"]
):
    data = conhelps.for_config_name(f"{dataset}_bigbio_kb").load_dataset()
    for split in data.keys():
        for doc in data[split]:
            pmid = doc["document_id"]
            if pmid in all_pmids:
                continue

            all_pmids.add(pmid)
            doc_text = " ".join([" ".join(p["text"]) for p in doc["passages"]])
            all_full_text[pmid] = doc_text


# # PlantNorm
# print("Running Plant Norm")
# for subset in ['training','test','development']:
#     with open(f'../../PPRcorpus/corpus/DMCB_plant_{subset}_corpus.txt', 'r', encoding='utf-8', errors='ignore') as g:
#         all_text = g.read()
#         abstracts = all_text.strip().split('\n\n')
#         abstract_lines = [x.split('\n') for x in abstracts]
#         for abs in tqdm(abstract_lines):
#             pmid = abs[0].split('|')[0]
#             if pmid in all_pmids:
#                 continue
#             if len(abs[0].split('|')) == 1:
#                 abs.pop(0)
#             title = abs[0].split('|')[1]
#             abs_text = abs[1].split('|')[1]
#             doc_text = ' '.join([title, abs_text])

#             all_pmids.add(pmid)
#             all_full_text[pmid].add(doc_text)


all_pmids = list(all_pmids)


  0%|          | 0/6 [00:00<?, ?it/s]Found cached dataset medmentions (/nethome/dkartchner3/.cache/huggingface/datasets/bigbio___medmentions/medmentions_full_bigbio_kb/1.0.0/4ed5b6a69d807969022e559198c5a7386b9a978268a558758a090db6b451d6c4)


  0%|          | 0/3 [00:00<?, ?it/s]

 17%|█▋        | 1/6 [00:08<00:40,  8.18s/it]Found cached dataset bc5cdr (/nethome/dkartchner3/.cache/huggingface/datasets/bigbio___bc5cdr/bc5cdr_bigbio_kb/1.0.0/68f03988d9e501c974d9f9987183bf06474858d1318ed0d4e51cfc4584f0f51f)


  0%|          | 0/3 [00:00<?, ?it/s]

 33%|███▎      | 2/6 [00:10<00:18,  4.71s/it]Found cached dataset gnormplus (/nethome/dkartchner3/.cache/huggingface/datasets/bigbio___gnormplus/gnormplus_bigbio_kb/1.0.0/97a2714b58185305591c949b067cea2febfca2447016096c3d08021d84bf7b69)


  0%|          | 0/2 [00:00<?, ?it/s]

 50%|█████     | 3/6 [00:11<00:09,  3.20s/it]Found cached dataset ncbi_disease (/nethome/dkartchner3/.cache/huggingface/datasets/bigbio___ncbi_disease/ncbi_disease_bigbio_kb/1.0.0/5f3bb3f460b7487dc6d28ec539d7d7cd7d717705ff58314672581cab8e1d9957)


  0%|          | 0/3 [00:00<?, ?it/s]

 67%|██████▋   | 4/6 [00:13<00:04,  2.45s/it]Found cached dataset nlmchem (/nethome/dkartchner3/.cache/huggingface/datasets/bigbio___nlmchem/nlmchem_bigbio_kb/1.0.0/66bcefa38a4fe5d4ba1a0993a516040bad028699fbe3ef935f95532596668131)


  0%|          | 0/3 [00:00<?, ?it/s]

 83%|████████▎ | 5/6 [00:14<00:02,  2.15s/it]Found cached dataset nlm_gene (/nethome/dkartchner3/.cache/huggingface/datasets/bigbio___nlm_gene/nlm_gene_bigbio_kb/1.0.0/71526324bb52d82b3917dfc7c9b76f3bac4fb0d86d98c5c2e29951b8cee0e24f)


  0%|          | 0/2 [00:00<?, ?it/s]

100%|██████████| 6/6 [00:15<00:00,  2.63s/it]


In [90]:
def find_lowest_failing_index(pmid_list, min_ind, max_ind):
    if check_query_status(pmid_list[min_ind:max_ind]):
        return -1
    while min_ind < max_ind - 1:
        mid = (max_ind + min_ind) // 2
        if check_query_status(pmid_list[min_ind:mid]):
            min_ind = mid
        else: 
            max_ind = mid

    return min_ind
        


def multielement_binary_search(pmid_list, min_ind=0, ):
    '''
    Find PMIDS that should be omitted from BERN2 search
    '''
    pmids_to_omit=[]
    max_ind = len(pmid_list)
    
    while min_ind < max_ind - 1:
        failure_ind = find_lowest_failing_index(pmid_list, min_ind, max_ind)
        print(failure_ind)
        if failure_ind == -1:
            break
        pmids_to_omit.append(failure_ind)
        min_ind = failure_ind + 1
    return pmids_to_omit

# find_lowest_failing_index(all_pmids, 900, 1000)
inds_to_exclude = multielement_binary_search(all_pmids)

907
1029
1730
1899
2108
2555
3139
3355
3393
3673
3950
4097
4431
4461
4508
4781
4784
5192
5321
5834
6009
6423
6599
6764
6891
6926
7273
7681
7993
8016
-1


In [95]:
pmids_to_omit = [all_pmids[x] for x in inds_to_exclude]
nonexcluded_pmids = [x for i, x in enumerate(all_pmids) if i not in inds_to_exclude]
len(nonexcluded_pmids)

8073

In [57]:
len(out)


In [98]:

pulled_pubmed_docs = retrieve_pmid_list(all_pmids, sleep_interval=0, pmids_to_omit=pmids_to_omit)
with open("../data/bern2_annotations_from_pmids.json", "w") as f:
    f.write(ujson.dumps(pulled_pubmed_docs, indent=2))


In [103]:
pulled_pubmed_docs = pulled_pmids

In [104]:
pulled_pmids = [x['_id'] for x in pulled_pubmed_docs if len(x['annotations']) > 0]
len(pulled_pmids)

7914

In [105]:

unpulled_pmids = [x for x in all_pmids if x not in pulled_pmids]

In [108]:
pulled_full_text = retrieve_full_text_documents(all_full_text, pmids_to_pull=unpulled_pmids, sleep_interval=100)
with open("../data/bern2_annotations_from_full_text.json", "w") as f:
    f.write(ujson.dumps(pulled_full_text, indent=2))


100%|██████████| 8073/8073 [14:45<00:00,  9.12it/s]  
