In [51]:
from sentence_transformers import SentenceTransformer, InputExample, losses
from torch.utils.data import DataLoader
from datasets import Dataset
from tqdm import tqdm
from pathlib import Path
import json
from Bio import Entrez, Medline  # Accessing and parsing PubMed/NCBI data
import random

In [52]:
model = SentenceTransformer("pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb")

In [53]:
EMAIL = "pgreinald@gmail.com"           # Enter your E-Mail(The API will start compaining if not entered)
API_KEY = "9666f51fccbd68a29320334f1d78ad257608"         # Enter your API Key(More Queries/s if entered)

Entrez.email = EMAIL
Entrez.api_key = API_KEY

In [None]:
def get_article(pubmed_id, skip_sleep = False):
    # Fetch the record from PubMed
    if not skip_sleep:
        time.sleep(0.34)      
    handle = Entrez.efetch(db="pubmed", id=pubmed_id, rettype="abstract", retmode="xml")
    records = Entrez.read(handle)

    # Extract title and abstract
    article = records['PubmedArticle'][0]['MedlineCitation']['Article']
    title = article['ArticleTitle']
    abstract = article['Abstract']['AbstractText'][0]

    return title, abstract

In [55]:
def concat_article(title, abstract):
    return f"{title} {abstract}"

In [56]:
from itertools import islice

def filter_ids(generic_ids, ground_truth_ids, limit = 100):
    return list(islice((id_ for id_ in generic_ids if id_ not in ground_truth_ids), limit))

In [57]:
TRAIN_PATH = Path("Final_Notebook/training13b.json")

max_per_query = 10

with open(TRAIN_PATH, "r", encoding="utf-8") as f:
    data = json.load(f)

train_examples = []

for q in tqdm(data["questions"]):
    q_id = q["id"]
    q_query = q["body"]
    ground_truth = q["documents"]
    ground_truth = [i.split("/")[4] for i in ground_truth]

    ground_truth_texts = [concat_article(*get_article(id)) for id in ground_truth]

    DOCS_PATH = Path("Final_Notebook/api_retrieval/" + q["id"] + ".json")

    if DOCS_PATH.is_file():
        with open(DOCS_PATH, encoding="utf-8") as f:
            docs = json.load(f)
            docs = docs["documents"]

        if len(docs) <= 0:
            continue
        
        neg_ids = [doc["pmid"] for doc in docs if not doc["pmid"] in ground_truth]

        if len(neg_ids) <= 0:
            continue

        for n in range(min(len(neg_ids), max_per_query)):
            pos_article = random.choice(ground_truth_texts)
            neg_id = neg_ids[n]

            neg_title = doc["title"]
            neg_abstract = doc["abstract"]
                
            neg_article = concat_article(neg_title, neg_abstract)
            
            train_examples.append(InputExample(texts=[q_query, pos_article, neg_article]))



  2%|▏         | 2/100 [00:18<14:54,  9.12s/it]


KeyboardInterrupt: 

In [None]:
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=8)

In [None]:
train_loss = losses.TripletLoss(model=model)

In [None]:
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=3,
    warmup_steps=100,
    show_progress_bar=True
)

# Save model
model.save("triplet-finetuned-BioBERT")




Step,Training Loss
