In [None]:
import os
import gc
import ast
import time
import random
import itertools
import requests
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from xml.etree import ElementTree as ET

In [None]:
# Count the number of publications related to the query using the PubMed API
# The following process may take a long time
def load_mesh_label():
    df = pd.read_csv("../data/fig2-3/MeSH/mesh_disease_leaves_w_annotation.csv")
    mesh_id2name = dict(zip(df["mesh_id"], df["label"]))
    return mesh_id2name

def load_pmids_pubtator3():
    chunk = 1
    df = pd.read_parquet(f"../data/fig2-3/All_MeSH_diseases_pmid_bert_corrs_chunk{chunk}.parquet")
    return df["pmid"].astype(str).values


def get_pubmed_counts(mesh_id, mesh_id2name, pmid2use):
    sleep_time = random.uniform(0.5, 1.0)
    time.sleep(sleep_time)

    # Fist, get the total number of hits
    query = mesh_id2name[mesh_id]
    url = f"https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?db=pubmed&term={query}&rettype=count"
    response = requests.get(url)
    if not response.status_code == 200:
        print(f"[{mesh_id}] Error: {response.status_code}")
        return None
    else:
        root = ET.fromstring(response.text)
        total_hits = int(root.find("Count").text)

        # Second, get the pmids
        retmax = 100000
        total_pmids = set()
        for retstart in range(0, total_hits, retmax):
            url = f"https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?db=pubmed&term={query}&retstart={retstart}&retmax={retmax}"
            response = requests.get(url)
            if not response.status_code == 200:
                print(f"[{mesh_id}] Error: {response.status_code}")
                return None
            else:
                root = ET.fromstring(response.text)
                pmids = [id_elem.text for id_elem in root.findall(".//Id")]
                total_pmids.update(set(pmids))
        
        # Third, get the intersection with the pmids2use
        count = len(set(pmid2use.astype(str)) & total_pmids)
        return count

RE_CALCULATION = False
if RE_CALCULATION:
    mesh_id2name = load_mesh_label()
    pmid2use = load_pmids_pubtator3()

    pubmed_api_counts = []
    for mesh_id in tqdm(mesh_id2name.keys()):
        count = get_pubmed_counts(mesh_id)
        pubmed_api_counts.append(count)

    pubmed_api_counts = [int(x) for x in pubmed_api_counts if x is not None]
    pubmed_api_counts = np.array(pubmed_api_counts)
    np.save("../data/fig2-3/pubmed_api_counts.npy", pubmed_api_counts)

In [None]:
# ==================== CONFIGURATION ====================
def configure_plot_style():
    custom_params = {"axes.spines.right": False, "axes.spines.top": False}
    sns.set_theme(style="ticks", rc=custom_params)
    mpl.rcParams.update({
        'xtick.labelsize': 14,
        'ytick.labelsize': 13,
        'axes.labelsize': 16,
        'axes.titlesize': 16,
        'lines.linewidth': 1.5
    })


# ==================== DATA LOADING ====================
def load_pmids():
    df = pd.read_csv("../data/fig2-3/PubTator3/Pubtator3_BioREX_pmid_mesh.csv")
    df["from_mesh"] = df["from_mesh"].apply(ast.literal_eval)
    return df


def load_similarity_scores(num_chunks=3):
    all_chunks = []
    for chunk in range(1, num_chunks + 1):
        df = pd.read_parquet(f"../data/fig2-3/All_MeSH_diseases_pmid_bert_corrs_chunk{chunk}.parquet")
        if chunk == 1:
            all_chunks.append(df[["pmid"]])
        all_chunks.append(df.drop(columns=["pmid"], errors='ignore'))
        del df
        gc.collect()
    return pd.concat(all_chunks, axis=1)


def load_mesh_categories():
    df = pd.read_csv("../data/fig2-3/MeSH/mesh_disease_leaves_w_annotation.csv")
    if isinstance(df["tree_categories"].iloc[0], str):
        df["tree_categories"] = df["tree_categories"].apply(ast.literal_eval)
    return df


def load_category_mapping():
    df = pd.read_csv("../data/fig2-3/MeSH/MeSH_large_categories.csv")
    return dict(zip(df["mesh_id"], df["name"]))


# ==================== PROCESSING ====================
def create_mesh_one_hot(pmids_df, mesh2use):
    mesh2idx = {mesh: idx for idx, mesh in enumerate(mesh2use)}
    mesh_oh = np.zeros((len(pmids_df), len(mesh2use)), dtype=int)

    for idx, row in tqdm(pmids_df.iterrows(), total=len(pmids_df)):
        for mesh in row["from_mesh"]:
            if mesh.startswith("MESH:"):
                mesh = mesh.split(":")[1]
                if mesh in mesh2idx:
                    mesh_oh[idx, mesh2idx[mesh]] = 1
    return mesh_oh


def compute_correlations(pmid_corrs, mesh_oh, mesh2use):
    corrs_pos, corrs_neg, mesh_ids = [], [], []

    for i, mesh_id in enumerate(tqdm(mesh2use)):
        if mesh_oh[:, i].sum() > 0:
            pos_idx = np.where(mesh_oh[:, i] == 1)[0]
            neg_idx = np.random.choice(np.where(mesh_oh[:, i] == 0)[0], len(pos_idx), replace=False)
            corrs_pos.extend(pmid_corrs.loc[pos_idx, mesh_id].values)
            corrs_neg.extend(pmid_corrs.loc[neg_idx, mesh_id].values)
            mesh_ids.extend([mesh_id] * len(pos_idx))
    return np.array(corrs_pos), np.array(corrs_neg), np.array(mesh_ids)


# ==================== PUBMED API ====================
def get_pubmed_counts(mesh_id, mesh_id2name, pmid2use):
    time.sleep(random.uniform(0.5, 1.0))
    query = mesh_id2name.get(mesh_id, "")
    if not query:
        return None

    try:
        url = f"https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?db=pubmed&term={query}&rettype=count"
        response = requests.get(url)
        root = ET.fromstring(response.text)
        total_hits = int(root.find("Count").text)

        total_pmids = set()
        retmax = 100000
        for retstart in range(0, total_hits, retmax):
            url = f"https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?db=pubmed&term={query}&retstart={retstart}&retmax={retmax}"
            response = requests.get(url)
            root = ET.fromstring(response.text)
            pmids = [id_elem.text for id_elem in root.findall(".//Id")]
            total_pmids.update(set(pmids))

        count = len(set(pmid2use.astype(str)) & total_pmids)
        return count
    except Exception as e:
        print(f"Error for {mesh_id}: {e}")
        return None


def fetch_pubmed_counts(mesh2use, mesh_id2name, pmid2use, save_path):
    if os.path.exists(save_path):
        print(f"Loading cached PubMed counts from: {save_path}")
        return np.load(save_path)
    
    print("Fetching PubMed counts via API...")
    counts = []
    for mesh_id in tqdm(mesh2use):
        count = get_pubmed_counts(mesh_id, mesh_id2name, pmid2use)
        counts.append(count)

    counts = [int(x) for x in counts if x is not None]
    np.save(save_path, counts)
    print(f"Saved PubMed counts to: {save_path}")
    return np.array(counts)


# ==================== PLOTTING ====================
def plot_article_count_comparison(mesh_tag_counts, pubmed_counts, bert_counts, output_path):
    plt_df = pd.DataFrame({
        "count": np.concatenate([mesh_tag_counts, pubmed_counts, bert_counts]),
        "type": ["MeSH-tag"] * len(mesh_tag_counts) + ["PubMed-API"] * len(pubmed_counts) + ["BERT-based"] * len(bert_counts)
    })

    fig, ax = plt.subplots(figsize=(5, 6))
    sns.boxplot(data=plt_df, x="type", y="count", palette=["gray", "green", "red"], ax=ax, width=0.5)
    ax.set_yscale("log")
    ax.set_xlabel("")
    ax.set_ylabel("Number of detected articles")
    ax.grid(True, axis="y")
    fig.tight_layout()
    fig.savefig(output_path, dpi=300, bbox_inches="tight")
    plt.show()
    plt.close()

In [None]:
# ==================== MAIN ====================
configure_plot_style()

pmids_df = load_pmids()
sim_df = load_similarity_scores()
mesh_category_df = load_mesh_categories()

mesh2use = sim_df.columns[1:]
pmid2use = sim_df["pmid"].values
pmids_df = pmids_df.set_index("pmid").loc[pmid2use].reset_index()

mesh_oh = create_mesh_one_hot(pmids_df, mesh2use)

# Mapping
mesh_id2name = dict(zip(mesh_category_df["mesh_id"], mesh_category_df["label"]))
all_categories = sorted(set(itertools.chain(*mesh_category_df["tree_categories"].values)) - {"C22"})
cat2mesh = {
    cat: mesh_category_df[mesh_category_df["tree_categories"].apply(lambda x: cat in x)]["mesh_id"].values
    for cat in all_categories
}

# Visualization
plot_article_count_comparison(
    mesh_tag_counts=mesh_oh.sum(axis=0),
    pubmed_counts=fetch_pubmed_counts(mesh2use, mesh_id2name, pmid2use, "../data/fig2-3/all_mesh_pubmed_api_counts.npy"),
    bert_counts=(sim_df.iloc[:, 1:].values > 0.2).sum(axis=0),
    output_path="../data/fig2-3/fig2.all_mesh_number_of_detected_articles.png"
)