In [None]:
from sentence_transformers import SentenceTransformer, InputExample, losses
from torch.utils.data import DataLoader
from datasets import Dataset
from tqdm import tqdm
from pathlib import Path
import json
from Bio import Entrez
import random
import time
from itertools import islice
import os

In [None]:
model = SentenceTransformer("pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb")
TRAIN_PATH = Path("./training13b.json")

max_per_query = 2

In [None]:
EMAIL = "pgreinald@gmail.com"           # Enter your E-Mail(The API will start compaining if not entered)
API_KEY = "9666f51fccbd68a29320334f1d78ad257608"         # Enter your API Key(More Queries/s if entered)

Entrez.email = EMAIL
Entrez.api_key = API_KEY

In [None]:
with open(TRAIN_PATH, "r", encoding="utf-8") as f:
    data = json.load(f)

In [None]:
RETMX = 1000  # Number of documents to fetch
TRAIN_DATA_URL  = "https://participants-area.bioasq.org/Tasks/13b/trainingDataset/training13b.json"
DATA_FILE = "training13b.json"
BATCH = 10_000_000    # PubMed efetch batch size

# Load the Training File
with open(DATA_FILE, "r", encoding="utf-8") as f:
    data = json.load(f)
questions = data.get("questions", [])

##############################---Utility Functions---###################################

def ensure_training_file() -> None:
    if os.path.exists(DATA_FILE) and os.path.getsize(DATA_FILE) > 0:
        return
    print("Downloading training data …")
    resp = requests.get(TRAIN_DATA_URL, timeout=30)
    resp.raise_for_status()
    if b"<html" in resp.content[:100].lower():
        raise RuntimeError("Downloaded content is HTML, not JSON - check URL/login.")
    with open(DATA_FILE, "wb") as fh:
        fh.write(resp.content)
    print(f"Saved → {DATA_FILE}")

def esearch_pmids(query: str, retmax: int = RETMX) -> list:
    """
    Perform a PubMed search using both Title/Abstract and MeSH terms.
    Returns a list of PMIDs (strings).
    """
    # Cleaning up teh query by phrasing it as well as tokenizing it.
    phrase = query.lower().strip()
    tokens = phrase.split()
    # Build query parts for full phrase and individual tokens. Mesh Terms are 
    # standardized vocabulary used to enhance search precision by capturing the semantic meaning of biomedical concepts.
    # Frankly this really helps in the document retrival part.
    parts = [f'"{phrase}"[Title/Abstract]', f'"{phrase}"[MeSH Terms]']
    for t in tokens:
        parts.append(f'{t}[Title/Abstract]')
        parts.append(f'{t}[MeSH Terms]')
    term = f"({' OR '.join(parts)}) AND hasabstract[text]"

    # Retry up to 3 times on failure
    for attempt in range(3):
        try:
            handle = Entrez.esearch(
                db="pubmed",
                term=term,
                retmax=retmax,
                sort="relevance",
                retmode="xml"
            )
            result = Entrez.read(handle)
            time.sleep(0.1)
            return result.get('IdList', [])
        except Exception as e:
            print(f"Esearch attempt {attempt+1} failed: {e}")
            time.sleep(2 ** attempt)
    return []

#####################################################################################

def fetch_documents(pmids: list[str], batch_size: int = BATCH) -> list[tuple[str,str,str]]:
    docs: list[tuple[str,str,str]] = []
    for start in range(0, len(pmids), batch_size):
        batch = pmids[start:start+batch_size]
        for attempt in range(3):
            try:
                h = Entrez.efetch(db="pubmed", id=",".join(batch), retmode="xml")
                recs = Entrez.read(h)
                time.sleep(0.34)         
                break
            except Exception as e:
                print(f"    [efetch] attempt {attempt+1}/3 failed → {e}")
                time.sleep(2 ** attempt)
        else:                            # all retries exhausted
            continue

        for art in recs.get("PubmedArticle", []):
            pmid   = art["MedlineCitation"]["PMID"]
            art_el = art["MedlineCitation"]["Article"]
            title  = art_el.get("ArticleTitle", "")
            abst   = " ".join(art_el.get("Abstract", {})
                              .get("AbstractText", []))
            docs.append((pmid, title, abst))
    return docs

# Loop Iteration for each query in the training file or testin batches. 
def save_docs() -> None:
    ensure_training_file()

    # Load input questions
    with open(DATA_FILE, "r", encoding="utf-8") as fh:
        questions = json.load(fh)["questions"]

    # Output file setup
    out_dir  = Path("api_retrieval")
    out_dir.mkdir(exist_ok=True)

    try:
        # Nice progress bar if tqdm available

        # with out_path.open(mode, encoding="utf-8") as fout:
        for q in tqdm(questions, total=len(questions), unit="q"):
            qid, qtype = q["id"], q["type"]
            out_path = out_dir / (qid + ".json")

            if out_path.is_file():
                continue
            
            body       = q["body"].strip()


            if not out_path.is_file():
                pmids = esearch_pmids(body)
                docs  = fetch_documents(pmids)

                record = {
                    "id": qid,
                    "type": qtype,
                    "body": body,
                    "documents": [
                        {"pmid": p, "title": t, "abstract": a}
                        for p, t, a in docs
                    ]
                }

                with out_path.open("w", encoding="utf-8") as fout:
                    fout.write(json.dumps(record))
                    fout.flush()                   # ensure line is on disk
    except KeyboardInterrupt:
        print("\nInterrupted by user.")

    print("\n✅ Finished.")

save_docs()

In [None]:
def get_article(pubmed_id: str, skip_sleep: bool = False) -> tuple[str, str]:
    # Fetch the record from PubMed
    if not skip_sleep:
        time.sleep(0.1)      
    handle = Entrez.efetch(db="pubmed", id=pubmed_id, rettype="abstract", retmode="xml")
    records = Entrez.read(handle)

    # Extract title and abstract
    try:
        article = records['PubmedArticle'][0]['MedlineCitation']['Article']
        title = article['ArticleTitle']
        abstract = article['Abstract']['AbstractText'][0]
    except Exception as e:
        print(f"An error occurred: {e}")
        print(f"The id is {pubmed_id}")
        print(f"{records}")
        title = ""
        abstract = ""

    return title, abstract

In [None]:
def concat_article(title: str, abstract: str) -> str:
    return f"{title} {abstract}"

In [None]:

def filter_ids(generic_ids: list[str], ground_truth_ids: list[str], limit: int = 100) -> list[str]:
    return list(islice((id_ for id_ in generic_ids if id_ not in ground_truth_ids), limit))

In [None]:
def create_train_examples(data: dict = data, max_per_query: int = max_per_query):
    train_examples = []

    for q in tqdm(data["questions"]):
        q_id = q["id"]
        q_query = q["body"]
        ground_truth = q["documents"]
        ground_truth = [i.split("/")[4] for i in ground_truth]

        ground_truth_texts = [concat_article(*get_article(id)) for id in ground_truth]

        DOCS_PATH = Path("./api_retrieval/" + q["id"] + ".json")

        if DOCS_PATH.is_file():
            with open(DOCS_PATH, encoding="utf-8") as f:
                docs = json.load(f)
                docs = docs["documents"]

            if len(docs) <= 0:
                continue
            
            neg_ids = [doc["pmid"] for doc in docs if not doc["pmid"] in ground_truth]

            if len(neg_ids) <= 0:
                continue

            for n in range(min(len(neg_ids), max_per_query)):
                pos_article = random.choice(ground_truth_texts)
                neg_id = neg_ids[n]

                neg_title = docs[n]["title"]
                neg_abstract = docs[n]["abstract"]
                    
                neg_article = concat_article(neg_title, neg_abstract)
                
                train_examples.append(InputExample(texts=[q_query, pos_article, neg_article]))

    return train_examples

In [None]:
train_examples = create_train_examples(data)
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=8)

In [None]:
train_loss = losses.TripletLoss(model=model)

In [None]:
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=2,
    warmup_steps=100,
    show_progress_bar=True
)

# Save model
model.save("./triplet-finetuned-BioBERT")
