In [2]:

"""
Initialization
"""
# imports
import os
import pandas as pd
from collections import Counter
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from collections import defaultdict
from graphviz import Digraph
import numpy as np
import re
import seaborn as sns
import sys
from matplotlib.colors import LinearSegmentedColormap
import importlib, util
PROJECT_PATH = "/Users/antata/Library/CloudStorage/OneDrive-BaylorCollegeofMedicine/text-mining/categories"
SFIG_PATH = "/Users/antata/Library/CloudStorage/OneDrive-BaylorCollegeofMedicine/text-mining/manuscript/sfigs"
os.chdir(PROJECT_PATH)
importlib.reload(util)
from util import CATEGORY_NAMES, COLOR_TEMPLATE, CORSIV_PROBE_LIST, CONTROLS, CATEGORIES, read_in_probes, calculate_points, plot_enrichment, breakdown, export_paper


epic = pd.read_csv("../humanData/database/EPIC.hg38.txt", sep="\t", header=None)
epic_probe_list = set(epic.iloc[:,3])
hm450 = pd.read_csv("../humanData/database/HM450.hg38.txt", sep="\t", header=None)
hm450_probe_list = set(hm450.iloc[:,3])
illumina = epic_probe_list.union(hm450_probe_list)

plt.rcParams['font.family'] = 'Helvetica'
plt.rcParams['font.size'] = 16
plt.rcParams['axes.linewidth'] = 2  # Thicker outer box
plt.rcParams['xtick.major.width'] = 2  # Thicker x-axis ticks
plt.rcParams['ytick.major.width'] = 2  # Thicker y-axis ticks



In [2]:
dfs = []
for cat in CATEGORY_NAMES:
    if cat == "obesity":
        cat2 = "anthropometric"
    elif cat == "urogenital":
        cat2 = "reproductive"
    else:
        cat2 = cat
    df = pd.read_csv(f"probe/archive/before_filter1000/{cat2}_all_probes.csv")
    dfs.append(df)
df = pd.concat(dfs)
df.drop_duplicates(subset=["pmcid", "probeId"], inplace=True)
ref = pd.read_csv("pubmed_search/all_studies_cleaned.csv")[["PMCID", "Category"]]
df = pd.merge(df, ref, right_on="PMCID", left_on="pmcid", how="left")
df = df[["Category", "pmcid", "probeId"]]
grouped_df = df.groupby('pmcid')['probeId'].nunique().reset_index()
grouped_df.columns = ['pmcid', 'unique_probe_count']


In [None]:


plt.figure()
plt.hist(grouped_df['unique_probe_count'], bins=20, edgecolor='black', color='skyblue')
plt.xlabel('Number of Probes Reported')
plt.ylabel('Number of Papers')
plt.title('Distribution of Papers by the Number of Probes Reported')
plt.tight_layout()
plt.show()

# plt.savefig(f"{SFIG_PATH}/supp_table_size_all.jpeg", dpi=300, bbox_inches='tight')
# plt.close()


In [None]:
tmp = grouped_df[grouped_df['unique_probe_count'] < 10000]
plt.figure(figsize=(6,6))
plt.hist(tmp['unique_probe_count'], bins=range(0, int(tmp['unique_probe_count'].max()) + 100, 100), edgecolor='black', color='skyblue')
plt.axvline(x=1000, color='red', linestyle='--', alpha=0.7)

# Add labels and title
plt.xlabel('Number of Probes Reported', fontsize=24)
plt.ylabel('Number of Papers', fontsize=24)
plt.title('Distribution of Papers by the Number of Probes Reported', fontsize=24, pad=20)

# plt.yscale('log')
plt.xlim(0, 10000)
# Show plot
plt.tight_layout()
plt.show()
# plt.savefig(f"{SFIG_PATH}/supp_table_size_zoomedin.jpeg", dpi=300, bbox_inches='tight')
# plt.close()

In [None]:
# control_info = pd.read_excel(f"../manuscript/supplementary_tables.xlsx", sheet_name="S5")

control_info = control_info[control_info["CoRSIV Probe Count"] > 0]
metrics = ["Region Size (bp)", "Probe Count", "CpG Count", "TSS Count", "Gene Body Count", "TES Count"]
fig, axes = plt.subplots(2, 3, figsize=(8, 6), gridspec_kw={'hspace': 0.2, 'wspace':0.6})
axes = axes.flatten()


for i, m in enumerate(metrics):
    ax = axes[i]
    data = pd.DataFrame({
        'x': control_info[f'CoRSIV {m}'],
        'y': control_info[f'Control {m}']
    })
    data['frequency'] = data.groupby(['x', 'y'])['x'].transform('count')
    reversed_Blues = plt.colormaps["Blues"].reversed()
    scatter = ax.scatter(data['x'], data['y'], 
                        c=data['frequency'], 
                        cmap='Blues',
                        s=50, 
                        edgecolor='grey',
                        alpha=0.8)
    ax.set_aspect('equal')
    max_val = max(ax.get_xlim()[1], ax.get_ylim()[1])
    ax.set_xlim(0, max_val)
    ax.set_ylim(0, max_val)
    ax.set_title(m, fontsize=16)
    ax.set_xticks(ax.get_yticks())
    ax.set_yticks(ax.get_yticks())
    ax.set_xlim(-1, max_val)
    ax.set_ylim(-1, max_val)
max_freq = max([data.groupby(['x', 'y'])['x'].count().max() for m in metrics])
plt.subplots_adjust(right=0.88)
norm = plt.Normalize(vmin=1, vmax=1000)
sm = plt.cm.ScalarMappable(cmap='Blues', norm=norm)
sm.set_array([])
cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
cbar = fig.colorbar(sm, cax=cbar_ax, label='Number of Regions')
cbar.ax.set_ylabel('Number of Regions', fontsize=20)
cbar.ax.tick_params(labelsize=16)
# Set colorbar ticks to be 1, 10, 100, 1000 (log scale)
cbar.set_ticks(range(200, 1200, 200))
labels = [str(i) if i < 1000 else ">1000" for i in range(200, 1200, 200)]
cbar.set_ticklabels(labels)

for ax in axes:
    scatter = ax.collections[0]
    scatter.set_norm(norm)
    
x = 0.03
fig.add_artist(plt.Line2D([x, x], [0.15,0.85], transform=fig.transFigure, color='black', linestyle='-', linewidth=2))
fig.text(x-0.08, 0.5, 'Control', va='center', rotation='vertical', fontsize=24)

y = 0.05
fig.add_artist(plt.Line2D([0.12, 0.88], [y, y], transform=fig.transFigure, color='black', linestyle='-', linewidth=2))
fig.text(0.5, y-0.08, "CoRSIV", ha='center', rotation='horizontal', fontsize=24)
# plt.tight_layout()
plt.savefig(f"{SFIG_PATH}/control_metrics.jpeg", dpi=300, bbox_inches='tight')



In [57]:
# Figure 3A-F: decay plots for each category

cat_probes_dict = []
for cat in CATEGORY_NAMES:
    if cat == "cancer":
        df = pd.read_csv(f"probe/{cat}_all_probes_backup.csv")
        probe_list = df["probeId"].to_list()
        c = dict(Counter(probe_list))
        cat_probes_dict.append(c)
    else:
        cat_probes_dict.append(read_in_probes(cat))

for i in [1,2,4,8,9]:
    output_path = f"{SFIG_PATH}/{CATEGORY_NAMES[i]}_decay_enrichment.svg"
    l1, l2, l3, p, p2, _ = calculate_points(cat_probes_dict[i], CATEGORY_NAMES[i])
    show_y_label = i == 1
    show_legend = i == 1
    paper, r = plot_enrichment([l1, l2, l3, p, p2], CATEGORY_NAMES[i], i, output=None, show_figure=None, show_y_label=show_y_label, format="svg", show_legend=show_legend, export_all=True)

In [None]:
current_category = "cancer"
target_idx = CATEGORY_NAMES.index(current_category)
mesh_ttoc = defaultdict(set) #term:code
file_path = '../humanData/database/mtrees2024.txt'
# Read the lines from the file
with open(file_path, 'r') as file:
    for line in file:
        # Split each line into A and B
        parts = line.strip().split(';')
        if len(parts) == 2:
            term, code = parts
            mesh_ttoc[term].add(code)
        else:
            print(parts)
mesh_ctot =  {v:k for k, vs in mesh_ttoc.items() for v in vs}

def starts_with_any(given_string, string_list):
    for prefix in string_list:
        if given_string.startswith(prefix):
            return True
    return False
neuro_mesh_tree = {}
keywords = CATEGORIES[target_idx]
for kw in keywords:
    neuro_mesh_tree[kw] = set([k for k, v in mesh_ttoc.items() for c in v if starts_with_any(c, mesh_ttoc[kw])])
def filter_mesh_list(input):
    return any(input in sublist for sublist in neuro_mesh_tree.values())
if current_category != "metabolic":
    neuro_df = pd.read_csv(f"probe/{current_category}_all_probes.csv")
    if current_category != "cancer":
        neuro_df["Filtered Mesh Term"] = neuro_df["Filtered Mesh Term"].apply(lambda x: [term.strip() for term in x.split("|")])
    else:
        neuro_df["Filtered Mesh Term"] = neuro_df["Filtered Mesh Term"].apply(eval)
else:
    neuro_df = pd.read_csv(f"probe/metabolic_diseases_all_probes.csv")  
    neuro_df["Filtered Mesh Term"] = neuro_df["Filtered Mesh Term"].apply(eval)
# df = pd.read_csv("probe_main_table/cancer_main_probes.csv")
# probes_to_keep = df[df["pmcid"] == "PMC10275808"]["probeId"].tolist()
# neuro_df = neuro_df[(~(neuro_df["pmcid"] == "PMC10275808")) | (neuro_df["probeId"].isin(probes_to_keep))]

# neuro_df = neuro_df[~neuro_df["pmcid"].isin(["PMC4222689", "PMC4913906"])]

temp1 = neuro_df.drop_duplicates(subset="pmcid")
mesh_count_by_study = defaultdict(int)
mesh_terms = list(temp1["Filtered Mesh Term"])
pmcids = list(temp1["pmcid"])
term_pmcid_map = defaultdict(set)
for i, m in enumerate(mesh_terms):
    for t in m:
        mesh_count_by_study[t] += 1
        term_pmcid_map[t].add(pmcids[i])
mesh_count_by_study = {key:count for key, count in mesh_count_by_study.items() if filter_mesh_list(key)}
mesh_count_by_study[current_category.capitalize()] = len(set(neuro_df["pmcid"]))
mesh_count_by_study = dict(sorted(mesh_count_by_study.items(), reverse=True))
categories = list(mesh_count_by_study.keys())
counts = list(mesh_count_by_study.values())

enriched_categories = []
not_enriched_categories = []

d1 = []
d2 = []
probes = []
papers_ct = []
terms, counts = zip(*[(k, v) for k, v in mesh_count_by_study.items()])
paper_sets = []
for term in terms:
    if term != "Genital Neoplasms, Male":
        continue
    output_path = f"{SFIG_PATH}/{current_category if current_category != 'neurological' else 'neuro'}/{term}.svg"
    p, paper, r, p2, curr_paper_set = breakdown(neuro_df, term_pmcid_map, [term], target_idx, output=None, show_figure=True, format="svg", export_all=True, show_y_label=True)
    probes.append(p)
    d1.append(paper)
    d2.append(r)
    papers_ct.append(p2)
    paper_sets.append(curr_paper_set)

# df = pd.DataFrame({"Categories":terms, "Enrichment Ratio":d2, "CoRSIV Probes": probes, "CoRSIV Papers":papers_ct, "Highest Number of Papers": d1, "Total Number of Papers": counts, "Paper Sets": paper_sets})
# df.sort_values("Enrichment Ratio", ascending=False, inplace=True)
# df.index = df["Categories"]
# df.drop(columns=["Categories"], inplace=True)
# df = df[(df["Enrichment Ratio"] > 1) & (df["Highest Number of Papers"]> 1) & (df.index!=current_category.capitalize())]


In [None]:
print(df.to_string())

In [232]:
import json

all_studies = pd.read_csv("pubmed_search/all_studies_cleaned.csv")["PMCID"].tolist()
def mqtl(text):
    if type(text) is float:
        return None
    paragraphs = [p.get("text") for p in json.loads(text)["documents"][0]["passages"] if p.get("text")]
    allp = " ".join(paragraphs)
    return sum([allp.count("mQTL"), allp.count("meQTL"), allp.count("methylation quantitative trait loci")])
# mqtl_studies = set([])
dfs = []
for cat in CATEGORY_NAMES:

    cat = "anthropometric" if cat == "obesity" else cat

    df = pd.read_csv(f"full_text/{cat}_full_text.csv")
    df = df[df["PMCID"].isin(all_studies)]
    df["mQTL"] = df["Full Text"].apply(mqtl)
    df = df[df["PMCID"].isin(all_studies)]
    dfs.append(df[["PMCID", "mQTL"]])
dfs = pd.concat(dfs)
dfs.drop_duplicates(subset="PMCID", inplace=True)
    # mqtl_studies.update(set(df[df["mQTL"]==True]["PMCID"].tolist()))
# len(mqtl_studies)
dfs.to_csv("mqtl_counts.csv", index=False)

In [None]:
mqtl_info = pd.read_csv("mqtl_counts.csv")
neuro_df = pd.read_csv(f"probe/neurological_all_probes.csv")  
neuro_df["Filtered Mesh Term"] = neuro_df["Filtered Mesh Term"].apply(lambda x: [term.strip() for term in x.split("|")])
adhd_studies = neuro_df[neuro_df["Filtered Mesh Term"].apply(lambda x: "Attention Deficit Disorder with Hyperactivity" in x)]["pmcid"].tolist()
mqtl_info[mqtl_info["PMCID"].isin(adhd_studies)]

In [None]:
mqtl_info = pd.read_csv("mqtl_counts.csv")

for j, current_category in enumerate(CATEGORY_NAMES):
    keywords = CATEGORIES[j]
    neuro_mesh_tree = {}
    for kw in keywords:
        neuro_mesh_tree[kw] = set([k for k, v in mesh_ttoc.items() for c in v if starts_with_any(c, mesh_ttoc[kw])])
    if current_category != "metabolic":
        neuro_df = pd.read_csv(f"probe/{current_category}_all_probes.csv")  
        if current_category == "cancer":
            neuro_df["Filtered Mesh Term"] = neuro_df["Filtered Mesh Term"].apply(eval)
        else:
            neuro_df["Filtered Mesh Term"] = neuro_df["Filtered Mesh Term"].apply(lambda x: [term.strip() for term in x.split("|")])
    else:
        neuro_df = pd.read_csv(f"probe/metabolic_diseases_all_probes.csv")  
        neuro_df["Filtered Mesh Term"] = neuro_df["Filtered Mesh Term"].apply(eval)

    temp1 = neuro_df.drop_duplicates(subset="pmcid")
    mesh_count_by_study = defaultdict(int)
    mesh_terms = list(temp1["Filtered Mesh Term"])
    pmcids = list(temp1["pmcid"])
    term_pmcid_map = defaultdict(set)
    for i in range(len(mesh_terms)):
        m = mesh_terms[i]
        for t in m:
            mesh_count_by_study[t] += 1
            term_pmcid_map[t].add(pmcids[i])
    mesh_count_by_study = {key:count for key, count in mesh_count_by_study.items() if filter_mesh_list(key)}
    mesh_count_by_study[current_category.capitalize()] = len(set(neuro_df["pmcid"]))
    mesh_count_by_study = dict(sorted(mesh_count_by_study.items(), reverse=True))
    categories = list(mesh_count_by_study.keys())
    counts = list(mesh_count_by_study.values())

    enriched_categories = []
    not_enriched_categories = []

    d1 = []
    d2 = []
    probes = []
    papers_ct = []
    terms, counts = zip(*[(k, v) for k, v in mesh_count_by_study.items()])
    paper_sets = []
    for term in terms:
        p, paper, r, p2, curr_paper_set = breakdown(neuro_df, term_pmcid_map, [term], j, output=None, show_figure=False, format="svg", export_all=False)
        probes.append(p)
        d1.append(paper)
        d2.append(r)
        papers_ct.append(p2)
        paper_sets.append(curr_paper_set)

    df = pd.DataFrame({"Categories":terms, "Enrichment Ratio":d2, "CoRSIV Probes": probes, "CoRSIV Papers":papers_ct, "Highest Number of Papers": d1, "Total Number of Papers": counts, "Paper Sets": paper_sets})
    df.sort_values("Enrichment Ratio", ascending=False, inplace=True)
    cutoff = 15 if current_category != "cancer" else 20
    df = df[df["Total Number of Papers"] >= cutoff]
    df = df[["Categories", "Enrichment Ratio"]]
    
    # Calculate average mQTL count for each category
    df["mQTL"] = df["Categories"].apply(lambda x: mqtl_info[mqtl_info["PMCID"].isin(term_pmcid_map[x])]["mQTL"].mean())
    df = df[["Categories", "Enrichment Ratio", "mQTL"]]
    df.to_csv(f"mqtl/mqtl_counts_{current_category}.csv", index=False)


In [None]:
mesh_ttoc = defaultdict(set) #term:code
file_path = '../humanData/database/mtrees2024.txt'
# Read the lines from the file
with open(file_path, 'r') as file:
    for line in file:
        # Split each line into A and B
        parts = line.strip().split(';')
        if len(parts) == 2:
            term, code = parts
            mesh_ttoc[term].add(code)
        else:
            print(parts)
mesh_ctot =  {v:k for k, vs in mesh_ttoc.items() for v in vs}

def starts_with_any(given_string, string_list):
    for prefix in string_list:
        if given_string.startswith(prefix):
            return True
    return False

def filter_mesh_list(input):
    return any(input in sublist for sublist in neuro_mesh_tree.values())

for current_category in ['cancer', 'endocrine', 'immune', 'metabolic', "neurological", 'urogenital', 'obesity']:
    j = CATEGORY_NAMES.index(current_category)
    keywords = CATEGORIES[j]
    neuro_mesh_tree = {}
    for kw in keywords:
        neuro_mesh_tree[kw] = set([k for k, v in mesh_ttoc.items() for c in v if starts_with_any(c, mesh_ttoc[kw])])
    if current_category != "metabolic":
        neuro_df = pd.read_csv(f"probe/{current_category}_all_probes.csv")  
        if current_category == "cancer":
            neuro_df["Filtered Mesh Term"] = neuro_df["Filtered Mesh Term"].apply(eval)
        else:
            neuro_df["Filtered Mesh Term"] = neuro_df["Filtered Mesh Term"].apply(lambda x: [term.strip() for term in x.split("|")])
    else:
        neuro_df = pd.read_csv(f"probe/metabolic_diseases_all_probes.csv")  
        neuro_df["Filtered Mesh Term"] = neuro_df["Filtered Mesh Term"].apply(eval)

    temp1 = neuro_df.drop_duplicates(subset="pmcid")
    mesh_count_by_study = defaultdict(int)
    mesh_terms = list(temp1["Filtered Mesh Term"])
    pmcids = list(temp1["pmcid"])
    term_pmcid_map = defaultdict(set)
    for i in range(len(mesh_terms)):
        m = mesh_terms[i]
        for t in m:
            mesh_count_by_study[t] += 1
            term_pmcid_map[t].add(pmcids[i])
    mesh_count_by_study = {key:count for key, count in mesh_count_by_study.items() if filter_mesh_list(key)}
    mesh_count_by_study[current_category.capitalize()] = len(set(neuro_df["pmcid"]))
    mesh_count_by_study = dict(sorted(mesh_count_by_study.items(), reverse=True))
    categories = list(mesh_count_by_study.keys())
    counts = list(mesh_count_by_study.values())


    terms, counts = zip(*[(k, v) for k, v in mesh_count_by_study.items()])
    paper_sets = []
    for term in terms:
        p, paper, r, p2, curr_paper_set = breakdown(neuro_df, term_pmcid_map, [term], j, output=None, show_figure=False, format="svg", export_all=False)
        paper_sets.append(curr_paper_set)

    df = pd.DataFrame({"Categories":terms, "Paper Sets": paper_sets})
    categories_we_want = pd.read_excel(f"../permutation_testing/pt_results.xlsx", sheet_name=current_category)
    m = pd.merge(categories_we_want["Disease"], df, left_on="Disease", right_on="Categories").iloc[:, 1:]
    paper_df = pd.DataFrame(index=sorted(set().union(*m["Paper Sets"])))
    for cat, papers in zip(m["Categories"], m["Paper Sets"]):
        paper_df[cat] = paper_df.index.isin(papers)
    read_in_category = current_category if current_category != "metabolic" else "metabolic_diseases"
    paper_details = pd.read_csv(f"pubmed_search/{read_in_category}_final.csv")[["PMID", "PMCID", "Journal", "Last Name", "Year", "Title", "Abstract"]]
    paper_details.columns = ["PMID", "PMCID", "Journal", "First Author Last Name", "Year", "Title", "Abstract"]
    probes = pd.read_csv(f"probe/{read_in_category}_all_probes.csv")[["probeId", "pmcid"]]
    c = Counter(probes["probeId"])
    c = {k:v for k, v in c.items() if v > 1}
    probes = probes[(probes["probeId"].isin(CORSIV_PROBE_LIST)) & (probes["probeId"].isin(c.keys()))]
    probes = probes.groupby("pmcid").agg({"probeId": lambda x: ",".join(x)})
    probes.columns = ["CoRSIV Probes Reported"]
    paper_details = pd.merge(paper_details, probes, left_on="PMCID", right_on="pmcid")
    # paper_details.drop(columns=["pmcid"], inplace=True)
    final_df = pd.merge(paper_details, paper_df, left_on="PMCID", right_index=True)
    final_df.to_csv(f"../manuscript/stables/paper_details/{current_category}_paper_details_2.csv", index=False)
    


In [None]:
df = pd.read_csv("probe_main_table/cancer_main_probes.csv")
probes_to_keep = df[df["pmcid"] == "PMC10275808"]["probeId"].tolist()
cancer_df = pd.read_csv("probe/cancer_all_probes.csv")
cancer_df = cancer_df[(~(cancer_df["pmcid"] == "PMC10275808")) | (cancer_df["probeId"].isin(probes_to_keep))]
cancer_df["Filtered Mesh Term"] = cancer_df["Filtered Mesh Term"].apply(eval)
term = "Prostatic Neoplasms"
cancer_df = cancer_df[cancer_df["Filtered Mesh Term"].  apply(lambda x: term in x)]
c = Counter(cancer_df["probeId"])
l1, l2, l3, p, p2,_ = calculate_points(c, term)
paper, r = plot_enrichment([l1, l2, l3, p, p2], term, 0, output=None, show_figure=True, box_placement=(0.15, 0.9), show_y_label=True, format="svg", show_legend=False)



In [None]:
cancer_df = pd.read_csv("probe/cancer_all_probes.csv")
cancer_df[cancer_df["pmcid"]== "PMC10275808"].shape
# cancer_df.to_csv("probe/cancer_all_probes_corsiv.csv", index=False)


In [None]:
current_category = "cancer"
categories_we_want = pd.read_excel(f"../permutation_testing/pt_results.xlsx", sheet_name=current_category)
categories_we_want

In [88]:
import statsmodels.api as sm
from scipy import stats
fig, axes = plt.subplots(3, 2, figsize=(6,9))
axes = axes.flatten()
fig_idx = 0
for current_category in ["cancer", "endocrine", "immune", "metabolic", "neurological", "urogenital"]:
    j = CATEGORY_NAMES.index(current_category)
    df = pd.read_csv(f"mqtl/mqtl_counts_{current_category}.csv")
    ax = axes[fig_idx]
    X = np.array(df["mQTL"])
    y = np.array(df["Enrichment Ratio"])
    X_const = sm.add_constant(X)
    model = sm.OLS(y, X_const).fit()
    ax.scatter(X, y, color=COLOR_TEMPLATE[j], s=50)
    ax.plot(X, model.predict(X_const), color=COLOR_TEMPLATE[j], linestyle='solid', linewidth=4, alpha=0.6)
    r_squared = model.rsquared
    f_pvalue = model.f_pvalue
    if current_category in ["cancer", "endocrine", "urogenital"]:
        annotation_text = f"R² = {r_squared:.2f}\nP = {f_pvalue:.1e}"
    else:
        annotation_text = f"R² = {r_squared:.2f}\nP = {f_pvalue:.2g}"
    ax.annotate(annotation_text, 
                xy=(0.05, 0.85), # Position in axes coordinates
                xycoords='axes fraction',
                color="red" if f_pvalue < 0.05 else "black",
                fontsize=16,
                ha='left', 
                va='center',
                bbox=dict(boxstyle='round,pad=0.5', fc='none', ec='none', alpha=1))
    ax.set_title(f"{current_category.capitalize()}", fontsize=20)
    fig_idx += 1
# Add a vertical line
x = 0.02
fig.add_artist(plt.Line2D([x, x], [0.06,0.94], transform=fig.transFigure, color='black', linestyle='-', linewidth=2))
fig.text(x-0.07, 0.5, 'Enrichment Ratio', va='center', rotation='vertical', fontsize=24)

# Add a horizontal line
y = 0.02
fig.add_artist(plt.Line2D([0.1, 0.95], [y, y], transform=fig.transFigure, color='black', linestyle='-', linewidth=2))
# Add vertical text to the left of the vertical line
fig.text(0.5, y-0.07, "Average mQTL Mentions", ha='center', rotation='horizontal', fontsize=24)
plt.tight_layout()
# plt.show()
plt.savefig(f"{SFIG_PATH}/mqtl_enrichment_ratio_enriched.svg", format="svg")
plt.close()

In [87]:
import statsmodels.api as sm
from scipy import stats
fig, axes = plt.subplots(2, 2, figsize=(6,6))
axes = axes.flatten()
fig_idx = 0
for current_category in ["cardiovascular", "digestive", "hematological", "respiratory"]:
    j = CATEGORY_NAMES.index(current_category)
    df = pd.read_csv(f"mqtl/mqtl_counts_{current_category}.csv")
    ax = axes[fig_idx]
    X = np.array(df["mQTL"])
    y = np.array(df["Enrichment Ratio"])
    X_const = sm.add_constant(X)
    model = sm.OLS(y, X_const).fit()
    ax.scatter(X, y, color=COLOR_TEMPLATE[j], s=50)
    ax.plot(X, model.predict(X_const), color=COLOR_TEMPLATE[j], linestyle='solid', linewidth=4, alpha=0.6)
    r_squared = model.rsquared
    f_pvalue = model.f_pvalue
    
    annotation_text = f"R² = {r_squared:.2f}\nP = {f_pvalue:.2f}"
    ax.annotate(annotation_text, 
                xy=(0.55, 0.85), # Position in axes coordinates
                xycoords='axes fraction',
                color="red" if f_pvalue < 0.05 else "black",
                fontsize=16,
                ha='left', 
                va='center',
                bbox=dict(boxstyle='round,pad=0.5', fc='none', ec='none', alpha=1))
    ax.set_title(f"{current_category.capitalize()}", fontsize=20)
    fig_idx += 1
# Add a vertical line
x = 0.02
fig.add_artist(plt.Line2D([x, x], [0.09,0.91], transform=fig.transFigure, color='black', linestyle='-', linewidth=2))
fig.text(x-0.07, 0.5, 'Enrichment Ratio', va='center', rotation='vertical', fontsize=24)

# Add a horizontal line
y = 0.02
fig.add_artist(plt.Line2D([0.1, 0.95], [y, y], transform=fig.transFigure, color='black', linestyle='-', linewidth=2))
# Add vertical text to the left of the vertical line
fig.text(0.5, y-0.07, "Average mQTL Mentions", ha='center', rotation='horizontal', fontsize=24)
plt.tight_layout()
# plt.show()
plt.savefig(f"{SFIG_PATH}/mqtl_enrichment_ratio_not_enriched.svg", format="svg")
plt.close()


In [None]:
# df = pd.read_csv("probe/cancer_all_probes.csv")
# df["Filtered Mesh Term"] = df["Filtered Mesh Term"].apply(eval)
# df = df[df["Filtered Mesh Term"].apply(lambda x: "Prostatic Neoplasms" in x)]
# df = df[df["probeId"].isin(CORSIV_PROBE_LIST)]
df["pmcid"].unique()

In [None]:
# Figure 5C: becon regression plots for all categories except neurological
import math
import scipy.stats as stats
import matplotlib.pyplot as plt
import numpy as np
import statsmodels.api as sm

non_corsiv_baseline = illumina - CORSIV_PROBE_LIST
target_col = "Mean Cor All Brain"
df = pd.read_csv("becon/becon_all_probes.csv")
regions = list(zip(["CoRSIV", "Non-CoRSIV"], [CORSIV_PROBE_LIST, non_corsiv_baseline]))

# Create subplot grid
fig, axs = plt.subplots(3, 2, figsize=(12, 11))
axs = axs.flatten()

# Plot each category except neurological
plot_idx = 0
for cat in ["cancer", "endocrine", "immune", "metabolic", "urogenital"]:
    cat_probes = read_in_probes(cat)
    ax = axs[plot_idx]
    max_papers = max(cat_probes.values())

    for rname, rset in regions:
        medians = []
        paper_counts = []
        for pidx in range(1, max_papers + 1):
            p = set(k for k, v in cat_probes.items() if v == pidx)
            probes_in_region = rset.intersection(p)
            filtered_df = df[df["CpG ID"].isin(probes_in_region)]
            if len(filtered_df) < 15:
                max_papers = pidx - 1
                break
            else:
                medians.append(filtered_df[target_col].median())
                paper_counts.append(pidx)
        
        X = np.array(medians)
        y = np.array(paper_counts)
        X_const = sm.add_constant(X)
        model = sm.OLS(y, X_const).fit()

        color = COLOR_TEMPLATE[CATEGORY_NAMES.index(cat)] if rname == 'CoRSIV' else 'grey'
        marker = 'D' if rname == 'CoRSIV' else 'o'
        ax.scatter(X, y, color=color, label=rname, s=200, marker=marker)
        ax.plot(X, model.predict(X_const), color=color, linestyle='--', linewidth=4, zorder=10, alpha=0.6)
        r_squared = round(model.rsquared, 3)
        slope = round(model.params[1], 3)
        f_pvalue = round(model.f_pvalue, 3)
        
        annotation_text = f"R² = {r_squared:.2f}\nP = {f_pvalue:.2f}"
        
        # Add annotation
        ax.annotate(annotation_text, 
                    xy=(X[-1]-0.04, y[-1]/2),
                    xytext=(10, 0), 
                    textcoords='offset points',
                    color=color,
                    fontsize=20,
                    ha='left', 
                    va='center',
                    bbox=dict(boxstyle='round,pad=0.5', fc='none', ec='none', alpha=1))

    # ax.set_xlabel("Median Brain-Blood Correlation (BECon)", fontsize=22)
    # ax.set_ylabel('Number of Papers', fontsize=18)
    ax.set_title(cat.capitalize(), fontsize=28, pad=15)
    ax.set_xlim(-0.05, 0.6)
    ax.set_xticks(np.arange(0.0, 0.7, 0.1))
    ax.tick_params(axis='x', which='major', length=5)
    ax.set_yticks(range(1, max_papers + 1))
    ax.set_ylim(0, max_papers + 1)
    ax.tick_params(axis='both', which='major', labelsize=25)
    
    plot_idx += 1

axs[-1].axis('off')
# Add legend to the last subplot
handles = [plt.Line2D([0], [0], marker='D', color='w', markerfacecolor="k", 
                      label="CoRSIV", markersize=15)]
handles.append(plt.Line2D([0], [0], marker='o', color='w', label='Control', markersize=20, markerfacecolor='grey'))

axs[-1].legend(handles=handles, loc='center', fontsize=24, frameon=False)

x = 0
fig.add_artist(plt.Line2D([x, x], [0.06,0.94], transform=fig.transFigure, color='black', linestyle='-', linewidth=2))
fig.text(x-0.04, 0.5, 'Number of Papers', va='center', rotation='vertical', fontsize=24)

# Add a horizontal line
y = 0
fig.add_artist(plt.Line2D([0.04, 0.96], [y, y], transform=fig.transFigure, color='black', linestyle='-', linewidth=2))
# Add vertical text to the left of the vertical line
fig.text(0.5, y-0.04, "Median Brain-Blood Correlation (BECon)", ha='center', rotation='horizontal', fontsize=24)
plt.tight_layout()

output = f"{SFIG_PATH}/becon_regression_all_categories.svg"
plt.savefig(output, format="svg")

In [None]:
enrichment_ratios = [10.90, 4.69, 0.67, 53.48, 1.61, 10.08, 23.70, 25.66, 5.72, 2.24, 6.93]
mqtl_counts = []
number_of_mqtl_studies = []
mqtl_info = pd.read_csv("mqtl_counts.csv")
studies = pd.read_csv("../manuscript/stables/all_2203_studies.csv")
for cat in CATEGORY_NAMES:
    paperid = studies[studies[cat.capitalize()]]["PMCID"].tolist()
    current_mqtl = mqtl_info[mqtl_info["PMCID"].isin(paperid)]
    mqtl_counts.append(round(current_mqtl["mQTL"].mean(), 2))
    number_of_mqtl_studies.append(round(current_mqtl[current_mqtl["mQTL"]>0].shape[0] / len(paperid), 2)*100)
plt.figure(figsize=(6,6))
plt.scatter(enrichment_ratios, mqtl_counts, c=COLOR_TEMPLATE, s=100)
plt.xlabel("Enrichment Ratio", fontsize=18)
plt.ylabel("Average mQTL Count", fontsize=18)
plt.tight_layout()
plt.show()
# axes[j].scatter(df["Enrichment Ratio"], df["mQTL"], c=COLOR_TEMPLATE[j])
# axes[j].set_xlabel("Enrichment Ratio", fontsize=18)
# axes[j].set_ylabel("Average mQTL Count", fontsize=18) 
# axes[j].set_title(f"{current_category.capitalize()}", fontsize=20)

In [None]:
plt.figure(figsize=(6,6))
plt.scatter(enrichment_ratios, number_of_mqtl_studies, c=COLOR_TEMPLATE, s=100)
plt.xlabel("Enrichment Ratio", fontsize=18)
plt.ylabel("Percentage of Studies mentioning mQTL (%)", fontsize=18)
plt.tight_layout()
plt.show()

In [None]:
df = pd.read_csv("categories/probe/neurological_all_probes.csv")
df["Filtered Mesh Term"] = df["Filtered Mesh Term"].apply(lambda x: [y.strip() for y in x.split("|")])
adhd_studies = set(df[df["Filtered Mesh Term"].apply(lambda x: "Attention Deficit Disorder with Hyperactivity" in x)]["pmcid"].tolist())
set(df[df["probeId"].isin(corsiv_control_probes["probeId"])]["pmcid"].tolist()).intersection(adhd_studies)
# adhd_studies.difference(set(df[df["probeId"].isin(corsiv_control_probes["probeId"])]["pmcid"].tolist()))

In [34]:
export_categories = df[(df["Enrichment Ratio"] > 1) & (df["Total Number of Papers"]>= 20) & (df["Highest Number of Papers"] > 1)].index.tolist()
for catname in export_categories:
    to_export_df = neuro_df[neuro_df["Filtered Mesh Term"].apply(lambda x: catname in x)]
    to_export_df.to_csv(f"../permutation_testing/pt/{current_category}/{catname}_probes.csv", index=False)
    



In [None]:
# Create scatter plot of highest vs total papers
plt.figure(figsize=(6,6))
plt.scatter(df["Total Number of Papers"], df["Highest Number of Papers"], 
           c=df["Enrichment Ratio"].apply(lambda x: COLOR_TEMPLATE[target_idx] if x > 1 else "grey"),
           alpha=0.6)

plt.xlabel("Total Number of Papers Reporting a Probe", fontsize=12)
plt.ylabel("Highest Number of Papers Reporting a Probe", fontsize=12)

# Add 1:1 reference line
max_val = max(df["Total Number of Papers"].max(), df["Highest Number of Papers"].max())
plt.xlim(0, 20)
plt.ylim(0, 10)
plt.tight_layout()
# plt.savefig(f"{SFIG_PATH}/{current_category}/highest_vs_total_papers.jpeg", dpi=300)
plt.show()


In [None]:
# Figure 4A: neurological mesh hierarchy
from graphviz import Digraph
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches


def visualize(input_nodes, output_name=None, format="svg", target="Enrichment Ratio"):
    input_nodes = {x for y in input_nodes for x in mesh_ttoc[y]}
    dot = Digraph()
    dot.attr(rankdir='LR')  # Keep layout horizontal (Left-to-Right)

    edges = set()
    nodes = {}
    
    for code in input_nodes:
        parts = code.split('.')
        for i in range(1, len(parts) + 1):
            partial_code = '.'.join(parts[:i])
            if partial_code not in nodes and partial_code in mesh_ctot and any(partial_code.startswith(c) for k in keywords for c in mesh_ttoc[k]):
                nodes[partial_code] = mesh_ctot[partial_code]
    def add_code_edges(code):
        parts = code.split('.')
        for i in range(1, len(parts)):
            parent_code = '.'.join(parts[:i])
            child_code = '.'.join(parts[:i+1])
            if parent_code in nodes and child_code in nodes:
                edges.add((parent_code, child_code))
            
    def get_node_color(node):
        term = nodes[node]
        if term not in df.index:
            return "#FFFFFF"  # White for nodes not in df
        
        highest_papers = df.loc[term, 'Highest Number of Papers']
        enrichment_ratio = df.loc[term, 'Enrichment Ratio']
        
        if enrichment_ratio < 1 or highest_papers < 2:
            return "#FFFFFF"  # White for nodes with low papers or enrichment
        
        color_rgb = mcolors.to_rgb(COLOR_TEMPLATE[target_idx])
        intensity = min(1, df.loc[term, target] / df[target].max())
        scaled_color = tuple(1 - (1 - c) * intensity for c in color_rgb)  # Invert intensity calculation
        return mcolors.to_hex(scaled_color)

    
    for code, term in nodes.items():
        add_code_edges(code)  # Add edges between parent and child nodes
    if current_category == "neurological":
        nodes["N"] = "Neurological" 
        edges.add(("N", "F03"))
        edges.add(("N", "C10"))
    # Add nodes for each unique code with term name as label
    for code, term in nodes.items():
        node_color = get_node_color(code)
        if term in df.index:
            custom_label = f"{term} ({df.loc[term, 'Total Number of Papers']})"
        else:
            custom_label = term
        dot.node(code, label=custom_label, shape='box', style='filled', fillcolor=node_color, fontname='Helvetica')
    for parent_code, child_code in edges:
        dot.edge(parent_code, child_code)

    dot.attr(concentrate='true', splines='ortho')
    dot.attr(nodesep='0.2', ranksep='0.5')
    
    # Create color legend
    fig, ax = plt.subplots(figsize=(6, 1))
    cmap = mcolors.LinearSegmentedColormap.from_list("custom", ["white", COLOR_TEMPLATE[target_idx]])
    norm = mcolors.Normalize(vmin=0, vmax=df[target].max())
    cb = plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), 
                      cax=ax, orientation='horizontal', label=target)
    
    # Save legend as separate image
    
    legend_path =f"{SFIG_PATH}/hierarchy_plots/{current_category}_legend.svg"
    plt.savefig(legend_path, bbox_inches='tight')

    if output_name:
        dot.render(output_name, format=format, cleanup=True)
    else:
        dot.view()
    return nodes

minimum_papers = 20
output_name = f"{SFIG_PATH}/hierarchy_plots/{current_category}_hierarchy_{minimum_papers}_papers" if minimum_papers > 0 else f"{SFIG_PATH}/hierarchy_plots/{current_category}_full_hierarchy"
# Modify the visualize function call to adjust edge routing
tmp = {key:count for key, count in mesh_count_by_study.items() if count >= minimum_papers}

nodes = visualize(tmp.keys(), output_name=None, format="svg")

In [None]:
import math

manual = pd.read_csv("pubmed_search/archive/adhd_test/probes/all_ADHD_probes.csv")
manual = manual[manual["probeId"].str.startswith("cg") | manual["probeId"].str.startswith("ch.")]
manual = manual[manual["probeId"].isin(epic_probe_list) | manual["probeId"].isin(hm450_probe_list)]
def to_float(x):
    try:
        return float(x)
    except:
        if x == "<0.05":
            return 0.049
        if x == "pass":
            return 0.049
        if x == "not pass" or x == " ":
            return math.nan
        if x == "> 0.1":
            return 0.11
        elif x == "≤ 0.1":
            return 0.09
        elif x == "≤ 0.05":
            return 0.049
        elif x == "≤0.01":
            return 0.009
        return x
manual["p-value"] = manual["p-value"].str.replace("−", "-").str.replace("‐", "-").str.replace("–", "-").str.replace(" × E", "e").apply(to_float)
manual["q-value"] = manual["q-value"].str.replace("−", "-").str.replace("‐", "-").str.replace("–", "-").str.replace(" × E", "e").apply(to_float)
manual["adj-p-value"] = manual["adj-p-value"].str.replace("−", "-").str.replace("‐", "-").str.replace("–", "-").str.replace(" × E", "e").apply(to_float)
manual = manual.groupby(['pmcid', 'From']).filter(lambda x: len(x) <= 1000)


keep = pd.read_csv("pubmed_search/archive/adhd_test/probes/all_ADHD_source.csv")
manual = pd.merge(manual, keep, on=["pmcid", "Notes", "From", "Title"])
manual = manual[manual["Keep1"]==1]
manual = manual[(manual["p-value"] < 0.05)]

# manual = manual[(manual["q-value"] < 0.05) | (manual["adj-p-value"] < 0.05) | (manual["p-value"] < 1e-5)]
manual = manual.drop_duplicates(subset=["pmcid", "probeId"])
print(manual.shape)
c = Counter(manual["probeId"])
l1, l2, l3, p, p2 = calculate_points(c, manual)
output_path = f"{SFIG_PATH}/adhd_manual_enrichment.svg"
paper, r = plot_enrichment([l1, l2, l3, p, p2], "ADHD (Manual)", 7, output=None, format="svg")

In [None]:
import math

manual = pd.read_csv("pubmed_search/adhd_test/probes/all_ADHD_probes_uptodate.csv")
manual = manual[manual["probeId"].str.startswith("cg") | manual["probeId"].str.startswith("ch.")]
manual = manual[manual["probeId"].isin(epic_probe_list) | manual["probeId"].isin(hm450_probe_list)]
def to_float(x):
    try:
        return float(x)
    except:
        if x == "<0.05":
            return 0.049
        if x == "pass":
            return 0.049
        if x == "not pass" or x == " ":
            return math.nan
        if x == "> 0.1":
            return 0.11
        elif x == "≤ 0.1":
            return 0.09
        elif x == "≤ 0.05":
            return 0.049
        elif x == "≤0.01":
            return 0.009
        return x
manual["p-value"] = manual["p-value"].apply(to_float)
manual["q-value"] = manual["q-value"].apply(to_float)
manual["adj-p-value"] = manual["adj-p-value"].apply(to_float)
manual = manual.groupby(['pmcid', 'From']).filter(lambda x: len(x) <= 1000)


keep = pd.read_csv("pubmed_search/adhd_test/probes/all_ADHD_source_uptodate.csv")
manual = pd.merge(manual, keep, on=["pmcid", "Notes", "From", "Title"])
manual = manual[manual["Keep1"]==1]
manual = manual[(manual["p-value"] < 0.05)]
# manual = manual[(manual["q-value"] < 0.05) | (manual["adj-p-value"] < 0.05) | (manual["p-value"] < 1e-5)]
manual = manual.drop_duplicates(subset=["pmcid", "probeId"])
# manual[["pmcid","probeId"]].to_csv("../permutation_testing/adhd_manual_probes.csv", index=False)
c = Counter(manual["probeId"])
l1, l2, l3, p, p2 = calculate_points(c, manual)
output_path = f"{SFIG_PATH}/adhd_manual_enrichment_nominal.svg"
paper, r = plot_enrichment([l1, l2, l3, p, p2], "ADHD (Nominally Significant)", 7, output=output_path, format="svg")

In [None]:
df = pd.read_csv("../permutation_testing/adhd_manual_probes.csv")
print(df.shape)
c = Counter(df["probeId"])
l1, l2, l3, p, p2, _ = calculate_points(c, manual)
output_path = f"{SFIG_PATH}/adhd_manual_enrichment_nominal.svg"
paper, r = plot_enrichment([l1, l2, l3, p, p2, _], "ADHD (Manual)", 7, output=output_path, format="svg")

In [None]:
df = pd.read_csv("probe/cancer_all_probes.csv")
df["Filtered Mesh Term"] = df["Filtered Mesh Term"].apply(eval)
df = df[df["Filtered Mesh Term"].apply(lambda x: "Prostatic Neoplasms" in x)]
prostate_cancer_pmcids = df["pmcid"].unique()
mqtl_info = pd.read_csv("mqtl_counts.csv")
mqtl_info = mqtl_info[mqtl_info["PMCID"].isin(prostate_cancer_pmcids)]
questionable_mqtl_papers = mqtl_info[mqtl_info["mQTL"] > 0]["PMCID"].unique()
df = pd.read_csv("probe/cancer_all_probes.csv")
df = df[df["pmcid"].isin(questionable_mqtl_papers)]
df.groupby("pmcid")["probeId"].nunique()
# df[df["pmcid"] == "PMC7145271"]
# df[df["pmcid"] == "PMC7145271"]

In [None]:
df = pd.read_csv("probe/metabolic_diseases_all_probes.csv")
df["Filtered Mesh Term"] = df["Filtered Mesh Term"].apply(eval)
df = df[df["Filtered Mesh Term"].apply(lambda x: "Diabetes Mellitus, Type 2" in x)]
prostate_cancer_pmcids = df["pmcid"].unique()
mqtl_info = pd.read_csv("mqtl_counts.csv")
mqtl_info = mqtl_info[mqtl_info["PMCID"].isin(prostate_cancer_pmcids)]
questionable_mqtl_papers = mqtl_info[mqtl_info["mQTL"] > 0]["PMCID"].unique()
df = pd.read_csv("probe/metabolic_diseases_all_probes.csv")
df = df[df["pmcid"].isin(questionable_mqtl_papers)]
df.groupby("pmcid")["probeId"].nunique()
# df[df["pmcid"] == "PMC7145271"]
# df[df["pmcid"] == "PMC7145271"]

In [None]:
cancer_df = pd.read_csv("probe/metabolic_diseases_all_probes.csv")
cancer_df["Filtered Mesh Term"] = cancer_df["Filtered Mesh Term"].apply(eval)
term = "Diabetes Mellitus, Type 2"
cancer_df = cancer_df[cancer_df["Filtered Mesh Term"].apply(lambda x: term in x)]
cancer_df = cancer_df[~cancer_df["pmcid"].isin(["PMC4222689", "PMC4913906"])]
c = Counter(cancer_df["probeId"])
l1, l2, l3, p, p2,_ = calculate_points(c, "metabolic")
paper, r = plot_enrichment([l1, l2, l3, p, p2], "metabolic", 0, output=None, show_figure=True, box_placement=(0.15, 0.9), show_y_label=True, format="svg", show_legend=False)



In [None]:
# Figure 5B: becon density plot
import math
import scipy.stats as stats
import matplotlib.pyplot as plt

non_corsiv_baseline = illumina - CORSIV_PROBE_LIST
bins = [-1, 0, 1.0]

idx = 1
target_cols = ["Mean Cor All Brain", "ICC", "iir1"]
fnames = ["becon/becon_all_probes.csv", "iir_icc/Flanagan_icc_results.csv", "iir_icc/Flanagan_iir_results.csv"]
probe_cols = ["CpG ID", "ID", "ID"]
xlabels = ["Brain-Blood Correlation (BECon)", "Intraclass Correlation Coefficient (ICC)", r"IIR$_{2-98\%}$"]
output_path = ["becon", "icc", "iir"]

target_col = target_cols[idx]
df = pd.read_csv(fnames[idx])
colname = probe_cols[idx]
regions = list(zip(["Non-CoRSIV", "CoRSIV"], [non_corsiv_baseline, CORSIV_PROBE_LIST]))
xlabel = xlabels[idx]
output_id = output_path[idx]

# Create subplots with 3 rows and 4 columns
fig, axes = plt.subplots(3, 4, figsize=(12, 9))
axes = axes.flatten()
lst = [c for c in CATEGORY_NAMES if c != "neurological"] if idx == 0 else CATEGORY_NAMES
# Plot for each category
for cat_idx, catname in enumerate(lst):
    ax = axes[cat_idx]
    i = CATEGORY_NAMES.index(catname)

    dfs_for_plot = []
    max_papers = max(cat_probes_dict[i].values())
    papers_threshold = 2
    p = set(k for k, v in cat_probes_dict[i].items() if v >= papers_threshold)
    
    for rname, rset in regions:
        probes_in_region = rset.intersection(p)
        filtered_df = df[df[colname].isin(probes_in_region)]
        dfs_for_plot.append((filtered_df, rname))

    for j, (df_subset, rname) in enumerate(dfs_for_plot):
        density = stats.gaussian_kde(df_subset[target_col])
        xs = np.linspace(-1, 1, 200)
        ys = density(xs)
        color = COLOR_TEMPLATE[i] if rname == 'CoRSIV' else 'grey'
        ax.plot(xs, ys, "-", color=color, label=rname, linewidth=3)
        ax.fill_between(xs, ys, alpha=0.5, color=color)
        peak_index = np.argmax(ys)
        ax.text(xs[peak_index], ys[peak_index]+0.04, f"{rname}", fontsize=15, #\n(n={len(df_subset):,})
                verticalalignment='bottom', horizontalalignment='center',
                color=color)

    ax.tick_params(axis='both', which='major', labelsize=16, length=5)
    if idx == 2:
        ax.set_xticks([0, 0.5, 1.0])
        ax.set_xlim(-0.1, 1.1)
    elif idx == 1:
        ax.set_xticks([-1, -0.5, 0.0, 0.5, 1.0])
    else:
        ax.set_yticks([0, 1.0, 2.0])
        ax.set_ylim(0, 2.5)
        ax.set_xticks([-1, 0.0, 1.0])
    ax.set_title(catname.capitalize(), 
                 fontsize=20, pad=10)
ax = axes[10] if idx == 0 else axes[11]

dfs_for_plot = []

for rname, rset in regions:
    filtered_df = df[df[colname].isin(rset)]
    dfs_for_plot.append((filtered_df, rname))

for j, (df_subset, rname) in enumerate(dfs_for_plot):
    density = stats.gaussian_kde(df_subset[target_col])
    xs = np.linspace(-1, 1, 200)
    ys = density(xs)
    color = 'black' if rname == 'CoRSIV' else 'grey'
    ax.plot(xs, ys, "-", color=color, label=rname, linewidth=3)
    ax.fill_between(xs, ys, alpha=0.5, color=color)
    peak_index = np.argmax(ys)
    ax.text(xs[peak_index], ys[peak_index]+0.04, f"{rname}", fontsize=15, #\n(n={len(df_subset):,})
            verticalalignment='bottom', horizontalalignment='center',
            color=color)

ax.tick_params(axis='both', which='major', labelsize=16, length=5)
if idx == 2:
    ax.set_xticks([0, 0.5, 1.0])
    ax.set_xlim(-0.1, 1.1)
elif idx == 1:
    ax.set_xticks([-1, -0.5, 0.0, 0.5, 1.0])
else:
    ax.set_yticks([0, 1.0, 2.0])
    ax.set_ylim(0, 2.5)
    ax.set_xticks([-1, 0.0, 1.0])
    
ax.set_title("All Probes", 
                fontsize=20, pad=10)
if idx == 0:
    # Remove the last (empty) subplot
    fig.delaxes(axes[-1])

# Add a vertical line
x = 0
fig.add_artist(plt.Line2D([x, x], [0.06,0.94], transform=fig.transFigure, color='black', linestyle='-', linewidth=2))
fig.text(x-0.04, 0.5, 'Density', va='center', rotation='vertical', fontsize=24)

# Add a horizontal line
y = 0
fig.add_artist(plt.Line2D([0.04, 0.98], [y, y], transform=fig.transFigure, color='black', linestyle='-', linewidth=2))
# Add vertical text to the left of the vertical line
fig.text(0.5, y-0.04, xlabel, ha='center', rotation='horizontal', fontsize=24)
plt.tight_layout()
plt.show()
output = f"{SFIG_PATH}/{output_id}_kde_all_categories.svg"
plt.savefig(output, format="svg")

In [None]:
# Figure 5C: becon regression plots for all categories except neurological
import math
import scipy.stats as stats
import matplotlib.pyplot as plt
import numpy as np
import statsmodels.api as sm

non_corsiv_baseline = illumina - CORSIV_PROBE_LIST
target_col = "Mean Cor All Brain"
df = pd.read_csv("becon/becon_all_probes.csv")
regions = list(zip(["CoRSIV", "Non-CoRSIV"], [CORSIV_PROBE_LIST, non_corsiv_baseline]))

# Create subplot grid
fig, axs = plt.subplots(3, 2, figsize=(12, 11))
axs = axs.flatten()

# Plot each category except neurological
plot_idx = 0
for cat in ["cancer", "endocrine", "immune", "metabolic", "urogenital"]:
    cat_probes = read_in_probes(cat)
    ax = axs[plot_idx]
    max_papers = max(cat_probes.values())

    for rname, rset in regions:
        medians = []
        paper_counts = []
        for pidx in range(1, max_papers + 1):
            p = set(k for k, v in cat_probes.items() if v == pidx)
            probes_in_region = rset.intersection(p)
            filtered_df = df[df["CpG ID"].isin(probes_in_region)]
            if len(filtered_df) < 15:
                max_papers = pidx - 1
                break
            else:
                medians.append(filtered_df[target_col].median())
                paper_counts.append(pidx)
        
        X = np.array(medians)
        y = np.array(paper_counts)
        X_const = sm.add_constant(X)
        model = sm.OLS(y, X_const).fit()

        color = COLOR_TEMPLATE[CATEGORY_NAMES.index(cat)] if rname == 'CoRSIV' else 'grey'
        marker = 'D' if rname == 'CoRSIV' else 'o'
        ax.scatter(X, y, color=color, label=rname, s=200, marker=marker)
        ax.plot(X, model.predict(X_const), color=color, linestyle='--', linewidth=4, zorder=10, alpha=0.6)
        r_squared = round(model.rsquared, 3)
        slope = round(model.params[1], 3)
        f_pvalue = round(model.f_pvalue, 3)
        
        annotation_text = f"R² = {r_squared:.2f}\nP = {f_pvalue:.2f}"
        
        # Add annotation
        ax.annotate(annotation_text, 
                    xy=(X[-1]-0.04, y[-1]/2),
                    xytext=(10, 0), 
                    textcoords='offset points',
                    color=color,
                    fontsize=20,
                    ha='left', 
                    va='center',
                    bbox=dict(boxstyle='round,pad=0.5', fc='none', ec='none', alpha=1))

    # ax.set_xlabel("Median Brain-Blood Correlation (BECon)", fontsize=22)
    # ax.set_ylabel('Number of Papers', fontsize=18)
    ax.set_title(cat.capitalize(), fontsize=28, pad=15)
    ax.set_xlim(-0.05, 0.6)
    ax.set_xticks(np.arange(0.0, 0.7, 0.1))
    ax.tick_params(axis='x', which='major', length=5)
    ax.set_yticks(range(1, max_papers + 1))
    ax.set_ylim(0, max_papers + 1)
    ax.tick_params(axis='both', which='major', labelsize=25)
    
    plot_idx += 1

axs[-1].axis('off')
# Add legend to the last subplot
handles = [plt.Line2D([0], [0], marker='D', color='w', markerfacecolor="k", 
                      label="CoRSIV", markersize=15)]
handles.append(plt.Line2D([0], [0], marker='o', color='w', label='Control', markersize=20, markerfacecolor='grey'))

axs[-1].legend(handles=handles, loc='center', fontsize=24, frameon=False)

x = 0
fig.add_artist(plt.Line2D([x, x], [0.06,0.94], transform=fig.transFigure, color='black', linestyle='-', linewidth=2))
fig.text(x-0.04, 0.5, 'Number of Papers', va='center', rotation='vertical', fontsize=24)

# Add a horizontal line
y = 0
fig.add_artist(plt.Line2D([0.04, 0.96], [y, y], transform=fig.transFigure, color='black', linestyle='-', linewidth=2))
# Add vertical text to the left of the vertical line
fig.text(0.5, y-0.04, "Median Brain-Blood Correlation (BECon)", ha='center', rotation='horizontal', fontsize=24)
plt.tight_layout()

output = f"{SFIG_PATH}/becon_regression_all_categories.svg"
plt.savefig(output, format="svg")

In [None]:
manual = pd.read_csv("../permutation_testing/adhd_manual_probes.csv")
automtated = pd.read_csv("probe/neurological_all_probes.csv")
automtated["Filtered Mesh Term"] = automtated["Filtered Mesh Term"].apply(lambda x: [term.strip() for term in x.split("|")])
automtated = automtated[automtated["Filtered Mesh Term"].apply(lambda x: "Attention Deficit Disorder with Hyperactivity" in x)][["pmcid", "probeId"]]
# print(manual.shape, automtated.shape)
# print(set(automtated["pmcid"].unique()) - set(manual["pmcid"].unique()))
m = pd.merge(manual, automtated, how="outer", indicator=True)
# Create counts for Venn diagram
left_only = len(m[m['_merge'] == 'left_only'])
right_only = len(m[m['_merge'] == 'right_only']) 
both = len(m[m['_merge'] == 'both'])
from matplotlib_venn import venn2

# Create and plot Venn diagram
plt.figure(figsize=(8,8))
venn2(subsets=(left_only, right_only, both), 
      set_labels=('Manual', 'Automated'),
      set_colors=('lightblue', 'lightgreen'))
plt.title('Overlap between Manual and \nAutomated Probe Instances', pad=20, fontsize=24)
# plt.show()
plt.savefig(f"{SFIG_PATH}/venn_adhd.svg", format="svg")
unique_to_automated = m[m["_merge"] == "right_only"].iloc[:,0:2]


In [165]:
manual = pd.read_csv("pubmed_search/adhd_test/probes/all_ADHD_probes_uptodate.csv")
manual = manual[manual["probeId"].str.startswith("cg") | manual["probeId"].str.startswith("ch.")]
manual = manual[manual["probeId"].isin(epic_probe_list) | manual["probeId"].isin(hm450_probe_list)]
def to_float(x):
    try:
        return float(x)
    except:
        if x == "<0.05":
            return 0.049
        if x == "pass":
            return 0.049
        if x == "not pass" or x == " ":
            return math.nan
        if x == "> 0.1":
            return 0.11
        elif x == "≤ 0.1":
            return 0.09
        elif x == "≤ 0.05":
            return 0.049
        elif x == "≤0.01":
            return 0.009
        return x
manual["p-value"] = manual["p-value"].apply(to_float)
manual["q-value"] = manual["q-value"].apply(to_float)
manual["adj-p-value"] = manual["adj-p-value"].apply(to_float)
keep = pd.read_csv("pubmed_search/adhd_test/probes/all_ADHD_source_uptodate.csv")
manual = pd.merge(manual, keep, on=["pmcid", "Notes", "From", "Title"])
# manual.drop_duplicates(subset=["pmcid", "probeId"], inplace=True)

qc = pd.merge(unique_to_automated, manual, on=["pmcid", "probeId"], how="left")
# qc.to_csv("pubmed_search/adhd_test/probes/probes_unique_to_automated.csv", index=False)

In [None]:
unique_to_automated[unique_to_automated["pmcid"]=='PMC5540511']

In [None]:
qc = pd.read_csv("pubmed_search/adhd_test/probes/probes_unique_to_automated.csv")


In [None]:
qc = pd.read_csv("pubmed_search/adhd_test/probes/probes_unique_to_automated.csv").iloc[:,-3:]
# Calculate total counts for each column
excluded_count = qc['Excluded'].sum()
nonsig_count = qc['Nonsignificant'].sum() 

# Create pie chart
plt.figure(figsize=(4,4))
plt.pie([excluded_count, nonsig_count], 
        labels=[f'Excluded\n({excluded_count} probes)', f'Not Significant\n({nonsig_count} probes)'],
        autopct='%1.1f%%',
        colors=['#fbb4ae', '#b3cde3'])
plt.title('Reasons for Probe Mismatch \nBetween Automated and Manual Approaches')
plt.savefig(f"{SFIG_PATH}/probe_mismatch_reasons.svg", format="svg")

In [None]:
df = pd.read_csv("../permutation_testing/adhd_manual_probes.csv")
print(df.shape)
c = Counter(df["probeId"])
l1, l2, l3, p, p2 = calculate_points(c, manual)
paper, r = plot_enrichment([l1, l2, l3, p, p2], "ADHD (Manual)", 7, output=None, format="svg")

In [33]:
import pandas as pd
from collections import Counter
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from math import log10
from statistics import mean
from matplotlib import ticker

def calculate_points(count_dictionary, input_cat):
    paper_threshold_count = []
    corsiv_count = []
    control_count = []
    corsiv_pct = []
    control_pct = []
    i = 1
    if not count_dictionary:
        return [], [], [], 0, 0
    max_probe_count = max(count_dictionary.values())
    probe_res_count = 0
    corsiv_paper_set = set()
    while i <= max_probe_count:
        dummy_dict = {key:count for key, count in count_dictionary.items() if count == i}
        logval = len(dummy_dict)
        paper_threshold_count.append((i, logval))
        i += 1
    probe_cutoff = max_probe_count
    for i in range(paper_threshold_count[-1][0], 0, -1):
        if paper_threshold_count[i-1][1] < 10:
            continue
        probe_cutoff = i
        break
    paper_threshold_count = paper_threshold_count[:probe_cutoff]
    i = 1
    if isinstance(input_cat, str):
        cat = "metabolic_diseases" if input_cat == "metabolic" else input_cat
        df = pd.read_csv(f"probe/{cat}_all_probes.csv")
    else:
        df = input_cat
    while i <= probe_cutoff:
        dummy_dict = {key:count for key, count in count_dictionary.items() if count == i}
        overlapping_probes = set(dummy_dict.keys()).intersection(CORSIV_PROBE_LIST)
        corsiv_overlap_count = len(overlapping_probes)
        corsiv_overlap = corsiv_overlap_count / len(dummy_dict) *100 if len(dummy_dict) > 0 else 0
        corsiv_pct.append((i, corsiv_overlap))
        corsiv_count.append((i, corsiv_overlap_count))
        curr_control = []
        curr_control_count = []
        for control_set in CONTROLS:
            control_overlap_count = len(set(dummy_dict.keys()).intersection(control_set))
            control_overlap = control_overlap_count / len(dummy_dict) *100 if len(dummy_dict) > 0 else 0
            curr_control.append(control_overlap)
            curr_control_count.append(control_overlap_count)
        control_pct.append((i, mean(curr_control)))
        control_count.append((i, mean(curr_control_count)))
        if corsiv_overlap > mean(curr_control):
            probe_res_count += corsiv_overlap_count
            corsiv_paper_set.update(df[df["probeId"].isin(overlapping_probes)]["pmcid"].unique())
        i += 1
    return paper_threshold_count, corsiv_count, control_count, corsiv_pct, control_pct, probe_res_count, len(corsiv_paper_set)

def plot_enrichment(counts, title_text, cat_index, output=None, show_ratio=True, format="svg", show_legend=False, show_figure=True, show_y_label=True, box_placement=(0.15, 0.9), export_all=False):
    if counts[0] == []:
        return 0, 0
    fig = plt.figure(figsize=(6, 6))
    gs = gridspec.GridSpec(2, 1, height_ratios=[1, 2])
    ax1 = fig.add_subplot(gs[0])
    ax2 = fig.add_subplot(gs[1], sharex=ax1)     
    
    x_values, y_values = zip(*counts[0])
    ax1.plot(x_values, y_values, marker='o', linestyle='-', color = COLOR_TEMPLATE[cat_index])
    for x, y in zip(x_values, y_values):
        ax1.annotate(f'{y:,}', (x, y), textcoords="offset points", xytext=(0,10), ha='center', fontsize=14)
    ax1.set_title(title_text.capitalize() if " " not in title_text else title_text, fontsize=30)# .capitalize()
    max_y = int(log10(max(y_values))) + 1
    yticks = [10**i for i in range(1, max_y + 1)]
    ax1.set_yscale('log')
    ax1.set_yticks(yticks)
    ax1.set_yticklabels([f'$10^{i}$' for i in range(1, max_y + 1)])
    # ax1.yaxis.set_minor_locator(ticker.LogLocator(subs=range(2, 10)))
    ax1.set_xticks(range(1, int(max(x_values))+1))
    ax1.tick_params(axis='x', bottom=True, direction='inout', labelbottom=False, length=10)
    ax1.tick_params(axis='y', which='both', left=True, labelleft=True)
    x_values, y_values = zip(*counts[1])
    x_values_, y_values_ = zip(*counts[5])
    ax2.plot(x_values, y_values, marker='o', linestyle='-', color = COLOR_TEMPLATE[cat_index], label="CoRSIV")
    for i, (x, y) in enumerate(zip(x_values, y_values)):
        # ax2.annotate(f'{y:.1f}%, {y_values_[i]:,} probes', (x, y), textcoords="offset points", xytext=(0,10), ha='center', fontsize=14)
        ax2.annotate(f'{y:.1f}%', (x, y), textcoords="offset points", xytext=(0,10), ha='center', fontsize=14)
        ax2.annotate(f'{y_values_[i]:,} probes', (x, y), textcoords="offset points", xytext=(0,0), ha='center', fontsize=14)

    x_values, y_values = zip(*counts[2])
    x_values_, y_values_ = zip(*counts[6])
    ax2.plot(x_values, y_values, marker='o', linestyle='-', color = "grey", label="Control")
    for i, (x, y) in enumerate(zip(x_values, y_values)):
        # ax2.annotate(f'{y:.1f}%, {y_values_[i]:,} probes', (x, y), textcoords="offset points", xytext=(0,10), ha='center', fontsize=14)
        ax2.annotate(f'{y:.2f}%' if y > 0.1 else f'{y:.3f}%', (x, y), textcoords="offset points", xytext=(0,10), ha='center', fontsize=14)
        ax2.annotate(f'{y_values_[i]:,} probes', (x, y), textcoords="offset points", xytext=(0,0), ha='center', fontsize=14)    
    ax2.set_xticks(range(1, int(max(x_values))+1))
    ax2.set_xlabel('Number of Papers Reporting Probe', fontsize=18)
    if show_y_label:
        ax1.set_ylabel('Number of Probes', fontsize=16)
        ax2.set_ylabel('Overlapping Probes (%)', fontsize=16)
        ax2.tick_params(axis='y', labelsize=18)
    ax2.tick_params(axis='x', which='both', bottom=True, labelbottom=True, labelsize=18)
    if not show_y_label:
        decimals = 1 if any(tick % 1 != 0 for tick in ax2.get_yticks()) else 0
        ax2.yaxis.set_major_formatter(ticker.PercentFormatter(decimals=decimals))
    enrichment_ratio = 0
    if show_ratio:
        corsiv_ratio = sum([i*pct for i, pct in counts[1]])
        control_ratio = sum([i*pct for i, pct in counts[2]])
        enrichment_ratio = round(corsiv_ratio / control_ratio, 1) if control_ratio != 0 else -1
        middle_text = f"Ratio = {enrichment_ratio}\n{counts[3]:,} Probes\n{counts[4]:,} Papers" if enrichment_ratio > 1 else "No Enrichment"
        
        ax2.annotate(middle_text,
                    xy=box_placement,  # Adjust position to account for bbox size
                    xycoords='axes fraction',  # Use axes fraction for coordinates
                    ha='center',  # Center the text horizontally
                    va='top',     # Align text to the top of the box
                    fontsize=16,  # Font size
                    color="red",
                    bbox=dict(facecolor='white', edgecolor='red', linewidth=2, boxstyle='square,pad=0.5'))  # Add thicker square box around text
    plt.tight_layout()
    plt.subplots_adjust(hspace=0)
    
    if show_legend:
        ax2.legend(bbox_to_anchor=(0.35, 0.5), shadow=False, frameon=False)

    if output is None:
        if show_figure:
            plt.show() 
    else:
        if export_all or (enrichment_ratio > 1 and len(counts[0]) > 1):
            fig.savefig(output, format=format, bbox_inches='tight')
    plt.close()
    return len(counts[0]), enrichment_ratio
total_probe, c1count, c2count, c1pct, c2pct, p, p2 = calculate_points(c, df)
output = f"{SFIG_PATH}/annotated_example.svg"
paper, r = plot_enrichment([total_probe, c1pct, c2pct, p, p2, c1count, c2count], "Neurodevelopmental Disorders", 6, output=output, show_figure=True)

In [69]:
from collections import defaultdict
# Initialize an empty dictionary to store the results
mesh_ttoc = defaultdict(set) #term:code

mesh_tree = {} #big category:set of subcategories
file_path = '../humanData/database/mtrees2024.txt'

# Read the lines from the file
with open(file_path, 'r') as file:
    for line in file:
        # Split each line into A and B
        parts = line.strip().split(';')
        if len(parts) == 2:
            term, code = parts
            mesh_ttoc[term].add(code)
        else:
            print(parts)
            
mesh_ctot =  {v:k for k, vs in mesh_ttoc.items() for v in vs} #code:term


def next_layer(given_string, string_list):
    for prefix in string_list:
        if given_string.startswith(prefix) and len(given_string) > len(prefix) and given_string[len(prefix)] == '.' and '.' not in given_string[len(prefix)+1:]:
            return True
    return False


# keywords = ["Endocrine System Diseases"]
keywords = [cc for c in CATEGORIES for cc in c]

for kw in keywords:
    mesh_tree[kw] = set([k for k, v in mesh_ttoc.items() for c in v if next_layer(c, mesh_ttoc[kw])])


In [None]:
import plotly.graph_objects as go
for key_category in CATEGORY_NAMES:
    if key_category == "metabolic":
        df = pd.read_csv("probe/metabolic_diseases_all_probes.csv")
        df["Filtered Mesh Term"] = df["Filtered Mesh Term"].apply(eval)
    else:
        df = pd.read_csv(f"probe/{key_category}_all_probes.csv")
        df["Filtered Mesh Term"] = df["Filtered Mesh Term"].apply(lambda x: [term.strip() for term in x.split("|")])
    target_categories = set([])
    for c in CATEGORIES[CATEGORY_NAMES.index(key_category)]:
        target_categories |= mesh_tree[c]
    df["Filtered Mesh Term"] = df["Filtered Mesh Term"].apply(lambda x: [term for term in x if term in target_categories])
    c = Counter(df["probeId"])
    c = {k:v for k, v in c.items() if v >= 2 and k in CORSIV_PROBE_LIST}
    df = df[df["probeId"].isin(c)]
    # categories = list(set([term for terms in df["Filtered Mesh Term"] for term in terms]))
    # probes = list(set(df["probeId"].unique()))
    affiliations = [(row['probeId'], term) 
                        for _, row in df.iterrows()
                        for term in row['Filtered Mesh Term']]

    nodes = list(set([aff[0] for aff in affiliations] + [aff[1] for aff in affiliations]))
    node_indices = {node: i for i, node in enumerate(nodes)}
    links = {
        'source': [node_indices[aff[0]] for aff in affiliations],  # Student indices
        'target': [node_indices[aff[1]] for aff in affiliations],  # Group indices
        'value': [1 for _ in affiliations]  # All links have equal weight (1)
    }

    # Define the Sankey diagram
    fig = go.Figure(go.Sankey(
            node=dict(
                pad=15,
                thickness=20,
                line=dict(color="black", width=0.5),
                label=nodes
            ),
            link=dict(
                source=links['source'],
                target=links['target'],
                value=links['value']
            )
        )
    )

    # Update layout and display
    fig.update_layout(title_text=key_category.capitalize(), font_size=10, width=1000, height=2000)
    fig.write_html(f"{SFIG_PATH}/alluvial/{key_category}.html")
    # fig.show()


In [None]:
import plotly.graph_objects as go
dfs2 = []
for key_category in ["metabolic", "endocrine"]:
    if key_category == "metabolic":
        df = pd.read_csv("probe/metabolic_diseases_all_probes.csv")
        df["Filtered Mesh Term"] = df["Filtered Mesh Term"].apply(eval)
    else:
        df = pd.read_csv(f"probe/{key_category}_all_probes.csv")
        df["Filtered Mesh Term"] = df["Filtered Mesh Term"].apply(lambda x: [term.strip() for term in x.split("|")])
    target_categories = set([])
    for c in CATEGORIES[CATEGORY_NAMES.index(key_category)]:
        target_categories |= mesh_tree[c]
    df["Filtered Mesh Term"] = df["Filtered Mesh Term"].apply(lambda x: [term for term in x if term in target_categories])
    c = Counter(df["probeId"])
    c = {k:v for k, v in c.items() if v <= 5 and v >= 2 and k in CORSIV_PROBE_LIST}
    df = df[df["probeId"].isin(c)]
    df = df.explode('Filtered Mesh Term')
    dfs2.append(df)
    df_pivoted = df.groupby(['probeId', 'Filtered Mesh Term']).size().reset_index(name='count')
    df_pivoted = df_pivoted.pivot(index='probeId', columns='Filtered Mesh Term', values='count').fillna(0)
    df_pivoted.rename_axis(index=None, columns=None, inplace=True)
    dfs2.append(df_pivoted)
    colors = [(1, 1, 1),
            (0, 0, 1)]
    cmap = LinearSegmentedColormap.from_list("custom_blue", colors, N=100)
    g = sns.clustermap(df_pivoted, method='ward', metric='euclidean', cmap=cmap, figsize=(8, 6), annot=False, vmax=5)
    cbar = g.ax_heatmap.collections[0].colorbar
    cbar.set_ticks([0, 5])
    cbar.set_ticklabels(['0', '5'])
    g.ax_row_dendrogram.set_visible(False)
    for line in g.ax_col_dendrogram.collections:
        line.set_linewidth(2)

    cbar.set_label("Number of \nPapers", rotation=90, fontsize=16)
    g.ax_heatmap.set_yticks([])
    g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xticklabels(), rotation=90, ha='center', va='top', fontsize=20)
    g.ax_heatmap.tick_params(axis='x', which='major', pad=10)
    g.figure.suptitle(f"{len(df_pivoted):,} CoRSIV Probes Reported in ≥2 {key_category.capitalize()} Papers", y=1.05, fontsize=26)
    g.savefig(f"{SFIG_PATH}/alluvial/{key_category}_subcategory_2papers.jpeg", format="jpeg", dpi=300)