In [None]:
"""
Author: Wen-Jou Chang
Baylor College of Medicine

This script is used to generate the main figures in the paper.
"""

In [29]:

"""
Initialization
"""
# imports
import os
import pandas as pd
from collections import Counter
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from collections import defaultdict
from graphviz import Digraph
import numpy as np
import re
import seaborn as sns
import sys
from matplotlib.colors import LinearSegmentedColormap
import importlib, util

PROJECT_PATH = "YOUR_PATH"
FIGURE_PATH = "YOUR_PATH"
os.chdir(PROJECT_PATH)
importlib.reload(util)
from util import CATEGORY_NAMES, COLOR_TEMPLATE, CORSIV_PROBE_LIST, CONTROLS, read_in_probes, calculate_points, plot_enrichment, breakdown, export_paper


epic = pd.read_csv("../humanData/database/EPIC.hg38.txt", sep="\t", header=None)
epic_probe_list = set(epic.iloc[:,3])
hm450 = pd.read_csv("../humanData/database/HM450.hg38.txt", sep="\t", header=None)
hm450_probe_list = set(hm450.iloc[:,3])
illumina = epic_probe_list.union(hm450_probe_list)

plt.rcParams['font.family'] = 'Helvetica'
plt.rcParams['font.size'] = 16
plt.rcParams['axes.linewidth'] = 2  # Thicker outer box
plt.rcParams['xtick.major.width'] = 2  # Thicker x-axis ticks
plt.rcParams['ytick.major.width'] = 2  # Thicker y-axis ticks

In [None]:
# Fig 1A: mesh hierarchy tree

# Initialize mesh tree
mesh_ttoc = defaultdict(set) #term:code
file_path = '../humanData/database/mtrees2024.txt'
# Read the lines from the file
with open(file_path, 'r') as file:
    for line in file:
        # Split each line into A and B
        parts = line.strip().split(';')
        if len(parts) == 2:
            term, code = parts
            mesh_ttoc[term].add(code)
        else:
            print(parts)
mesh_ctot =  {v:k for k, vs in mesh_ttoc.items() for v in vs}


def visualize(input_nodes, output_name=None):
    input_nodes = {y:mesh_ctot[y] for y in input_nodes}
    print(input_nodes)
    dot = Digraph()
    dot.attr(rankdir='LR')  # Keep layout horizontal (Left-to-Right)

    edges = set()
    nodes = {}
    
    for code in input_nodes:
        if code.startswith("C23"):
            continue
        parts = code.split('.')
        for i in range(1, len(parts) + 1):
            partial_code = '.'.join(parts[:i])
            if partial_code not in nodes and partial_code in mesh_ctot:
                nodes[partial_code] = mesh_ctot[partial_code]

    def add_code_edges(code):
        parts = code.split('.')
        for i in range(1, len(parts)):
            parent_code = '.'.join(parts[:i])
            child_code = '.'.join(parts[:i+1])
            edges.add((parent_code, child_code))
            
    def get_node_color(node):
        for category, color in category_color_map.items():
            if category == nodes[node]:
                return color
        return "#000000"  # Default color for nodes not in any category
    
    for code, term in nodes.items():
        add_code_edges(code)

    for n in nodes:
        if "." not in n:
            edges.add((n[0].capitalize(), n))
    nodes["C"] = "Diseases"
    nodes["F"] = "Psychiatry and Psychology"
            
    # Add nodes for each unique code with term name as label
    for code, term in sorted(nodes.items(), key=lambda x :graph_order[x[1]]):
        print(code, term)
        node_color = get_node_color(code)
        if node_color == "#000000":
            dot.node(code, label=term, shape='box', fontname='Helvetica', align='right', penwidth='1.5')
        else:
            dot.node(code, label=term, shape='box', style='filled', fillcolor=node_color, fontname='Helvetica', fontcolor='white', rank='same', penwidth='1.5')

    # Add edges for hierarchical relationships
    for parent_code, child_code in edges:
        dot.edge(parent_code, child_code, penwidth='1.5')
        

    dot.attr(splines='ortho')
    dot.attr(nodesep='0.2', ranksep='0.5')
    if output_name:
        dot.render(output_name, format='svg')
    else:
        dot.view()

categories = [["Neoplasms"], ["Cardiovascular Diseases"], ["Digestive System Diseases"], ["Endocrine System Diseases"], ["Hemic and Lymphatic Diseases"], ["Immune System Diseases"], ["Metabolic Diseases"], ["Mental Disorders", "Nervous System Diseases"], ["Obesity"], ["Respiratory Tract Diseases"], ["Urogenital Diseases"]]
graph_order = defaultdict(str)
for i, c in enumerate(categories):
    for cc in c:
        graph_order[cc] = CATEGORY_NAMES[i]
graph_order["Mental Disorders"] = "1"
graph_order["Psychiatry and Psychology"] = "2"
graph_order["Metabolic Diseases"] = "3"
graph_order["Nutrition Disorders"] = "4"
category_color_map = {}
for category_list, color in zip(categories, COLOR_TEMPLATE):
    for category in category_list:
        category_color_map[category] = color
nodes = set([y for clist in categories for c in clist for y in mesh_ttoc[c]])
output_name = f"{FIGURE_PATH}/Fig1/mesh_tree_hierarchy"
# Modify the visualize function call to adjust edge routing
visualize(nodes, output_name=None)


In [None]:
# Fig 1C: number of studies per category

articles = set()
ilumina_studies = []

for cat in CATEGORY_NAMES:
    kw = "metabolic_diseases" if cat == "metabolic" else cat
    probes = pd.read_csv(f"probe/{kw}_all_probes.csv")
    papers = pd.read_csv(f"pubmed_search/{kw}.csv")
    papers = papers[papers["PMCID"].isin(probes["pmcid"])]
    qc = len(set(probes["pmcid"]))
    ilumina_studies.append(papers.shape[0])
    articles |= set(papers["PMCID"])
print(len(articles))
plt.figure()
bars = plt.barh([c.capitalize() for c in CATEGORY_NAMES], ilumina_studies, color=COLOR_TEMPLATE)
for i, bar in enumerate(bars):
    plt.text(bar.get_width()+10, bar.get_y() + bar.get_height() / 2, f'{bar.get_width():,}', 
            va='center', ha='left', fontsize=14)
plt.xlabel('Number of Papers Included')

plt.tight_layout()
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)

plt.xlim(0, max(ilumina_studies) * 1.2)
plt.gca().invert_yaxis()
plt.savefig(f"{FIGURE_PATH}/Fig1/illumina_studies_num.svg", format="svg", bbox_inches="tight")


In [None]:
# Fig 2A: cancer decay plot illustration

c = read_in_probes("cancer")

paper_threshold_count = []
i = 1
max_probe_count = max(c.values())
while i <= max_probe_count:
    dummy_dict = {key:count for key, count in c.items() if count == i}
    paper_threshold_count.append((i, len(dummy_dict)))
    i += 1

probe_cutoff = max_probe_count
for i in range(paper_threshold_count[-1][0], 0, -1):
    if paper_threshold_count[i-1][1] < 10:
        continue
    probe_cutoff = i
    break
paper_threshold_count = paper_threshold_count[:probe_cutoff]

fig = plt.figure(figsize=(5,4))
     
x_values, y_values = zip(*paper_threshold_count)
plt.plot(x_values, y_values, marker='o', linestyle='-', color=COLOR_TEMPLATE[0])

plt.title("Cancer", fontsize=20)
plt.ylabel('Number of Probes', fontsize=16)
plt.xlabel('Number of Papers Reporting Probe', fontsize=16)
# plt.yscale('log', subs=[])
plt.yticks(fontsize=16)
plt.ylim(10, 100000)
plt.xticks(range(2, 11, 2), fontsize=16)
plt.tight_layout()
# plt.show()
# plt.savefig(f"{FIGURE_PATH}/Fig2/cancer_decay.svg", format="svg")


In [None]:
# Fig 2B: cancer disgenet results
def get_color(actual_category, category_list):
    colors = []
    if not isinstance(category_list, str):
        return '#D3D3D3'
    ind = [i.lower().strip() for i in category_list.split(",")]
    for i in ind:
        if i == actual_category:
            return f"{COLOR_TEMPLATE[CATEGORY_NAMES.index(i)]}50"
        if i in CATEGORY_NAMES:
            colors.append(f"{COLOR_TEMPLATE[CATEGORY_NAMES.index(i)]}50")
    return colors[0] if colors else '#D3D3D3'
    
ref = pd.read_csv("disgenet_terms_annotated.csv", names=["Name", "Category"], skiprows=1)
category = "cancer"
size = 4
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(size*2, 4))
axes = [ax1, ax2]
for k in range(1, 3):
    df = pd.read_csv(f"go_probe/go_results/manifest/{category}_all_{k}papers_unique.csv")
    df = df[df["Database"] == "DisGeNET"]
    df = df.merge(ref, on="Name", how="left")
    df["Name"] = df["Name"].apply(lambda x: re.sub(r'\([^)]*\)', '', x))
    if df.empty:
        continue
    df = df.sort_values(["Adjusted p-value", "P-value"], ascending=[True, True])
    combined_top = df.head(10)
    log_p = -np.log10(combined_top['Adjusted p-value'])
    category_colors = combined_top['Category'].apply(lambda x: get_color(category, x))
    
    bars = axes[k-1].barh(range(len(combined_top)), log_p, height=0.8, 
                color=category_colors)
    if k == 1:
        threshold = -np.log10(0.05)
        axes[k-1].axvline(x=threshold, color='black', linestyle='--', alpha=1)
        for i, (_, row) in enumerate(combined_top.iterrows()):
            axes[k-1].text(log_p[i]+0.5, i, f"{row['Name']}", ha='left', va='center', fontsize=12, color='black')
    else:   
        for i, (_, row) in enumerate(combined_top.iterrows()):
            axes[k-1].text(1, i, f"{row['Name']}", ha='left', va='center', fontsize=12, color='black')
    df2 = pd.read_csv(f"heatmap/probe_based_heatmap_{k}papers.csv", index_col=0)
    unique_probes = len(set(df2[(df2[category] != 0) & (df2.drop(columns=[category]) == 0).all(axis=1)].index))

    axes[k-1].set_title(f"{unique_probes:,} Probes in ≥ {k} {category.capitalize()} Paper{'s' if k > 1 else ''}", fontsize=16)
    axes[k-1].set_xlabel('-log₁₀(Adjusted P-value)', fontsize=16)
    
    axes[k-1].set_ylabel('')
    axes[k-1].set_yticks([])
    axes[k-1].set_yticklabels([])
    axes[k-1].invert_yaxis()
    axes[k-1].set_yticks([])
    axes[k-1].tick_params(axis='x', which='both', labelsize=16)
    axes[k-1].set_xlim(0, 11)
    axes[k-1].set_xticks(range(0, 30, 5))
plt.tight_layout()
plt.savefig(f"{FIGURE_PATH}/Fig2/{category}_disgenet.svg", format="svg")


In [None]:
# Figure 2C: probe based heatmap

# Increase the recursion limit
sys.setrecursionlimit(40000)
paper_threshold = 2
df = pd.read_csv(f"probe_based_heatmap_{paper_threshold}papers_0107.csv", index_col=0)

# # Randomly select 100 rows
# df = df.sample(n=2000, random_state=42)

# Create a custom colormap
colors = [(1, 1, 1),
          (0, 0, 1)]

cmap = LinearSegmentedColormap.from_list("custom_blue", colors, N=100)
g = sns.clustermap(df, method='ward', metric='euclidean', cmap=cmap, figsize=(8, 6), annot=False, vmax=10)

# Remove existing colorbar
g.ax_heatmap.collections[0].colorbar.remove()

# Create new colorbar in upper right
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
ax_ins = inset_axes(g.ax_heatmap, width="8%", height="40%", loc='upper right', 
                    bbox_to_anchor=(0.15, 0.3, 1, 1), bbox_transform=g.ax_heatmap.transAxes)
cbar = plt.colorbar(g.ax_heatmap.collections[0], cax=ax_ins, orientation='vertical')
cbar.set_ticks([0, 5, 10])
cbar.set_ticklabels(['0', '5', '≥10'])
cbar.set_label("Number of\nPapers", rotation=90, fontsize=16, labelpad=10)

g.ax_row_dendrogram.set_visible(False)
for line in g.ax_col_dendrogram.collections:
    line.set_linewidth(2)

g.ax_heatmap.set_yticks([])
g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xticklabels(), rotation=90, ha='center', va='top', fontsize=20)
g.ax_heatmap.tick_params(axis='x', which='major', pad=10)
g.figure.suptitle(f"{len(df):,} Probes\nReported in ≥ {paper_threshold} Papers", y=0.8, fontsize=20, rotation=90, x=0.1)
# plt.tight_layout()
# plt.show()
plt.savefig(f"{FIGURE_PATH}/Fig2/heatmap.jpeg", format="jpeg")



In [None]:
ref = pd.read_csv("go_probe/go_results/nonunique/disgenet_terms_annotated_2.csv", names=["Name", "Category"])
ref[ref["Name"] == "Blood basophil count"]

In [None]:
paper_threshold = 2
df2 = pd.read_csv(f"heatmap/probe_based_heatmap_{paper_threshold}papers.csv", index_col=0)
unique_probes = len(set(df2[(df2["cancer"] != 0) & (df2.drop(columns=["cancer"]) == 0).all(axis=1)].index))
print(unique_probes)

In [None]:
# Figure 2D-F: disgenet results for neurological, immune, and cardiovascular

ref = pd.read_csv("go_probe/go_results/manifest/disgenet_terms_annotated_2.csv", names=["Name", "Category"])
# specific_categories = ["cancer", "neurological", "immune"]
paper_threshold = 2
for category in CATEGORY_NAMES:
    df = pd.read_csv(f"go_probe/go_results/manifest/{category}_all_{paper_threshold}papers_unique.csv")
    df = df[df["Database"] == "DisGeNET"]
    df["Name"] = df["Name"].apply(lambda x: re.sub(r'\([^)]*\)', '', x).strip())
    df = df.merge(ref, on="Name", how="left")
    if df.empty:
        continue
    df = df.sort_values(["Adjusted p-value", "P-value"], ascending=[True, True])
    combined_top = df.head(10)
    log_p = -np.log10(combined_top['Adjusted p-value'])

    fig, ax = plt.subplots(figsize=(6, 4)) if category != "metabolic" else plt.subplots(figsize=(6, 2.5))
    category_colors = combined_top['Category'].apply(lambda x: get_color(category, x))
    
    bars = ax.barh(range(len(combined_top)), log_p, height=0.8, 
                color=category_colors)

    # # Add labels to the bars
    for i, (_, row) in enumerate(combined_top.iterrows()):
        x_pos = 0.05 * max(log_p)  # Calculate position based on max x-axis value
        ax.text(x_pos, i, f"{row['Name']}", ha='left', va='center', fontsize=12)
    df2 = pd.read_csv(f"heatmap/probe_based_heatmap_{paper_threshold}papers.csv", index_col=0)
    unique_probes = len(set(df2[(df2[category] != 0) & (df2.drop(columns=[category]) == 0).all(axis=1)].index))
    # Customize the plot
    ax.set_title(f"{unique_probes:,} Probes in ≥ {paper_threshold} {category.capitalize()} Paper{'s' if paper_threshold>1 else ''}", y=1.03, fontsize=20)
    ax.set_xlabel('-log₁₀(Adjusted P-value)', fontsize=20)
    ax.set_ylabel('')
    ax.set_yticks([])
    ax.set_yticklabels([])
    ax.invert_yaxis()  # Invert y-axis to show highest significance at the top
    ax.set_yticks([])
    ax.tick_params(axis='x', which='both', labelsize=20)
    plt.tight_layout()
    if category in ["immune", "neurological", "cardiovascular"]:
        plt.savefig(f"{FIGURE_PATH}/Fig2/{category}_{paper_threshold}papers_disgenet.svg", format="svg")


In [None]:
# Figure 2G: gene region histograms
from scipy.stats import chi2_contingency, chi2

df = pd.read_csv("../humanData/corsiv_manifest.csv")

def parse_gene_group(row):
    if row["UCSC_RefGene_Group"] == "Intergenic":
        return pd.Series([set([("_", "Intergenic")]), set([("Intergenic")])])
    names = row["UCSC_RefGene_Name"].split(";")
    groups = row["UCSC_RefGene_Group"].split(";")
    gene_groups = set(zip(names, groups))
    group_set = set(groups)
    return pd.Series([gene_groups, group_set])

def cleanup(row):
    if len(row) > 1:
        row.discard("Intergenic")
    return row

df[["gene_groups", "groups"]] = df.apply(parse_gene_group, axis=1)
df = df[["corsiv_id", "Probe_ID", "gene_groups", "groups"]]
all_corsiv_df = df.groupby('corsiv_id')['groups'].apply(lambda x: set().union(*x)).reset_index()
all_corsiv_df["groups"] = all_corsiv_df["groups"].apply(cleanup)
illumina_covered_corsiv = all_corsiv_df.shape[0]
all_corsiv_tally = {}
for row in all_corsiv_df['groups']:
    for group in row:
        all_corsiv_tally[group] = all_corsiv_tally.get(group, 0) + 1
if "ExonBnd" in all_corsiv_tally:
    del all_corsiv_tally["ExonBnd"]

def tally_count(category, paper_threshold=2):
    c = read_in_probes(category)
    c = {k:v for k, v in c.items() if v >= paper_threshold}
    cat_df = df[df["Probe_ID"].isin(c)]
    cat_df = cat_df.groupby('corsiv_id')['groups'].apply(lambda x: set().union(*x)).reset_index()
    cat_df["groups"] = cat_df["groups"].apply(cleanup)
    target_corsiv_tally = {}
    for row in cat_df['groups']:
        for group in row:
            target_corsiv_tally[group] = target_corsiv_tally.get(group, 0) + 1
    if "ExonBnd" in target_corsiv_tally:
        del target_corsiv_tally["ExonBnd"]
    return target_corsiv_tally, cat_df.shape[0]

def plot_gene_hist_all_categories(all_corsiv_tally, category_names, output_path=None, show_figure=True, format="pdf", colors=COLOR_TEMPLATE):
    keys = ['Intergenic', 'TSS1500', 'TSS200', "5'UTR", '1stExon', 'Body', "3'UTR"]
    all_total = sum(all_corsiv_tally.values())
    a_pct = [all_corsiv_tally[key] / all_total * 100 for key in keys]
    tss_all = []
    intergenic_all = []
    tss_categories = []
    intergenic_categories = []
    category_data = []
    for _, name in enumerate(CATEGORY_NAMES):
        d, count = tally_count(name)
        s_total = sum(d.values())
        s_pct = [d.get(key, 0) / s_total * 100 for key in keys]
        category_data.append((name, s_pct, count))
        tss_categories.append(d.get("TSS200", 0))
        intergenic_categories.append(d.get("Intergenic", 0))
        tss_all.append(all_corsiv_tally.get("TSS200", 0) * count / 1607)
        intergenic_all.append(all_corsiv_tally.get("Intergenic", 0) * count / 1607)
    plt.figure(figsize=(15.5, 5))
    bar_width = 0.07
    index = np.arange(len(keys))
    tss_categories = np.array(tss_categories)
    intergenic_categories = np.array(intergenic_categories)
    tss_all = np.array(tss_all)
    intergenic_all = np.array(intergenic_all)
    for i, (name, s_pct, count) in enumerate(category_data):
        plt.bar(index+i*bar_width, s_pct, bar_width, color=colors[i], label=f'{name.capitalize()} CoRSIVs ({count})', align='edge')

    for j, pct in enumerate(a_pct):
        plt.plot([index[j], index[j] + len(category_names) * bar_width], 
                    [pct, pct], color='black', linestyle='dashed', linewidth=2, alpha=0.8, label="All CoRSIVs (1607)" if j == 0 else "")
    
    chi2_stat = np.sum((tss_categories - tss_all)**2 / tss_all)
    p_value = 1 - chi2.cdf(chi2_stat, df=len(CATEGORY_NAMES)-1)
    print(f"TSS200: {chi2_stat:.2f}, {p_value}")
    chi2_stat = np.sum((intergenic_categories - intergenic_all)**2 / intergenic_all)
    p_value = 1 - chi2.cdf(chi2_stat, df=len(CATEGORY_NAMES)-1)
    print(f"Intergenic: {chi2_stat:.2f}, {p_value}")
    plt.xlabel('Gene Region', fontsize=26)
    plt.ylabel('Percentage (%)', fontsize=26)
    plt.xticks(index + bar_width * (len(category_names) / 2), keys, ha='center', fontsize=26)
    plt.yticks(fontsize=26)

    # plt.legend(frameon=False, loc='upper left', fontsize=10, bbox_to_anchor=(0, 1), ncol=2)
    plt.tight_layout()
    if output_path:
        plt.savefig(output_path, format=format, bbox_inches='tight')
    plt.show()
    plt.close()
    return
    
plot_categories = ["cancer", "cardiovascular", "endocrine", "immune", "metabolic", "neurological"]
plot_colors = ['#e6194B','#f58231','#469990','#2f8e3b','#0db7dd','#4363d8']#, "#000075"]#'#8298e5'
output_path = f"{FIGURE_PATH}/Fig2/all_categories_histogram.svg"

# # Adjust plot style for thicker outer box and ticks
with plt.rc_context({'axes.linewidth': 3, 'xtick.major.width': 3, 'ytick.major.width': 3}):
    plot_gene_hist_all_categories(all_corsiv_tally, CATEGORY_NAMES, output_path, show_figure=True, format="svg", colors=COLOR_TEMPLATE)


In [None]:
# Figure 3A-F: decay plots for each category

cat_probes_dict = []
for cat in CATEGORY_NAMES:
    cat_probes_dict.append(read_in_probes(cat))
    
la = (0.15, 0.9)
ra = (0.85, 0.9)
box_placement = [la if i not in [1, 8, 9] else ra for i in range(11)]

for i in [0, 3, 5, 6, 7, 10]:
    output_path = f"{FIGURE_PATH}/Fig3/{CATEGORY_NAMES[i]}.svg"
    l1, l2, l3, p, p2,_ = calculate_points(cat_probes_dict[i], CATEGORY_NAMES[i])
    show_y_label = i in [0, 6]
    show_legend = i == 0
    paper, r = plot_enrichment([l1, l2, l3, p, p2], CATEGORY_NAMES[i], i, output=output_path if i==0 else None, show_figure=True, box_placement=box_placement[i], show_y_label=show_y_label, format="svg", show_legend=show_legend)

In [None]:
# Figure 3A-F: decay plots for each category

cat_probes_dict = []
for cat in CATEGORY_NAMES:
    cat_probes_dict.append(read_in_probes(cat))
probes_ct = []
ratios = []
papers_ct = []

for i in range(len(CATEGORY_NAMES)):
    output_path = f"{FIGURE_PATH}/Fig3/{CATEGORY_NAMES[i]}.svg"
    l1, l2, l3, p, p2, _ = calculate_points(cat_probes_dict[i], CATEGORY_NAMES[i])
    paper, r = plot_enrichment([l1, l2, l3, p, p2], CATEGORY_NAMES[i], i, output=None, show_figure=True)
    probes_ct.append(p)
    ratios.append(r)
    papers_ct.append(p2)

df = pd.DataFrame({"Categories":CATEGORY_NAMES, "Enrichment Ratio":ratios, "Probes":probes_ct, "Papers":papers_ct})
df
# df.to_csv(f"{FIGURE_PATH}/Fig3/category_enrichment.csv", index=False)

In [6]:
# Figure 3G: cancer permuataion results histogram
CLOSE_TO_ZERO = 0
adjusted_threshold = 0.05 / 11
pvals = [3.08e-38, 0.04, 0.3, CLOSE_TO_ZERO, 0.44, 5.73e-94, 3.23e-45, CLOSE_TO_ZERO, 1.4e-09, 0.05, 1.42e-64]
enrcihment_ratios = [10.90, 4.69, 0.67, 53.48, 1.61, 10.08, 23.70, 25.66, 5.72, 2.24, 6.93]
ylims = [(0, 15), (0, 23), (0, 25), (0, 68), (0, 19), (0, 13), (0, 31), (0, 33), (0, 14), (0,10), (0, 10)]
yticks_list = [(0, 16, 5), (0, 24, 10), (0, 24, 10), (0, 61, 20), (0, 16, 5), (0, 14, 5), (0, 31, 10), (0, 31, 10), (0, 11, 5), (0, 10, 4), (0, 10, 4)]
PT_DIR = "../permutation_testing"
for i in range(11):
    df = pd.read_csv(f"{PT_DIR}/concatenated_results/{CATEGORY_NAMES[i]}_enrichment_after_permutations.bed", sep="\t")[:]
    df = df.sort_values(by="enrichment_ratio", ascending=False)
    df.reset_index(drop=True, inplace=True)
    with plt.rc_context({'font.size': 30, 'axes.linewidth': 4, 'xtick.major.width': 4, 'ytick.major.width': 4, 'xtick.major.size': 10, 'ytick.major.size': 10}):
        p_0_ratio = enrcihment_ratios[i]
        plt.figure(figsize=(35, 6))
        plt.bar(range(100000), df["enrichment_ratio"], width=7, color=COLOR_TEMPLATE[i])
        plt.xlim(-1000, 100000)
        plt.xticks(range(0, 100001, 10000))
        ylim_cap = max(df["enrichment_ratio"].max(), p_0_ratio)
        # offset = 5 if ylim_cap < 15 else 15
        plt.ylim(ylims[i][0], ylims[i][1])
        plt.yticks(range(*yticks_list[i]))
        plt.axhline(y=p_0_ratio, color='black', linestyle='--', linewidth=4)
        plt.text(3000, p_0_ratio+0.5, 'Actual enrichment ratio', verticalalignment='bottom', horizontalalignment='left', fontsize=35)
        pval = pvals[i]
        if pval == CLOSE_TO_ZERO:
            annotated_text = "P < 2.2e-308" 
        elif pval >= adjusted_threshold:
            annotated_text = "Not Significant"
        else:
            annotated_text = f'P = {pval:.1e}'
        plt.annotate(annotated_text, 
                    xy=(0.97, 0.9), 
                    xycoords='axes fraction',
                    horizontalalignment='right',
                    verticalalignment='top',
                    bbox=dict(boxstyle="round,pad=0.3", fc="none", ec="none", lw=0))

        plt.xlabel('Shuffled Set', fontsize=35)
        plt.ylabel('Enrichment Ratio', fontsize=35)
        plt.title(f'Enrichment ratios from 100k Permutations - {CATEGORY_NAMES[i].capitalize()}', fontsize=40)
        plt.tight_layout()
        # plt.show()
        plt.savefig(f"{FIGURE_PATH}/Fig3/{CATEGORY_NAMES[i]}_100k_permutations.jpeg", format="jpeg", bbox_inches="tight", dpi=300)
        plt.close()



In [None]:
# Figure 4B-D: neurological subcategory decay plots

mesh_ttoc = defaultdict(set) #term:code
file_path = '../humanData/database/mtrees2024.txt'
# Read the lines from the file
with open(file_path, 'r') as file:
    for line in file:
        # Split each line into A and B
        parts = line.strip().split(';')
        if len(parts) == 2:
            term, code = parts
            mesh_ttoc[term].add(code)
        else:
            print(parts)
mesh_ctot =  {v:k for k, vs in mesh_ttoc.items() for v in vs}

def starts_with_any(given_string, string_list):
    for prefix in string_list:
        if given_string.startswith(prefix):
            return True
    return False
neuro_mesh_tree = {}
keywords = ["Mental Disorders", "Nervous System Diseases"]
for kw in keywords:
    neuro_mesh_tree[kw] = set([k for k, v in mesh_ttoc.items() for c in v if starts_with_any(c, mesh_ttoc[kw])])
def filter_mesh_list(input):
    return any(input in sublist for sublist in neuro_mesh_tree.values())
target_idx = 7
neuro_df = pd.read_csv(f"probe/neurological_all_probes.csv")  
neuro_df["Filtered Mesh Term"] = neuro_df["Filtered Mesh Term"].apply(lambda x: [term.strip() for term in x.split("|")])

temp1 = neuro_df.drop_duplicates(subset="pmcid")
mesh_count_by_study = defaultdict(int)
mesh_terms = list(temp1["Filtered Mesh Term"])
pmcids = list(temp1["pmcid"])
term_pmcid_map = defaultdict(set)
for i in range(len(mesh_terms)):
    m = mesh_terms[i]
    for t in m:
        mesh_count_by_study[t] += 1
        term_pmcid_map[t].add(pmcids[i])
mesh_count_by_study = {key:count for key, count in mesh_count_by_study.items() if filter_mesh_list(key) and count >= 14}
mesh_count_by_study["Neurological"] = len(set(neuro_df["pmcid"]))
mesh_count_by_study = dict(sorted(mesh_count_by_study.items(), reverse=True))
categories = list(mesh_count_by_study.keys())
counts = list(mesh_count_by_study.values())

enriched_categories = []
not_enriched_categories = []

d1 = []
d2 = []
probes = []
papers_ct = []
terms, counts = zip(*[(k, v) for k, v in mesh_count_by_study.items()])
for term in terms:
    term = keywords if term == "Neurological" else [term]
    show_y_label = term == ["Neurodevelopmental Disorders"]
    show_legend = term == ["Neurodevelopmental Disorders"]
    output_path = f"{FIGURE_PATH}/Fig4/{term}.svg"
    p, paper, r, p2 = breakdown(neuro_df, term_pmcid_map, term, target_idx, show_figure=False, show_y_label=show_y_label, show_legend=show_legend)
    probes.append(p)
    d1.append(paper)
    d2.append(r)
    papers_ct.append(p2)
    # break

print(len(terms), len(d1), len(d2), len(probes), len(counts))
df = pd.DataFrame({"Categories":terms, "Enrichment Ratio":d2, "CoRSIV Probes": probes, "CoRSIV Papers":papers_ct, "Highest Number of Papers": d1, "Total Number of Papers": counts})
df.sort_values("Enrichment Ratio", ascending=False, inplace=True)
df.index = df["Categories"]
df.drop(columns=["Categories"], inplace=True)



In [None]:
# Figure 4B-D: neurological subcategory decay plots

mesh_ttoc = defaultdict(set) #term:code
file_path = '../humanData/database/mtrees2024.txt'
# Read the lines from the file
with open(file_path, 'r') as file:
    for line in file:
        # Split each line into A and B
        parts = line.strip().split(';')
        if len(parts) == 2:
            term, code = parts
            mesh_ttoc[term].add(code)
        else:
            print(parts)
mesh_ctot =  {v:k for k, vs in mesh_ttoc.items() for v in vs}

def starts_with_any(given_string, string_list):
    for prefix in string_list:
        if given_string.startswith(prefix):
            return True
    return False
neuro_mesh_tree = {}
keywords = ["Metabolic Diseases"]
for kw in keywords:
    neuro_mesh_tree[kw] = set([k for k, v in mesh_ttoc.items() for c in v if starts_with_any(c, mesh_ttoc[kw])])
def filter_mesh_list(input):
    return any(input in sublist for sublist in neuro_mesh_tree.values())
target_idx = 6
neuro_df = pd.read_csv(f"probe/metabolic_diseases_all_probes.csv")  
neuro_df["Filtered Mesh Term"] = neuro_df["Filtered Mesh Term"].apply(eval)

temp1 = neuro_df.drop_duplicates(subset="pmcid")
mesh_count_by_study = defaultdict(int)
mesh_terms = list(temp1["Filtered Mesh Term"])
pmcids = list(temp1["pmcid"])
term_pmcid_map = defaultdict(set)
for i in range(len(mesh_terms)):
    m = mesh_terms[i]
    for t in m:
        mesh_count_by_study[t] += 1
        term_pmcid_map[t].add(pmcids[i])
mesh_count_by_study = {key:count for key, count in mesh_count_by_study.items() if filter_mesh_list(key)}
mesh_count_by_study = dict(sorted(mesh_count_by_study.items(), reverse=True))
categories = list(mesh_count_by_study.keys())
counts = list(mesh_count_by_study.values())

enriched_categories = []
not_enriched_categories = []

d1 = []
d2 = []
probes = []
terms, counts = zip(*[(k, v) for k, v in mesh_count_by_study.items()])
for term in terms:
    output_path = f"{FIGURE_PATH}/final/cancer/{term}.svg"
    p, paper, r = breakdown(neuro_df, term_pmcid_map, [term], target_idx, show_figure=True, show_y_label=True, show_legend=False, output=None)
    probes.append(p)
    d1.append(paper)
    d2.append(r)
    # break

print(len(terms), len(d1), len(d2), len(probes), len(counts))
df = pd.DataFrame({"Categories":terms, "Enrichment Ratio":d2, "Number of Probes": probes, "Highest Number of Papers": d1, "Total Number of Papers": counts})
df.sort_values("Enrichment Ratio", ascending=False, inplace=True)
df.index = df["Categories"]
df.drop(columns=["Categories"], inplace=True)
print(df.to_string())
print(df[df["Enrichment Ratio"]<= 1].to_string())



In [None]:
# Figure 4B-D: neurological subcategory decay plots

mesh_ttoc = defaultdict(set) #term:code
file_path = '../humanData/database/mtrees2024.txt'
# Read the lines from the file
with open(file_path, 'r') as file:
    for line in file:
        # Split each line into A and B
        parts = line.strip().split(';')
        if len(parts) == 2:
            term, code = parts
            mesh_ttoc[term].add(code)
        else:
            print(parts)
mesh_ctot =  {v:k for k, vs in mesh_ttoc.items() for v in vs}

def starts_with_any(given_string, string_list):
    for prefix in string_list:
        if given_string.startswith(prefix):
            return True
    return False
neuro_mesh_tree = {}
keywords = ["Neoplasms"]
for kw in keywords:
    neuro_mesh_tree[kw] = set([k for k, v in mesh_ttoc.items() for c in v if starts_with_any(c, mesh_ttoc[kw])])
def filter_mesh_list(input):
    return any(input in sublist for sublist in neuro_mesh_tree.values())
target_idx = 0
neuro_df = pd.read_csv(f"probe/cancer_all_probes.csv")  
neuro_df = neuro_df[neuro_df['pmcid']!="PMC10275808"]
neuro_df["Filtered Mesh Term"] = neuro_df["Filtered Mesh Term"].apply(lambda x: [term.strip() for term in x.split("|")])

temp1 = neuro_df.drop_duplicates(subset="pmcid")
mesh_count_by_study = defaultdict(int)
mesh_terms = list(temp1["Filtered Mesh Term"])
pmcids = list(temp1["pmcid"])
term_pmcid_map = defaultdict(set)
for i in range(len(mesh_terms)):
    m = mesh_terms[i]
    for t in m:
        mesh_count_by_study[t] += 1
        term_pmcid_map[t].add(pmcids[i])
mesh_count_by_study = {key:count for key, count in mesh_count_by_study.items() if filter_mesh_list(key)}
# mesh_count_by_study["Cancer"] = len(set(neuro_df["pmcid"]))
mesh_count_by_study = dict(sorted(mesh_count_by_study.items(), reverse=True))
categories = list(mesh_count_by_study.keys())
counts = list(mesh_count_by_study.values())

enriched_categories = []
not_enriched_categories = []

d1 = []
d2 = []
probes = []
terms, counts = zip(*[(k, v) for k, v in mesh_count_by_study.items()])
for term in terms:
    if term != "Prostatic Neoplasms":
        continue
    output_path = f"{FIGURE_PATH}/final/cancer/{term}_drop.svg"
    p, paper, r = breakdown(neuro_df, term_pmcid_map, [term], target_idx, show_figure=False, show_y_label=False, show_legend=False, output=None)
    probes.append(p)
    d1.append(paper)
    d2.append(r)
    # break

print(len(terms), len(d1), len(d2), len(probes), len(counts))
df = pd.DataFrame({"Categories":terms, "Enrichment Ratio":d2, "Number of Probes": probes, "Highest Number of Papers": d1, "Total Number of Papers": counts})
df.sort_values("Enrichment Ratio", ascending=False, inplace=True)
df.index = df["Categories"]
df.drop(columns=["Categories"], inplace=True)
print(df.to_string())
print(df[df["Enrichment Ratio"]<= 1].to_string())



In [None]:
# Figure 4A: neurological mesh hierarchy
from graphviz import Digraph
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

def visualize(input_nodes, output_name=None, format="svg"):
    input_nodes = {x for y in input_nodes for x in mesh_ttoc[y]}
    dot = Digraph()
    dot.attr(rankdir='LR')  # Keep layout horizontal (Left-to-Right)

    edges = set()
    nodes = {}
    
    for code in input_nodes:
        parts = code.split('.')
        for i in range(1, len(parts) + 1):
            partial_code = '.'.join(parts[:i])
            if partial_code not in nodes and partial_code in mesh_ctot and (partial_code.startswith("F03") or partial_code.startswith("C10")):
                nodes[partial_code] = mesh_ctot[partial_code]
    def add_code_edges(code):
        parts = code.split('.')
        for i in range(1, len(parts)):
            parent_code = '.'.join(parts[:i])
            child_code = '.'.join(parts[:i+1])
            if parent_code in nodes and child_code in nodes:
                edges.add((parent_code, child_code))
            
    def get_node_color(node):
        term = nodes[node]
        if term not in df.index:
            return "#FFFFFF"  # White for nodes not in df
        
        highest_papers = df.loc[term, 'Highest Number of Papers']
        enrichment_ratio = df.loc[term, 'Enrichment Ratio']
        
        if highest_papers < 2 or enrichment_ratio < 1:
            return "#FFFFFF"  # White for nodes with low papers or enrichment
        
        color_rgb = mcolors.to_rgb(COLOR_TEMPLATE[target_idx])
        intensity = min(1, enrichment_ratio / df['Enrichment Ratio'].max())  # Normalize to [0, 1]
        scaled_color = tuple(1 - (1 - c) * intensity for c in color_rgb)  # Invert intensity calculation
        return mcolors.to_hex(scaled_color)

    
    for code, term in nodes.items():
        add_code_edges(code)  # Add edges between parent and child nodes
    nodes["N"] = "Neurological" 
    edges.add(("N", "F03"))
    edges.add(("N", "C10"))
    # Add nodes for each unique code with term name as label
    for code, term in nodes.items():
        node_color = get_node_color(code)
        if term in df.index:
            custom_label = f"{term} ({df.loc[term, 'Total Number of Papers']})"
        else:
            custom_label = term
        dot.node(code, label=custom_label, shape='box', style='filled', fillcolor=node_color, fontname='Helvetica')
    nodes["N"] = "Neurological"
    # Add edges for hierarchical relationships
    for parent_code, child_code in edges:
        dot.edge(parent_code, child_code)
        

    dot.attr(concentrate='true', splines='ortho')
    dot.attr(nodesep='0.2', ranksep='0.5')
    
    # Create color legend
    fig, ax = plt.subplots(figsize=(6, 1))
    cmap = mcolors.LinearSegmentedColormap.from_list("custom", ["white", COLOR_TEMPLATE[target_idx]])
    norm = mcolors.Normalize(vmin=0, vmax=df['Enrichment Ratio'].max())
    cb = plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), 
                      cax=ax, orientation='horizontal', label='Enrichment Ratio')
    
    # Save legend as separate image
    
    # legend_path = f"{FIGURE_PATH}/Fig4/legend.svg"
    # plt.savefig(legend_path, bbox_inches='tight')

    if output_name:
        dot.render(output_name, format=format, cleanup=True)
    else:
        dot.view()
    return nodes

output_name = f"{FIGURE_PATH}/Fig4/neurological_hierarchy"
# Modify the visualize function call to adjust edge routing
nodes = visualize(mesh_count_by_study.keys(), output_name=None, format="svg")


In [None]:
df = pd.read_csv(f"iir_icc/Flanagan_median_15_probes_minimum_df.csv")
# df[(df["Number of Probes"]<= 15) & (df["region_type"] == "CoRSIV")]

df[df["category"].isna()]


In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import to_rgb

# Figure 5A: median scatter plots on IIR and ICC for Flanagan dataset
def plot_median_scatter(col1, col2, ax, category, cmap, min_papers, max_papers, show_legend=False):
    region_markers = {"CoRSIV": "D", "Non-CoRSIV": "o"}
    
    for region_type, marker in region_markers.items():
        region_df = df[(df["region_type"] == region_type) & (df["category"] == category)]
    
        
        # Plot scatter points
        for _, row in region_df.iterrows():
            papers = row["papers"]
            color_intensity = (papers - min_papers) / (max_papers - min_papers + 1)
            color = cmap(color_intensity)
            # Plot main point with 3D effect
            ax.scatter(row[col1], row[col2], c=[color], s=150, alpha=1, marker=marker, edgecolors='black', linewidth=1)
    
    # Add black star for all CoRSIVs
    all_corsiv = df[(df["region_type"] == "CoRSIV") & (df["category"].isna())]
    ax.scatter(all_corsiv[col1], all_corsiv[col2], c='black', s=300, marker='*')
    
    # Add hollow black star for all Non-CoRSIVs
    all_non_corsiv = df[(df["region_type"] == "Non-CoRSIV") & (df["category"].isna())]
    ax.scatter(all_non_corsiv[col1], all_non_corsiv[col2], facecolors='none', edgecolors='black', s=280, marker='*', linewidth=1.5)
    
    ax.set_xlim(0, 1.1)
    ax.set_ylim(0, 1.1)
    ax.set_xticks(np.arange(0, 1.2, 0.2))
    ax.set_xticklabels(['0.0', '', '', '', '', '1.0'])
    ax.set_yticks(np.arange(0, 1.2, 0.2))
    ax.set_yticklabels(['0.0', '', '', '', '', '1.0'])
    # ax.yaxis.set_tick_params(pad=2)  # Adjust the padding to move labels closer to the axis
    zero_label = ax.yaxis.get_major_ticks()[0].label1
    custom_va = 1.0  # This value can be adjusted as needed

    zero_label.set_va('center')
    zero_label.set_position((zero_label.get_position()[0], custom_va))    
    ax.set_yticklabels(['0.0', '', '', '', '', '1.0'])

    ax.set_aspect('equal')
    ax.tick_params(axis='both', which='major', labelsize=20, length=5)
    
    # Add colorbar
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=min_papers, vmax=max_papers))
    sm.set_array([])
    cbar = plt.colorbar(sm, cax=cax)
    if category == "cancer":
        cbar.set_label('≥ Number of Papers', rotation=90, labelpad=5, fontsize=18)
    cbar.set_ticks(np.arange(int(min_papers), int(max_papers) + 1, 1))
    cbar.ax.tick_params(size=5)  # Increase the length of the ticks
    cbar.set_ticklabels(np.arange(int(min_papers), int(max_papers) + 1, 1), fontsize=16)
    
    if show_legend:
        legend_elements = [plt.Line2D([0], [0], marker=marker, color='gray', label=region, markersize=8, linestyle='None')
                           for region, marker in region_markers.items()]
        legend_elements.extend([
            plt.Line2D([0], [0], marker='*', color='black', label='All CoRSIVs', markersize=8, linestyle='None'),
            plt.Line2D([0], [0], marker='*', markerfacecolor='none', markeredgecolor='black', label='All Non-CoRSIVs', markersize=8, linestyle='None')
        ])

# plot_categories = ["cancer", "cardiovascular", "digestive", "endocrine", "immune", "metabolic", "neurological", "obesity", "respiratory", "urogenital"]
# plot_colors = ['#e6194B', '#f58231', '#f3c300', '#469990', '#2f8e3b', '#0db7dd', '#4363d8', '#800000', '#f032e6', '#911eb4']

# category_names = plot_categories
# color_template = plot_colors
data = "Flanagan"
fig, axes = plt.subplots(3, 4, figsize=(16, 10), gridspec_kw={'width_ratios': [1, 1, 1, 1], 'wspace': 0.1, 'hspace': 0.5})

df = pd.read_csv(f"iir_icc/{data}_median_15_probes_minimum_df.csv")

category_paper_ranges = {category: (df[df["category"] == category]["papers"].min(), df[df["category"] == category]["papers"].max()) for category in CATEGORY_NAMES}

for i, category in enumerate(CATEGORY_NAMES):
    row, col = divmod(i, 4)
    
    base_color = COLOR_TEMPLATE[i]
    light_color = to_rgb(base_color)
    dark_color = tuple(0.3 * c for c in to_rgb(base_color))
    
    min_papers, max_papers = category_paper_ranges[category]
    
    cmap = LinearSegmentedColormap.from_list('custom', [dark_color, light_color], N=100)
    
    ax = axes[row][col]
    plot_median_scatter("Median iir1", "Median ICC", ax, category, cmap, min_papers, max_papers, show_legend=(i == 0))
    ax.set_title(category.capitalize(), fontsize=24, pad=10)
    

legend_elements = [plt.Line2D([0], [0], marker=marker, color='gray', label=region, markersize=14 if marker == 'D' else 16, linestyle='None')
                   for region, marker in {"CoRSIV": "D", "Non-CoRSIV": "o"}.items()]
legend_elements.extend([
    plt.Line2D([0], [0], marker='*', color='black', label='All CoRSIVs', markersize=20, linestyle='None'),
    plt.Line2D([0], [0], marker='*', markerfacecolor='none', markeredgecolor='black', label='All Non-CoRSIVs', markersize=20, linestyle='None', linewidth=2)
])

# Add legend to the main plot
legend = axes[2,3].legend(handles=legend_elements, fontsize=20, frameon=False, bbox_to_anchor=(1.0, 1.0))
# Hide last two subplots
# axes[2, 2].axis('off')
axes[2, 3].axis('off')

# Add a vertical line
x = 0.11
fig.add_artist(plt.Line2D([x, x], [0.11, 0.88], transform=fig.transFigure, color='black', linestyle='-', linewidth=3))
fig.text(x-0.025, 0.5, 'Median Intraclass Correlation Coefficient (ICC)', va='center', rotation='vertical', fontsize=24)

# Add a horizontal line
y = 0.05
fig.add_artist(plt.Line2D([0.15, 0.88], [y, y], transform=fig.transFigure, color='black', linestyle='-', linewidth=3))
# Add vertical text to the left of the vertical line
fig.text(0.4, y-0.035, r'Median IIR$_{2-98\%}$ at Time 1', va='center', rotation='horizontal', fontsize=24)

plt.show()
# plt.savefig(f"{FIGURE_PATH}/Fig5/{data}_median_scatter.svg", format="svg")

In [None]:
df1 = pd.read_csv(f"../permutation_testing/pt/urogenital/Prostatic Neoplasms_probes.csv").iloc[:, 1:]
df2 = pd.read_csv(f"../permutation_testing/pt/cancer/Prostatic Neoplasms_probes.csv").iloc[:, 1:]
m = df1.merge(df2, how="outer", indicator=True)
# m[m["_merge"] != "both"]["pmcid"].unique()
m["_merge"].value_counts()

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import to_rgb

# Figure 5A: median scatter plots on IIR and ICC for Flanagan dataset
def plot_median_scatter(col1, col2, ax, category, base_color, min_papers, max_papers, show_legend=False):
    min_papers = 1  # Minimum number of papers
    max_papers = 7  # Maximum number of papers
    step_size = 30
    min_size = 10
    max_size = min_size+step_size*(max_papers-min_papers)
    for region_type in ["CoRSIV", "Non-CoRSIV"]:
        region_df = df[(df["region_type"] == region_type) & (df["category"] == category)]
        # Plot scatter points
        for _, row in region_df.iterrows():
            papers = row["papers"]
            # Calculate marker size based on paper count
            size = min_size + (max_size - min_size) * (papers - min_papers) / (max_papers - min_papers)
            if region_type == "CoRSIV":
                ax.scatter(row[col1], row[col2], c=base_color, s=size, alpha=1, marker='D', edgecolors='black')
            else:
                ax.scatter(row[col1], row[col2], c='gray', s=size, alpha=1, marker='o', edgecolors='black')
    
    # Add black star for all CoRSIVs
    all_corsiv = df[(df["region_type"] == "CoRSIV") & (df["category"].isna())]
    ax.scatter(all_corsiv[col1], all_corsiv[col2], c='black', s=300, marker='*')
    
    # Add hollow black star for all Non-CoRSIVs
    all_non_corsiv = df[(df["region_type"] == "Non-CoRSIV") & (df["category"].isna())]
    ax.scatter(all_non_corsiv[col1], all_non_corsiv[col2], facecolors='none', edgecolors='black', s=280, marker='*', linewidth=1.5)
    
    ax.set_xlim(0, 1.0)
    ax.set_ylim(0, 1.0)
    ax.set_xticks(np.arange(0, 1.2, 0.2))
    ax.set_xticklabels(['0.0', '', '', '', '', '1.0'])
    ax.set_yticks(np.arange(0, 1.2, 0.2))
    ax.set_yticklabels(['0.0', '', '', '', '', '1.0'])
    zero_label = ax.yaxis.get_major_ticks()[0].label1
    custom_va = 1.0
    zero_label.set_va('center')
    zero_label.set_position((zero_label.get_position()[0], custom_va))    
    ax.set_yticklabels(['0.0', '', '', '', '', '1.0'])

    ax.set_aspect('equal')
    ax.tick_params(axis='both', which='major', labelsize=20, length=5)
    

data = "Flanagan"
fig, axes = plt.subplots(3, 4, figsize=(16, 10), gridspec_kw={'width_ratios': [1, 1, 1, 1], 'wspace': 0.1, 'hspace': 0.5})

df = pd.read_csv(f"iir_icc/{data}_median_15_probes_minimum_df.csv")

category_paper_ranges = {category: (df[df["category"] == category]["papers"].min(), df[df["category"] == category]["papers"].max()) for category in CATEGORY_NAMES}

for i, category in enumerate(CATEGORY_NAMES):
    row, col = divmod(i, 4)
    base_color = COLOR_TEMPLATE[i]
    min_papers, max_papers = category_paper_ranges[category]
    
    ax = axes[row][col]
    plot_median_scatter("Median iir1", "Median ICC", ax, category, base_color, min_papers, max_papers, show_legend=(i == 0))
    ax.set_title(category.capitalize(), fontsize=24, pad=10)

legend_elements = []
legend_elements.extend([
    plt.Line2D([0], [0], marker='*', color='black', label='All CoRSIVs', markersize=20, linestyle='None'),
    plt.Line2D([0], [0], marker='*', markerfacecolor='none', markeredgecolor='black', label='All Non-CoRSIVs', markersize=20, linestyle='None', linewidth=2)
])

# Add legend to the main plot
legend = axes[2,3].legend(handles=legend_elements, fontsize=20, frameon=False, bbox_to_anchor=(1.0, 1.0))
# Hide last two subplots
# axes[2, 2].axis('off')
axes[2, 3].axis('off')

# Add a vertical line
x = 0.11
fig.add_artist(plt.Line2D([x, x], [0.11, 0.88], transform=fig.transFigure, color='black', linestyle='-', linewidth=3))
fig.text(x-0.025, 0.5, 'Median Intraclass Correlation Coefficient (ICC)', va='center', rotation='vertical', fontsize=24)

# Add a horizontal line
y = 0.05
fig.add_artist(plt.Line2D([0.15, 0.88], [y, y], transform=fig.transFigure, color='black', linestyle='-', linewidth=3))
# Add vertical text to the left of the vertical line
fig.text(0.4, y-0.035, r'Median IIR$_{2-98\%}$ at Time 1', va='center', rotation='horizontal', fontsize=24)
plt.show()

# plt.savefig(f"{FIGURE_PATH}/Fig5/{data}_median_scatter_2.svg", format="svg")

In [None]:
import statsmodels.api as sm
# Figure 5A: median scatter plots on IIR and ICC for Flanagan dataset
def plot_median_scatter(ax, category):
    for region_type in ["CoRSIV", "Non-CoRSIV"]:
        region_df = df[(df["region_type"] == region_type) & (df["category"] == category)]
        color = COLOR_TEMPLATE[CATEGORY_NAMES.index(category)] if region_type == "CoRSIV" else "gray"
        # Plot scatter points
        X = np.array(region_df["Median iir1"])
        y = np.array(region_df["papers"])
        ax.scatter(X, y, c=color, s=150, alpha=1, marker="o")
        
        X_const = sm.add_constant(X)
        model = sm.OLS(y, X_const).fit()
        ax.plot(X, model.predict(X_const), color=color, linestyle='--', linewidth=4, zorder=10, alpha=0.6)
        if len(X) > 2:  
            r_squared = round(model.rsquared, 3)
            slope = round(model.params[1], 3)
            f_pvalue = model.f_pvalue
            annotation_text = f"R² = {r_squared:.2f}\nP = {f_pvalue:.2e}"
            pos =  2 if region_type == 'CoRSIV' else 0.1
            ax.annotate(annotation_text, 
                xy=(0.85, pos),  # Use the last point of the data as the annotation position
                xytext=(10, 0), 
                textcoords='offset points',
                color=color,
                fontsize=12,
                ha='right', 
                va='bottom',
                bbox=dict(boxstyle='round,pad=0.5', fc='none', ec='none', alpha=1))  # Added weight parameter to make text bold

        ax.set_ylim(0, region_df["papers"].max()+1)
    ax.set_xlim(0, 1.0)
    ax.set_xticks(np.arange(0, 1.2, 0.2))
    ax.set_xticklabels(['0.0', '', '', '', '', '1.0'])
data = "Flanagan"
fig, axes = plt.subplots(3, 4, figsize=(14, 10), gridspec_kw={'width_ratios': [1, 1, 1, 1], 'wspace': 0.5, 'hspace': 0.5})

df = pd.read_csv(f"iir_icc/{data}_median_15_probes_minimum_df.csv")

category_paper_ranges = {category: (df[df["category"] == category]["papers"].min(), df[df["category"] == category]["papers"].max()) for category in CATEGORY_NAMES}

for i, category in enumerate(CATEGORY_NAMES):
    row, col = divmod(i, 4)
    ax = axes[row][col]
    plot_median_scatter(ax, category)
    ax.set_title(category.capitalize(), fontsize=24, pad=10)



axes[2, 3].axis('off')


# plt.savefig(f"{FIGURE_PATH}/Fig5/{data}_median_scatter.svg", format="svg")

In [None]:
# Figure 5B: becon density plot
import math
import scipy.stats as stats
import matplotlib.pyplot as plt

cat_probes_dict = []
for cat in CATEGORY_NAMES:
    cat_probes_dict.append(read_in_probes(cat))
    
non_corsiv_baseline = illumina - CORSIV_PROBE_LIST
bins = [-1, 0, 1.0]
target_col = "Mean Cor All Brain"
df = pd.read_csv("becon/becon_all_probes.csv")
regions = list(zip(["Non-CoRSIV", "CoRSIV"], [non_corsiv_baseline, CORSIV_PROBE_LIST]))


# Create a single subplot
fig, ax = plt.subplots(figsize=(6,5))

catname = 'neurological'
i = CATEGORY_NAMES.index(catname)

dfs_for_plot = []
max_papers = max(cat_probes_dict[i].values())
papers_threshold = 2
p = set(k for k, v in cat_probes_dict[i].items() if v >= papers_threshold)
for rname, rset in regions:
    probes_in_region = rset.intersection(p)
    filtered_df = df[df["CpG ID"].isin(probes_in_region)]
    dfs_for_plot.append((filtered_df, rname))

for j, (df_subset, rname) in enumerate(dfs_for_plot):
    density = stats.gaussian_kde(df_subset[target_col])
    xs = np.linspace(-1, 1, 200)
    ys = density(xs)
    color = COLOR_TEMPLATE[i] if rname == 'CoRSIV' else 'grey'
    
    ax.plot(xs, ys, "-", color=color, label=rname, linewidth=3)
    median = df_subset[target_col].median()
    ax.axvline(median, color=color, linestyle='--', linewidth=2, ymax=density(median)[0] / 2.1)
    ax.text(median + 0.05, 0.5, f'median = {median:.2f}', rotation=90, color="black", fontsize=16)
    ax.fill_between(xs, ys, alpha=0.5, color=color)
    peak_index = np.argmax(ys)
    ax.text(xs[peak_index], ys[peak_index]+0.04, rname, fontsize=20, 
            verticalalignment='bottom', horizontalalignment='center', 
            color=color)

# ax.set_aspect('equal')
ax.set_xticks([-1, 0.0, 1.0])
ax.tick_params(axis='both', which='major', labelsize=25, length=5)
ax.set_yticks([0, 1.0, 2.0])
ax.set_ylim(0, 2.1)
ax.set_xlabel("Brain-Blood Correlation (BECon)", fontsize=25)
ax.set_ylabel('Density', fontsize=25)
ax.set_title(f'Probes in ≥ {papers_threshold} Neurological Papers', fontsize=25, pad=15)

plt.tight_layout()
output = f"{FIGURE_PATH}/Fig5/becon_kde_neurological.svg"
plt.savefig(output, format="svg")
# Only show the plot once
plt.show()


In [None]:
# Figure 5C: becon regression plot for neurological category
import math
import scipy.stats as stats
import matplotlib.pyplot as plt
import numpy as np
import statsmodels.api as sm

non_corsiv_baseline = illumina - CORSIV_PROBE_LIST
target_col = "Mean Cor All Brain"
df = pd.read_csv("becon/becon_all_probes.csv")
regions = list(zip(["CoRSIV", "Non-CoRSIV"], [CORSIV_PROBE_LIST, non_corsiv_baseline]))

# Read probes for neurological category
neurological_probes = read_in_probes("neurological")

# Create a single subplot
fig, ax = plt.subplots(figsize=(8, 4))

max_papers = max(neurological_probes.values())

for rname, rset in regions:
    medians = []
    paper_counts = []
    for pidx in range(1, max_papers + 1):
        p = set(k for k, v in neurological_probes.items() if v == pidx)
        probes_in_region = rset.intersection(p)
        filtered_df = df[df["CpG ID"].isin(probes_in_region)]
        if len(filtered_df) < 15:
            max_papers = pidx - 1
            break
        else:
            medians.append(filtered_df[target_col].median())
            paper_counts.append(pidx)
    
    X = np.array(medians)
    y = np.array(paper_counts)
    X_const = sm.add_constant(X)
    model = sm.OLS(y, X_const).fit()

    color = COLOR_TEMPLATE[CATEGORY_NAMES.index("neurological")] if rname == 'CoRSIV' else 'grey'
    marker = 'D' if rname == 'CoRSIV' else 'o'
    ax.scatter(X, y, color=color, label=rname, s=200, marker=marker)
    ax.plot(X, model.predict(X_const), color=color, linestyle='--', linewidth=4, zorder=10, alpha=0.6)
    r_squared = round(model.rsquared, 3)
    slope = round(model.params[1], 3)
    f_pvalue = round(model.f_pvalue, 3)
    print(f"Model p-value: {model.f_pvalue}")
    
    pos = 0.75 if rname == 'CoRSIV' else 0.45
    annotation_text = f"R² = {r_squared:.2f}\nP = {f_pvalue:.3f}"
    
    
    # Add annotation
    ax.annotate(annotation_text, 
                xy=(X[-1]-0.04, y[-1]/2),  # Use the last point of the data as the annotation position
                xytext=(10, 0), 
                textcoords='offset points',
                color=color,
                fontsize=17,
                ha='left', 
                va='center',
                bbox=dict(boxstyle='round,pad=0.5', fc='none', ec='none', alpha=1))  # Added weight parameter to make text bold

ax.set_xlabel("Median Brain-Blood Correlation (BECon)", fontsize=22)
ax.set_ylabel('Number of Papers', fontsize=18)
ax.set_title("Neurological", fontsize=25, pad=15)
ax.set_xlim(-0.05, 0.6)
ax.set_xticks(np.arange(0.0, 0.7, 0.1))
ax.tick_params(axis='x', which='major', length=5)  # Increase the length of x-axis ticks
# ax.set_xticklabels(['0.0', '', '0.2', '', '0.4', '', '0.6'])
ax.set_yticks(range(1, max_papers + 1))
ax.set_ylim(0, max_papers + 1)
ax.tick_params(axis='both', which='major', labelsize=20)
# ax.legend(fontsize=14, frameon=False, loc='upper left')

plt.tight_layout()
# plt.close()
output = f"{FIGURE_PATH}/Fig5/becon_regression_neurological.svg"
plt.savefig(output, format="svg")
