In [None]:
import os
import math
import random
from pathlib import Path
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib.lines import Line2D
import networkx as nx
from scipy.stats import ttest_rel, pearsonr


In [None]:
# Read in primekg
primekg = pd.read_csv('./kg.csv', low_memory=False)


In [None]:
# convert csv into dictionary of features per disease
def create_dictionaries(
    csv_path: str,
    label_left: str = "Measured Protein",
    label_right: str = "RABIT Protein",
    sep: str = "; ",
    strip_whitespace: bool = True,
    as_sets: bool = True,
    label_map_reverse: Optional[Dict[str, str]] = None  # reverse of label_map if needed
) -> (Dict[str, List[str]], Dict[str, List[str]]):
    
    df = pd.read_csv(csv_path, index_col=0)

    dict_left = {}
    dict_right = {}

    for disease_name, row in df.iterrows():
        orig_name = label_map_reverse.get(disease_name, disease_name) if label_map_reverse else disease_name

        def process_cell(cell):
            if pd.isna(cell) or cell == "":
                return set() if as_sets else []
            items = str(cell).split(sep)
            if strip_whitespace:
                items = [x.strip() for x in items if x.strip()]
            return set(items) if as_sets else items

        dict_left[orig_name] = process_cell(row[label_left])
        dict_right[orig_name] = process_cell(row[label_right])

    return dict_left, dict_right


In [None]:
measured_prot_unique, rabit_prot_unique = create_dictionaries(
    "./syn_real_feat_unique.csv",
    label_left="Measured Protein",
    label_right="RABIT Protein",
    sep="; ",
    as_sets=True)


In [None]:
# split compound proteins names into singular proteins
def transform_dictionary(input_dict):
    new_dict = {}
        for key, elements in input_dict.items():
        transformed_elements = []
        for element in elements:
            # if "hla", convert all underscores to hyphens.
            if element.startswith("hla"):
                transformed_elements.append(element.replace("_", "-"))
            # if underscore, split by underscore
            elif "_" in element:
                parts = element.split('_')
                # if underscore and second part is a single character, join with a hyphen.
                if len(parts) == 2 and len(parts[1]) == 1:
                    transformed_elements.append(parts[0] + "-" + parts[1])
                else:
                    transformed_elements.extend(parts)
            else:
                transformed_elements.append(element)
        new_dict[key] = transformed_elements
        
    return new_dict


measured_prot_unique_cleaned = transform_dictionary(measured_prot_unique)
rabit_prot_unique_cleaned = transform_dictionary(rabit_prot_unique)


In [None]:
# manual translation of the proteins not found in primekg
translation_key = {
    'ntprobnp':'nppb',
    'palm2':'palm',
    'cert':'cert1',
    'sarg':'c1orf116',
    'leg1':'c6orf58',
    'gpr15l':'c10orf99',
    'wars':'wars1',
    'bap18':'c17orf49',
    'gatd3':'gatd3a',
    'ment':'c1orf56' }


def translate_dictionary2(input_dict, transdict):
    new_dict = {}

    # Iterate over each key in the dictionary
    for key, elements in input_dict.items():
        translated_elements = []
        # Iterate over each element in the value list
        for element in elements:
            # If the element exists as a key in transdict, replace it.
            if element in transdict:
                translated_elements.append(transdict[element])
            else:
                translated_elements.append(element)
        new_dict[key] = translated_elements

    return new_dict


measured_prot_unique_cleaned_translated = translate_dictionary2(measured_prot_unique_cleaned, translation_key)
rabit_prot_unique_cleaned_translated = translate_dictionary2(rabit_prot_unique_cleaned, translation_key)


In [None]:
# primekg shortest distances between protein and disease filtered to only use protein-protein intermediary nodes
def find_shortest_paths_for_proteins_filteredprimekg(
    primekg: pd.DataFrame,
    disease_name: str,
    protein_list: List[str],
    translated_diseases: Dict[str, Union[str, List[str]]]
) -> Dict[str, Any]:

    dp_edges = primekg[primekg['relation'] == 'disease_protein']
    pp_edges = primekg[primekg['relation'] == 'protein_protein']

    G_pp = nx.from_pandas_edgelist(
        pp_edges,
        source="x_index",
        target="y_index",
        create_using=nx.Graph()
    )

    raw = translated_diseases.get(disease_name, disease_name)
    disease_variants = [raw] if isinstance(raw, str) else list(raw)

    disease_ids = set()
    for variant in disease_variants:
        key = variant.lower()
        disease_ids.update(
            dp_edges.loc[
                (dp_edges["x_name"].str.lower() == key) &
                (dp_edges["x_type"] == "disease"),
                "x_index"
            ]
        )
        disease_ids.update(
            dp_edges.loc[
                (dp_edges["y_name"].str.lower() == key) &
                (dp_edges["y_type"] == "disease"),
                "y_index"
            ]
        )
    disease_ids = sorted(disease_ids)
    if not disease_ids:
        print(f"No disease nodes found for '{disease_name}' (variants: {disease_variants}).")
        return {}

    id_to_name: Dict[int,str] = {}
    for _, row in primekg.iterrows():
        id_to_name.setdefault(row["x_index"], row["x_name"])
        id_to_name.setdefault(row["y_index"], row["y_name"])

    result: Dict[str,Any] = {}
    for protein_name in protein_list:
        prot_ids = set()
        key = protein_name.casefold()
        
        prot_ids.update(
            primekg.loc[
                (primekg["x_name"].str.casefold() == key) &
                (primekg["x_type"] == "gene/protein"),
                "x_index"
            ]
        )
        
        prot_ids.update(
            primekg.loc[
                (primekg["y_name"].str.casefold() == key) &
                (primekg["y_type"] == "gene/protein"),
                "y_index"
            ]
        )
        protein_ids = sorted(prot_ids)
        if not protein_ids:
            print(f"No gene/protein nodes found for '{protein_name}'.")
            continue

        all_paths = []
        for d_id in disease_ids:
            init_prots = set(dp_edges.loc[dp_edges["x_index"] == d_id, "y_index"]) \
                       .union(dp_edges.loc[dp_edges["y_index"] == d_id, "x_index"])

            for p_id in protein_ids:
                best_path = None
                best_len = float('inf')

                if p_id in init_prots:
                    best_path = [d_id, p_id]
                    best_len = 2
                else:
                    if p_id not in G_pp:
                        continue
                    for prot in init_prots:
                        if prot not in G_pp:
                            continue
                        try:
                            pp_path = nx.shortest_path(G_pp, source=prot, target=p_id)
                        except nx.NetworkXNoPath:
                            continue

                        candidate = [d_id] + pp_path
                        if len(candidate) < best_len:
                            best_path = candidate
                            best_len = len(candidate)

                if best_path:
                    readable = " -> ".join(
                        f"{nid} ({id_to_name.get(nid,'Unknown')})"
                        for nid in best_path
                    )
                    all_paths.append({
                        "distance": len(best_path) - 1,
                        "path_codes": best_path,
                        "path_readable": readable
                    })

        if not all_paths:
            print(f"No valid constrained paths for '{disease_name}' → '{protein_name}'.")
            result[protein_name] = {"average_distance": None, "paths": []}
        else:
            avg_dist = sum(p["distance"] for p in all_paths) / len(all_paths)
            result[protein_name] = {
                "average_distance": avg_dist,
                "paths": all_paths
            }
    return result


def build_and_save_all_disease_protein_paths_filteredprimekg(primekg, disease_map, output_pickle_file, translated_diseases_dict):
    final_results = {}
    for disease, proteins in disease_map.items():
        print(f"\n== Processing disease '{disease}' with {len(proteins)} protein(s) ==")
        result_for_disease = find_shortest_paths_for_proteins_filteredprimekg(
            primekg=primekg,
            disease_name=disease,
            protein_list=proteins,
            translated_diseases=translated_diseases_dict
        )
        final_results[disease] = result_for_disease
    with open(output_pickle_file, "wb") as f:
        pickle.dump(final_results, f)
    print(f"\nAll results saved to {output_pickle_file}")
    return final_results


In [None]:
# disease names converted to match formats of those in primekg
disease_translation_dict = {
    'Leukemia': filter_strings_containing(unique_list_diseases, 'leukemia'),
    'Non-Hodgkin lymphoma': 'non-Hodgkin lymphoma',
    'Type 2 diabetes': 'type 2 diabetes mellitus',
    'Ischemic heart disease': 'myocardial ischemia',
    'Cerebrovascular diseases': 'cerebrovascular disorder',
    'Emphysema, COPD': 'chronic obstructive pulmonary disease',
    'Chronic liver diseases': 'chronic liver failure',
    'Chronic kidney diseases': 'chronic kidney disease',
    'All-cause dementia': filter_strings_containing(unique_list_diseases, 'dementia'),
    'Alzheimer’s disease': 'Alzheimer disease',
    'Parkinson’s disease and parkinsonism': 'Parkinson disease',
    'Rheumatoid arthritis': 'rheumatoid arthritis',
    'Osteoporosis': 'osteoporosis',
    'Osteoarthritis': 'osteoarthritis' }
    
    

In [None]:
# all olink panel proteins
protnames = pd.read_csv('./measured_proteomics_random.csv')
olinkpanel = [col[:-8] if col.endswith('_protein') else col      
                for col in protnames.columns                           
                if col != 'eid']                                    
olinkpanel

In [None]:
def transform_protein_list(protein_list):
    transformed = []

    for element in protein_list:
        # 1. hla_*  → replace all underscores with hyphens
        if element.startswith("hla"):
            transformed.append(element.replace("_", "-"))

        # 2. anything else with one or more underscores
        elif "_" in element:
            parts = element.split("_")

            # exactly one underscore and a 1-char suffix  → join with hyphen
            if len(parts) == 2 and len(parts[1]) == 1:
                transformed.append(f"{parts[0]}-{parts[1]}")
            else:
                # multiple parts  → keep each part as its own entry
                transformed.extend(parts)

        # 3. no underscores  → leave unchanged
        else:
            transformed.append(element)

    return transformed

olinkpanel1 = transform_protein_list(olinkpanel)


In [None]:
# manual translation of the proteins not found in primekg
## note: palm has multiple isoforms, one of which is palm2. there is another isoform (palm2akap2) that is present in primekg
translation_key = {
    'ntprobnp':'nppb',
    'palm2':'palm',
    'cert':'cert1',
    'sarg':'c1orf116',
    'leg1':'c6orf58',
    'gpr15l':'c10orf99',
    'wars':'wars1',
    'bap18':'c17orf49',
    'gatd3':'gatd3a',
    'ment':'c1orf56' }

olinkpanel2 = [translation_key.get(protein, protein) for protein in olinkpanel1]

In [None]:
all_olink_proteins_for_each_disease = {disease: olinkpanel2.copy() for disease in realfeat_unique_protlist_dict_cleaned_translated}
all_olink_proteins_for_each_disease

In [None]:
results_dict = build_and_save_all_disease_protein_paths_filteredprimekg(
    primekg=primekg,
    disease_map=all_olink_proteins_for_each_disease,
    output_pickle_file="primekg_results.pkl",
    translated_diseases_dict=disease_translation_dict
)

In [None]:
with open("primekg_results.pkl", "rb") as f:
    all_olink_proteins_shortdist = pickle.load(f)

In [None]:
# subset proteins and their shortest distances as desired
def subset_distance_dict(
    distance_dict: Mapping[str, Mapping[str, Mapping[str, Any]]],
    filter_dict:   Mapping[str, Iterable[str]],
    *,
    drop_if_missing: bool = True,
) -> Dict[str, Dict[str, Dict[str, Any]]]:
    exclude_set = set(exclude_list or [])

    out: Dict[str, Dict[str, Dict[str, Any]]] = {}
    for disease, prot_dict in distance_dict.items():
        if disease not in filter_dict:
            if drop_if_missing:
                continue        
            else:
                out[disease] = prot_dict.copy()
                continue
        proteins_to_keep: List[str] = list(filter_dict[disease])
        subset_prots = {
            p: details
            for p, details in prot_dict.items()
            if p in proteins_to_keep
        }
        if subset_prots:
            out[disease] = subset_prots

    return out



In [None]:
realprot_shortdist_sub = subset_distance_dict(
    all_olink_proteins_shortdist,         
    measured_prot_unique_cleaned_translated,       
    drop_if_missing=True
)

synprot_shortdist_sub = subset_distance_dict(
    all_olink_proteins_shortdist,          
    rabit_prot_unique_cleaned_translated,         
    drop_if_missing=True
)


In [None]:
def plot_distance_histograms_overlay(
    dist_dict1: Dict[str, Dict[str, dict]],
    dist_dict2: Dict[str, Dict[str, dict]],
    *,
    label1: str = "Dict1",
    label2: str = "Dict2",
    delta_missing: Optional[object] = 10,
    diseases: Optional[Sequence[str]] = None,
    exclude: Optional[Iterable[str]] = None,
    label_replacements: Optional[Dict[str, str]] = None,
    figsize_per_plot: Tuple[float, float] = (4, 3),
    ncols: int = 3,
    color1: str = "#4c72b0",
    color2: str = "#dd8452",
    colors: Optional[Tuple[str, str]] = None,   # legacy tuple; overrides color1/2 if given
    alpha1: float = 0.55,
    alpha2: float = 0.55,
    edgecolor1: str = "black",
    edgecolor2: str = "black",
    show: bool = True,
    save: bool = False,
    out_file: str = "distance_histograms_overlay.png",
    dpi: int = 300,
):

    # Handle legacy `colors` arg
    if colors is not None:
        if len(colors) != 2:
            raise ValueError("`colors` must be length-2.")
        color1, color2 = colors

    def _collect_min_dists(dct, disease):
        vals = []
        for info in dct.get(disease, {}).values():
            # explicit path distances
            if info.get("paths"):
                dvals = [p.get("distance") for p in info["paths"] if "distance" in p]
                if dvals:
                    vals.append(int(min(dvals)))
                    continue
            # fallback to average_distance
            avg = info.get("average_distance")
            if isinstance(avg, (int, float)) and not np.isnan(avg):
                vals.append(int(round(avg)))
            else:
                vals.append(None)
        return vals

    # diseases to plot
    all_names = set(dist_dict1) | set(dist_dict2)
    if exclude:
        all_names -= set(exclude)
    diseases = (sorted(all_names) if diseases is None else
                [d for d in diseases if d in all_names])
    if not diseases:
        raise ValueError("No diseases to plot after filtering.")

    n = len(diseases)
    ncols = max(1, min(ncols, n))
    nrows = int(np.ceil(n / ncols))
    fig_w = figsize_per_plot[0] * ncols
    fig_h = figsize_per_plot[1] * nrows
    fig, axes = plt.subplots(nrows, ncols, figsize=(fig_w, fig_h), squeeze=False)
    axes = axes.ravel()

    # interpret delta_missing
    skip_missing = (
        delta_missing is None or
        (isinstance(delta_missing, str) and delta_missing.lower() == "none")
    )
    str_missing = delta_missing if (isinstance(delta_missing, str) and not skip_missing) else None
    num_missing = delta_missing if isinstance(delta_missing, (int, float)) else None

    for idx, disease in enumerate(diseases):
        ax = axes[idx]
        d1_vals_raw = _collect_min_dists(dist_dict1, disease)
        d2_vals_raw = _collect_min_dists(dist_dict2, disease)

        def _prep(vals):
            numeric = [v for v in vals if v is not None]
            missing_ct = sum(v is None for v in vals)
            return numeric, missing_ct

        d1_num, d1_missing = _prep(d1_vals_raw)
        d2_num, d2_missing = _prep(d2_vals_raw)

        x_labels, x_positions, counts1, counts2 = [], [], [], []

        unique_numeric = sorted(set(d1_num) | set(d2_num))
        for x in unique_numeric:
            counts1.append(d1_num.count(x))
            counts2.append(d2_num.count(x))
            x_positions.append(x)
            x_labels.append(str(x))

        if not skip_missing and (d1_missing > 0 or d2_missing > 0):
            if num_missing is not None:
                pos = num_missing
                if pos in unique_numeric:
                    pos = max(unique_numeric) + 1
                label = str(num_missing)
            else:
                pos = (max(unique_numeric) + 1) if unique_numeric else 0
                label = str_missing
            counts1.append(d1_missing)
            counts2.append(d2_missing)
            x_positions.append(pos)
            x_labels.append(label)

        if not x_positions:
            ax.set_visible(False)
            continue

        width = 0.75
        ax.bar(x_positions, counts1, width=width,
               color=color1, alpha=alpha1, edgecolor=edgecolor1, label=label1)
        ax.bar(x_positions, counts2, width=width,
               color=color2, alpha=alpha2, edgecolor=edgecolor2, label=label2)

        ax.set_xticks(x_positions)
        ax.set_xticklabels(x_labels)
        ax.set_xlabel("Shortest path distance")
        ax.set_ylabel("# Proteins")
        title = label_replacements.get(disease, disease) if label_replacements else disease
        ax.set_title(title)
        ax.spines[['top', 'right']].set_visible(False)

        if idx == 0:
            ax.legend(frameon=False, fontsize=8)

    for extra_ax in axes[n:]:
        extra_ax.set_visible(False)

    plt.tight_layout()
    if save:
        os.makedirs(os.path.dirname(out_file) or ".", exist_ok=True)
        fig.savefig(out_file, dpi=dpi, bbox_inches="tight")
    if show:
        plt.show()
    plt.close(fig)
    return fig, axes



fig, axes = plot_distance_histograms_overlay(
    realprot_shortdist_sub,
    synprot_shortdist_sub,
    label1="Measured Proteins",
    label2="RABIT",
    color1="#d89a97",
    color2="#94bed8",
    edgecolor1="#333333",
    edgecolor2="#333333",
    alpha1=0.6,
    alpha2=0.6,
    delta_missing="None",
    label_replacements=label_replacements,
    save=False
)



In [None]:
def plot_none_statistics_pct_only(
    dict1: Dict[str, Dict[str, dict]],
    dict2: Dict[str, Dict[str, dict]],
    *,
    label1: str = "realprot",
    label2: str = "synprot",
    exclude_list: Optional[Iterable[str]] = None,
    font_size: int = 14,
    label_map: Optional[Dict[str, str]] = None,
    colors: Tuple[str, str] = ("#1f77b4", "#ff7f0e"),   # hexes for dict1/dict2

    # extra styling
    pairline_color: str = "gray",
    pairline_alpha: float = 0.7,
    bar_width: float = 0.35,

    # saving
    save: bool = False,
    out_dir: str = "./none_stats_figs",
    fname_prefix: str = "none_stats",
    save_formats: Tuple[str, ...] = ("pdf",),
):
    exclude_list = set(exclude_list or [])
    label_map = label_map or {}
    color1, color2 = colors

    diseases = sorted((set(dict1) | set(dict2)) - exclude_list)

    pct1, pct2 = [], []
    for d in diseases:
        total1 = len(dict1.get(d, {}))
        n1 = sum(info.get("average_distance") is None
                 for info in dict1.get(d, {}).values())
        total2 = len(dict2.get(d, {}))
        n2 = sum(info.get("average_distance") is None
                 for info in dict2.get(d, {}).values())

        pct1.append((n1 / total1 * 100) if total1 else 0.0)
        pct2.append((n2 / total2 * 100) if total2 else 0.0)

    pval_pct = ttest_rel(pct1, pct2).pvalue if len(pct1) > 1 else np.nan

    def save_fig(fig: plt.Figure, suffix: str):
        if not save:
            return
        os.makedirs(out_dir, exist_ok=True)
        for ext in save_formats:
            fig.savefig(
                os.path.join(out_dir, f"{fname_prefix}_{suffix}.{ext}"),
                dpi=300, bbox_inches="tight"
            )

    figs = {}

    # % None by disease
    fig1, ax1 = plt.subplots(figsize=(14, 6))
    x = np.arange(len(diseases))
    ax1.bar(x - bar_width/2, pct1, bar_width, color=color1, label=label1)
    ax1.bar(x + bar_width/2, pct2, bar_width, color=color2, label=label2)
    for i in range(len(diseases)):
        ax1.plot([x[i] - bar_width/2, x[i] + bar_width/2],
                 [pct1[i], pct2[i]],
                 marker="o", linestyle="-",
                 color=pairline_color, alpha=pairline_alpha)
    ax1.set_ylabel("% Unmapped", fontsize=font_size)
    ax1.set_xticks(x)
    ax1.set_xticklabels([label_map.get(d, d) for d in diseases],
                        rotation=45, ha="right", fontsize=font_size)
    ax1.tick_params(axis="both", labelsize=font_size)
    ax1.legend(frameon=False, fontsize=font_size)
    fig1.tight_layout()
    save_fig(fig1, "pct_by_disease")
    figs["pct_by_disease"] = (fig1, ax1)

    # Average % None across diseases
    fig2, ax2 = plt.subplots(figsize=(6, 6))
    summary_pct = [np.mean(pct1), np.mean(pct2)]
    xs = np.arange(2)
    ax2.bar(xs[0], summary_pct[0], bar_width, color=color1)
    ax2.bar(xs[1], summary_pct[1], bar_width, color=color2)
    for i in range(len(diseases)):
        ax2.plot(xs, [pct1[i], pct2[i]], marker="o", linestyle="-",
                 color=pairline_color, alpha=pairline_alpha)

    ax2.set_ylabel("% Unmapped", fontsize=font_size)
    ax2.set_xticks(xs)
    ax2.set_xticklabels([label1, label2], rotation=0, ha="center",
                        fontsize=font_size)
    ax2.tick_params(axis="both", labelsize=font_size)

    y_all = pct1 + pct2 + summary_pct
    y_min = min(0.0, min(y_all))
    y_max = max(y_all)
    pad   = 0.08 * (y_max - y_min if y_max > y_min else 1.0)
    ax2.set_ylim(y_min - pad, y_max + pad * 2)

    y_top = ax2.get_ylim()[1]
    ax2.text(0.5, y_top - pad * 0.6, f"p = {pval_pct:.3f}",
             ha="center", va="bottom", fontsize=font_size)

    fig2.tight_layout()
    save_fig(fig2, "pct_avg")
    figs["pct_avg"] = (fig2, ax2)

    return figs


label_replacements = {
    'Parkinson’s disease and parkinsonism': 'PD',
    'Alzheimer’s disease':  'AD',
    'Chronic liver diseases': 'CLD',
    'Ischemic heart disease': 'IHD',
    'Osteoporosis': 'OSP',
    'Emphysema, COPD': 'COPD',
    'Type 2 diabetes': 'T2DM',
    'Chronic kidney diseases': 'CKD',
    'Leukemia': 'Leukemia',
    'Non-Hodgkin lymphoma': 'NHL',
    'Cerebrovascular diseases': 'CBVD',
    'Osteoarthritis': 'OA',
    'Rheumatoid arthritis': 'RA',
    'All-cause dementia': 'Dementia'
}



figs = plot_none_statistics(
    realprot_shortdist_sub,
    synprot_shortdist_sub,
    label1="Measured Proteins",
    label2="RABIT Proteins",
    exclude_list=exclude_list,
    label_map=label_replacements,
    colors=("#d89a97", "#94bed8"),
    save=True,
    out_dir="./figs_none",
    fname_prefix="none_stats",
    save_formats=("pdf",)
)


In [None]:
def compare_distance_dicts_pctplots(
    dict1: Dict[str, Dict[str, Dict[str, Any]]],
    dict2: Dict[str, Dict[str, Dict[str, Any]]],
    *,
    label1: str = "RealProt",
    label2: str = "Synprot",
    delta: float | None = 6.0,
    threshold: float | str = None,
    baseline_dict: dict | None = None,
    label_replacements: dict | None = None,
    exclude: list[str] | None = None,
    figsize=(6, 6),
    jitter: float = 0.08,
    seed=None,
    save_pdf: bool = False,
    out_dir: str | os.PathLike = "./figures",
    pdf_filename: str = "distance_plots.pdf",
    dpi: int = 300,
    show: bool = True,
    base_fontsize: int = 18,
    color1: str = "lightblue",
    color2: str = "lightcoral",
    colors: tuple[str, str] | None = None,
    bar_alpha: float = 0.85,
    bar_edgecolor: str = "black",
    diag_line_color: str = "black",
    violin_inner_color: str = "black",
    violin_scatter_color: str = "black",
    violin_scatter_alpha: float = 0.8,
    violin_linewidth: float = 1.2,
):
    if colors:
        color1, color2 = colors
    if seed is not None:
        random.seed(seed)

    mpl.rcParams.update({
        "font.size": base_fontsize,
        "axes.labelsize": base_fontsize,
        "axes.titlesize": base_fontsize * 1.25,
        "xtick.labelsize": base_fontsize * 0.9,
        "ytick.labelsize": base_fontsize * 0.9,
        "legend.fontsize": base_fontsize * 0.9,
        "figure.titlesize": base_fontsize * 1.3,
    })

    exclude = set(exclude or [])

    pdf = None
    if save_pdf:
        out_dir = Path(out_dir)
        out_dir.mkdir(parents=True, exist_ok=True)
        pdf = PdfPages(out_dir / pdf_filename)

    def save_fig(fig):
        if pdf:
            fig.savefig(pdf, format="pdf", bbox_inches="tight", dpi=dpi)

    def show_fig(fig):
        if show:
            plt.show()
        else:
            plt.close(fig)

    def avg_distances(nested, delta_val):
        out = {}
        for disease, prots in nested.items():
            dists = []
            for info in prots.values():
                v = info.get("average_distance")
                if isinstance(v, (int, float)) and not np.isnan(v):
                    dists.append(float(v))
                elif delta_val is not None:
                    dists.append(delta_val)
            out[disease] = np.mean(dists) if dists else float("nan")
        return out

    common = sorted(d for d in dict1 if d in dict2 and d not in exclude)
    if not common:
        raise ValueError("No overlapping diseases after exclusions.")

    # Threshold setup
    if threshold is None:
        raise ValueError("Need a threshold value or 'bydisease'.")

    if threshold == "bydisease":
        if baseline_dict is None:
            raise ValueError("baseline_dict required when threshold='bydisease'")
        base_means = avg_distances(baseline_dict, delta)
        thresh_map = {d: base_means[d] for d in common if d in base_means}
    elif isinstance(threshold, (int, float)):
        thresh_map = {d: float(threshold) for d in common}
    else:
        raise ValueError("Invalid threshold type.")

    # Compute % below threshold
    def pct_below(nested, tmap):
        res = {}
        for disease, prots in nested.items():
            if disease not in tmap or not prots or disease in exclude:
                continue
            thr = tmap[disease]
            total = len(prots)
            below = sum(
                1 for p in prots.values()
                if isinstance((v := p.get("average_distance")), (int, float)) and v < thr
            )
            res[disease] = 100.0 * below / total
        return res

    pct1 = pct_below(dict1, thresh_map)
    pct2 = pct_below(dict2, thresh_map)

    common_pct = [d for d in common if d in pct1 and d in pct2]
    if not common_pct:
        raise ValueError("No diseases with valid percentages.")

    df = pd.DataFrame({
        "Disease": common_pct,
        label1: [pct1[d] for d in common_pct],
        label2: [pct2[d] for d in common_pct],
        "Threshold": [thresh_map[d] for d in common_pct]
    })

    # Label replacements
    repl = label_replacements or {}
    df["Disease_plot"] = df["Disease"].map(repl).fillna(df["Disease"])

    # 1. Bar plot by disease
    x = np.arange(len(df))
    bw = 0.4
    fig_bar, ax_bar = plt.subplots(figsize=(max(6, len(df) * 0.55), 6))
    ax_bar.bar(x - bw/2, df[label1], bw, label=label1, color=color1, edgecolor=bar_edgecolor, alpha=bar_alpha)
    ax_bar.bar(x + bw/2, df[label2], bw, label=label2, color=color2, edgecolor=bar_edgecolor, alpha=bar_alpha)
    ax_bar.set_xticks(x)
    ax_bar.set_xticklabels(df["Disease_plot"], rotation=90)
    ylabel = "% proteins below disease-specific threshold" if threshold == "bydisease" else f"% proteins with distance < {threshold}"
    ax_bar.set_ylabel(ylabel)
    ax_bar.legend(frameon=False)
    fig_bar.tight_layout()
    save_fig(fig_bar); show_fig(fig_bar)

    # 2. Violin plot
    long_df = pd.melt(df, id_vars="Disease", value_vars=[label1, label2],
                      var_name="Source", value_name="PercentBelow")
    fig_violin, ax_violin = plt.subplots(figsize=figsize)
    sns.violinplot(data=long_df, x="Source", y="PercentBelow",
                   palette={label1: color1, label2: color2},
                   inner=None, ax=ax_violin)
    for i, src in enumerate([label1, label2]):
        vals = df[src]
        ax_violin.vlines(i, vals.min(), vals.max(), color=violin_inner_color, linewidth=violin_linewidth)
        ax_violin.hlines(vals.median(), i - 0.15, i + 0.15, color=violin_inner_color, linewidth=2)
        ax_violin.scatter(
            np.full(vals.size, i) + np.random.uniform(-jitter, jitter, vals.size),
            vals, color=violin_scatter_color, s=35, alpha=violin_scatter_alpha
        )
    ax_violin.set_ylabel(ylabel)
    fig_violin.tight_layout()
    save_fig(fig_violin); show_fig(fig_violin)

    # 3. Scatter plot per disease
    fig_scatter, ax_scatter = plt.subplots(figsize=figsize)
    ax_scatter.scatter(df[label2], df[label1], color="grey", edgecolor="black", s=60)
    lims = [0, max(df[label1].max(), df[label2].max()) * 1.05]
    ax_scatter.plot(lims, lims, ls="--", color=diag_line_color, linewidth=1)
    for _, row in df.iterrows():
        ax_scatter.annotate(repl.get(row["Disease"], row["Disease"]),
                            (row[label2], row[label1]),
                            textcoords="offset points", xytext=(5, 5),
                            fontsize=int(base_fontsize * 0.85))
    ax_scatter.set_xlabel(f"{label2}: % below threshold")
    ax_scatter.set_ylabel(f"{label1}: % below threshold")
    fig_scatter.tight_layout()
    save_fig(fig_scatter); show_fig(fig_scatter)

    # Pearson correlation
    r, p = pearsonr(df[label2], df[label1])
    print(f"Pearson r = {r:.3f}, p = {p:.3g}")

    if pdf:
        pdf.close()
        print(f"Saved PDF to {out_dir / pdf_filename}")

    return df


In [None]:
label_replacements = {
    'Parkinson’s disease and parkinsonism': 'PD',
    'Alzheimer’s disease':  'AD',
    'Chronic liver diseases': 'CLD',
    'Ischemic heart disease': 'IHD',
    'Osteoporosis': 'OSP',
    'Emphysema, COPD': 'COPD',
    'Type 2 diabetes': 'T2DM',
    'Chronic kidney diseases': 'CKD',
    'Leukemia': 'Leukemia',
    'Non-Hodgkin lymphoma': 'NHL',
    'Cerebrovascular diseases': 'CBVD',
    'Osteoarthritis': 'OA',
    'Rheumatoid arthritis': 'RA',
    'All-cause dementia': 'Dementia'
}

df_pct = compare_distance_dicts_pctplots(
    realprot_shortdist_sub,
    synprot_shortdist_sub,
    delta=None,
    threshold="bydisease",
    label1="M-M", label2="R-R",
    baseline_dict=all_olink_proteins_shortdist,
    label_replacements=label_replacements,
    exclude=[],
    save_pdf=True,
    out_dir="../figures/figure3",
    pdf_filename="3c.pdf",
    color1="#d89a97",
    color2="#94bed8",
)


# GWAS Analysis

In [None]:
# list of disease names
ref_gene_dict = {
    'Leukemia': ['gwas-association-downloaded_2025-05-16-EFO_0000565-withChildTraits.tsv'],
    'Non-Hodgkin lymphoma': ['gwas-association-downloaded_2025-05-16-EFO_0005952-withChildTraits.tsv'],
    'Type 2 diabetes': ['gwas-association-downloaded_2025-05-16-MONDO_0005148.tsv'],
    'Ischemic heart disease': ['gwas-association-downloaded_2025-05-16-EFO_1001375-withChildTraits.tsv'],
    'Cerebrovascular diseases': ['gwas-association-downloaded_2025-05-16-EFO_0003763-withChildTraits.tsv'],
    'Emphysema, COPD': ['gwas-association-downloaded_2025-05-16-EFO_0000464.tsv', 'gwas-association-downloaded_2025-05-16-EFO_0000341-withChildTraits.tsv'],
    'Chronic liver diseases': ['gwas-association-downloaded_2025-05-16-EFO_0001421-withChildTraits.tsv'],
    'Chronic kidney diseases': ['gwas-association-downloaded_2025-05-16-EFO_0003884-withChildTraits.tsv'],
    'All-cause dementia': ['gwas-association-downloaded_2025-05-16-MONDO_0001627-withChildTraits.tsv'],
    'Alzheimer’s disease': ['gwas-association-downloaded_2025-05-16-MONDO_0004975-withChildTraits.tsv'],
    'Parkinson’s disease and parkinsonism': ['gwas-association-downloaded_2025-05-16-MONDO_0005180-withChildTraits.tsv'],
    'Rheumatoid arthritis': ['gwas-association-downloaded_2025-05-16-EFO_0000685-withChildTraits.tsv'],
    'Osteoporosis': ['gwas-association-downloaded_2025-05-16-EFO_0003882-withChildTraits.tsv'],
    'Osteoarthritis': ['gwas-association-downloaded_2025-05-16-MONDO_0005178-withChildTraits.tsv']
}

ref_gene_dict

In [None]:
def match_genes(
    lookup_genes:    Iterable[str],
    reference_genes: Iterable[str],
) -> pd.DataFrame:
    ref_set = {g.lower() for g in reference_genes}
    rows = [
        {
            "lookup_gene":     g,
            "reference_match": g if g.lower() in ref_set else None,
            "is_matched":      g.lower() in ref_set,
        }
        for g in lookup_genes
    ]
    return pd.DataFrame(rows)

def process_ref_genes_all(
    disease_dict: Dict[str, Dict[str, Dict[str, Any]]],
    ref_gene_dict: Dict[str, List[str]],
    gwas_dir:      str | os.PathLike,
    *,
    gene_col:  str  = "MAPPED_GENE",
    pval_col:  str  = "P-VALUE",     
    be_strict: bool = False,          
) -> Dict[str, pd.DataFrame]:
    matched_dfs: Dict[str, pd.DataFrame] = {}

    for disease, protein_map in disease_dict.items():
        if disease in exclude_list or disease not in ref_gene_dict:
            continue

        ref_genes: set[str] = set()
        for tsv in ref_gene_dict[disease]:
            path     = Path(gwas_dir) / tsv
            gwas_df  = pd.read_csv(path, sep="\t", low_memory=False)

            # if strict: filter for gene hits of genomic significance
            if be_strict:
                gwas_df = gwas_df[gwas_df[pval_col] < 5e-8]

            genes = (
                gwas_df[gene_col]
                  .dropna()
                  .astype(str)
                  .str.strip()
                  .unique()         
            )
            ref_genes.update(genes) 

        lookup_genes = protein_map.keys()        
        matched_dfs[disease] = match_genes(lookup_genes, ref_genes)

    return matched_dfs

In [None]:
gwas_dir = "./gwas_validation"
# match measured protein features to gwas hits
matched_results_real = process_ref_genes_all(
    disease_dict   = realprot_shortdist_sub,
    ref_gene_dict  = ref_gene_dict, 
    gwas_dir       = gwas_dir
    be_strict=True
)


In [None]:
gwas_dir = "./gwas_validation"
# match rabit protein features to gwas hits
matched_results_syn = process_ref_genes_all(
    disease_dict   = synprot_shortdist_sub, 
    ref_gene_dict  = ref_gene_dict,
    gwas_dir       = gwas_dir,
    be_strict=True
)



In [None]:
def plot_match_percentages_by_disease_nobins(
    real_results: Dict[str, pd.DataFrame],
    syn_results:  Dict[str, pd.DataFrame],
    *,
    # --- NEW / enhanced color + style knobs ---
    real_color: str = "#4C78A8",
    syn_color:  str = "#F58518",
    colors: Optional[Tuple[str, str]] = None,      # legacy tuple; overrides above if given
    real_alpha: float = 1.0,
    syn_alpha:  float = 1.0,
    edgecolor:  Optional[str] = None,
    edgewidth:  float = 0.0,
    legend_marker: str = "s",

    real_label: str = "Real Proteins",
    syn_label:  str = "Synthetic Proteins",
    label_replacements: Dict[str, str] | None = None,
    ncols: int = 3,
    title_fontsize: float = 16,
    label_fontsize: float = 14,
    tick_fontsize: float = 12,
    legend_fontsize: float = 16,
    legend_markerscale: float = 2.0,
    figsize_per_plot: tuple = (4, 4),

    # saving / showing
    save: bool = False,
    save_path: str = "match_percentages_grid.png",
    dpi: int = 300,
    show: bool = True,
):
    if colors is not None:
        if len(colors) != 2:
            raise ValueError("`colors` must be a 2-tuple (real_color, syn_color).")
        real_color, syn_color = colors

    if label_replacements is None:
        label_replacements = {}

    # ── 1. collect diseases present in either dict ──────────────────────
    diseases = sorted(set(real_results) | set(syn_results))
    if not diseases:
        print("No diseases to plot.")
        return None, None

    # ── 2. grid layout (reserve an empty cell for the legend) ───────────
    n = len(diseases)
    ncols = max(1, ncols)
    nrows = math.ceil(n / ncols)
    total_cells = nrows * ncols
    if total_cells == n:
        nrows += 1
        total_cells = nrows * ncols

    fig, axes = plt.subplots(
        nrows, ncols,
        figsize=(ncols * figsize_per_plot[0], nrows * figsize_per_plot[1]),
        squeeze=False,
    )
    axes = axes.flatten()

    # legend handles
    legend_handles = [
        Line2D([], [], marker=legend_marker, linestyle="", color=real_color,
               markeredgecolor=edgecolor, markeredgewidth=edgewidth, alpha=real_alpha),
        Line2D([], [], marker=legend_marker, linestyle="", color=syn_color,
               markeredgecolor=edgecolor, markeredgewidth=edgewidth, alpha=syn_alpha),
    ]
    legend_labels = [real_label, syn_label]

    # ── 3. per-disease plotting ─────────────────────────────────────────
    w = 0.4  # bar width
    for idx, disease in enumerate(diseases):
        ax = axes[idx]

        df_real = real_results.get(disease)
        df_syn  = syn_results.get(disease)

        def pct_match(df: pd.DataFrame | None) -> float | None:
            if df is None or df.empty:
                return None
            if "is_matched" in df:
                s = df["is_matched"]
            elif "reference_match" in df:
                s = df["reference_match"].notnull().astype(int)
            else:
                return None
            return s.mean() * 100

        p_r = pct_match(df_real)
        p_s = pct_match(df_syn)
        p_r = 0.0 if p_r is None else p_r
        p_s = 0.0 if p_s is None else p_s

        ax.bar(-w / 2, p_r, width=w,
               color=real_color, alpha=real_alpha,
               edgecolor=edgecolor, linewidth=edgewidth)
        ax.bar( w / 2, p_s, width=w,
               color=syn_color,  alpha=syn_alpha,
               edgecolor=edgecolor, linewidth=edgewidth)

        ax.set_xticks([-w / 2, w / 2])
        ax.set_xticklabels([real_label, syn_label],
                           fontsize=tick_fontsize, rotation=90)

        title = label_replacements.get(disease, disease)
        ax.set_title(title, fontsize=title_fontsize)
        ax.set_ylabel("Match Percentage (%)", fontsize=label_fontsize)
        ax.tick_params(axis="y", labelsize=tick_fontsize)

        ymax = max(p_r, p_s, 2) * 1.15
        ax.set_ylim(0, ymax)

    # ── 4. hide extra subplots except legend cell ───────────────────────
    first_empty = len(diseases)
    for ax in axes[first_empty + 1:]:
        ax.set_visible(False)

    # legend in the first empty cell
    if first_empty < len(axes):
        legend_ax = axes[first_empty]
        legend_ax.axis("off")
        legend_ax.legend(
            legend_handles,
            legend_labels,
            loc="center",
            fontsize=legend_fontsize,
            markerscale=legend_markerscale,
            frameon=False,
        )

    plt.tight_layout()

    if save:
        os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
        fig.savefig(save_path, dpi=dpi, bbox_inches="tight")

    if show:
        plt.show()
    else:
        plt.close(fig)

    return fig, axes


label_replacements = {
    'Parkinson’s disease and parkinsonism': 'PD',
    'Alzheimer’s disease':  'AD',
    'Chronic liver diseases': 'CLD',
    'Ischemic heart disease': 'IHD',
    'Osteoporosis': 'OSP',
    'Emphysema, COPD': 'COPD',
    'Type 2 diabetes': 'T2DM',
    'Chronic kidney diseases': 'CKD',
    'Leukemia': 'Leukemia',
    'Non-Hodgkin lymphoma': 'NHL',
    'Cerebrovascular diseases': 'CBVD',
    'Osteoarthritis': 'OA',
    'Rheumatoid arthritist': 'RA',
    'All-cause dementia': 'Dementia'
}

fig, axes = plot_match_percentages_by_disease_nobins(
    matched_results_real,
    matched_results_syn,
    real_label=r"$M\!-\!M$",
    syn_label=r"$R\!-\!R$",
    real_color="#d89a97",
    syn_color="#94bed8",
    edgecolor="#333333",
    edgewidth=0.8,
    real_alpha=0.9,
    syn_alpha=0.9,
    ncols=4,
    title_fontsize=20,
    label_fontsize=16,
    tick_fontsize=14,
    legend_fontsize=18,
    legend_markerscale=3,
    label_replacements=label_replacements,
    save=False
)



In [None]:
def plot_overall_match_percentages(
    real_results: Dict[str, pd.DataFrame],
    syn_results:  Dict[str, pd.DataFrame],
    *,
    save: bool = False,
    out_dir: str | os.PathLike = "./figures",
    filename_base: str = "overall_match_scatter",   # will save as *_labels.pdf
    real_label: str = r"$M-M$",
    syn_label:  str = r"$R-R$",
    point_color: str = "grey",
    edge_color: str = "black",
    point_size: int = 60,
    label_replacements: Dict[str, str] | None = None,
    text_fontsize: int = 9,
    text_offset: Tuple[int, int] = (5, 5),
) -> Tuple[float, float]:
    # per-disease % matched
    shared = sorted(set(real_results) & set(syn_results))
    if not shared:
        raise ValueError("The two dictionaries share no disease names.")

    real_pct = np.array([real_results[d]["is_matched"].mean() * 100 for d in shared])
    syn_pct  = np.array([syn_results[d]["is_matched"].mean() * 100 for d in shared])

    # Pearson correlation
    r, p = pearsonr(syn_pct, real_pct)
    print(f"Pearson r = {r:.3f}, p = {p:.3g}")

    # labeled scatter
    lim_max = max(real_pct.max(), syn_pct.max()) * 1.05
    repl = label_replacements or {}

    fig, ax = plt.subplots(figsize=(6, 6))
    ax.scatter(syn_pct, real_pct, color=point_color, edgecolor=edge_color, s=point_size, zorder=3)
    ax.plot([0, lim_max], [0, lim_max], ls="--", color="black", linewidth=1)
    ax.set_xlabel(f"{syn_label} Match %")
    ax.set_ylabel(f"{real_label} Match %")
    ax.set_xlim(0, lim_max)
    ax.set_ylim(0, lim_max)

    for d, x, y in zip(shared, syn_pct, real_pct):
        ax.annotate(repl.get(d, d), (x, y), textcoords="offset points",
                    xytext=text_offset, fontsize=text_fontsize)

    fig.tight_layout()

    if save:
        os.makedirs(out_dir, exist_ok=True)
        path = os.path.join(out_dir, f"{filename_base}_labels.pdf")
        fig.savefig(path, format="pdf", bbox_inches="tight")
        print(f"Saved → {path}")

    plt.show()
    return r, p



r, p = plot_overall_match_percentages(
    real_results=matched_results_real,
    syn_results=matched_results_syn,
    save=True,
    out_dir="/outputdir/",
    filename_base="/filename/",
    label_replacements=label_replacements
)
