In [None]:
%load_ext autoreload
%autoreload 2
import sys
import os
import yaml
ProjDIR = "/home/jw3514/Work/ASD_Circuits_CellType/" # Change to your project directory
sys.path.insert(1, f'{ProjDIR}/src/')
from ASD_Circuits import *

try:
    os.chdir(f"{ProjDIR}/notebook_rebuttal/")
    print(f"Current working directory: {os.getcwd()}")
except FileNotFoundError as e:
    print(f"Error: Could not change directory - {e}")
except Exception as e:
    print(f"Unexpected error: {e}")


HGNC, ENSID2Entrez, GeneSymbol2Entrez, Entrez2Symbol = LoadGeneINFO()

In [None]:
# Load config file
with open("../config/config.yaml", "r") as f:
    config = yaml.safe_load(f)

expr_matrix_path = config["analysis_types"]["STR_ISH"]["expr_matrix"]
STR_BiasMat = pd.read_parquet(f"../{expr_matrix_path}")
Anno = STR2Region()

In [None]:
def plot_structure_bias_correlation(df_a, df_b, label_a='Dataset A', label_b='Dataset B', title=None):
    """
    Create comparison plot between two structure bias datasets

    Parameters:
    df_a: DataFrame with EFFECT column for first dataset
    df_b: DataFrame with EFFECT column for second dataset
    label_a: Label for x-axis (first dataset)
    label_b: Label for y-axis (second dataset)
    title: Custom title for the plot (ignored, no title drawn)

    Returns:
    correlation: Pearson correlation coefficient
    """
    from scipy.stats import pearsonr

    plt.style.use('seaborn-v0_8-whitegrid')
    fig, ax = plt.subplots(1, 1, dpi=120, figsize=(5, 4), facecolor='none')

    fig.patch.set_alpha(0)
    ax.patch.set_alpha(0)

    # Merge the datasets on structure names for comparison
    merged_data = pd.merge(df_a[['EFFECT']], df_b[['EFFECT']], 
                          left_index=True, right_index=True, suffixes=('_A', '_B'))

    # Create scatter plot
    ax.scatter(merged_data['EFFECT_A'], merged_data['EFFECT_B'], 
              alpha=1, s=20, c='#1f77b4', edgecolors='black', linewidth=0.5)

    # Add diagonal line for reference
    min_val = min(merged_data['EFFECT_A'].min(), merged_data['EFFECT_B'].min())
    max_val = max(merged_data['EFFECT_A'].max(), merged_data['EFFECT_B'].max())
    ax.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.8, linewidth=2, label='y=x')

    # Calculate correlation and p-value
    correlation, pval = pearsonr(merged_data['EFFECT_A'], merged_data['EFFECT_B'])

    # Set labels (no title)
    ax.set_xlabel(f'{label_a}', fontsize=14)
    ax.set_ylabel(f'{label_b}', fontsize=14)

    # Format p-value display according to instructions
    if pval < 1e-10:
        p_disp = f"p <{1e-10:.0e}"
    else:
        p_disp = f"p ={pval:.2g}"

    # Put correlation and p-value inside plot as annotation
    ax.annotate(f'r = {correlation:.3f}\n{p_disp}',
                xy=(0.05, 0.95), xycoords='axes fraction',
                ha='left', va='top', fontsize=15,
                bbox=dict(boxstyle="round,pad=0.3", fc="w", ec="gray", alpha=0.8))

    # Add grid
    ax.grid(True, linestyle='--', alpha=0.7)

    # Add legend
    ax.legend(fontsize=12)

    # Make axes equal for better comparison
    ax.set_aspect('equal', adjustable='box')

    plt.tight_layout() 
    plt.show()

    # Print summary statistics
    #print(f"Correlation between {label_a} and {label_b} structure bias: {correlation:.4f}")
    #print(f"Number of structures compared: {len(merged_data)}")

    return correlation

In [None]:
Spark_ASD_STR_Bias = pd.read_csv("../dat/Unionize_bias/Spark_Meta_EWS.Z2.bias.FDR.csv", index_col=0)
Spark_ASD_STR_Bias["Region"] = [Anno.get(ct_idx, "Unknown") for ct_idx in Spark_ASD_STR_Bias.index.values]

Spark_ASD_159_STR_Bias = pd.read_csv("/home/jw3514/Work/ASD_Circuits_CellType/results/STR_ISH/ASD_SPARK_159_bias_addP_sibling.csv", index_col=0)
Spark_ASD_159_STR_Bias["Region"] = [Anno.get(ct_idx, "Unknown") for ct_idx in Spark_ASD_159_STR_Bias.index.values]


## Saterstorm et al. 

In [None]:
file_path = "/home/jw3514/Work/ASD_Circuits_CellType/dat/Genetics/ASC-102Genes.xlsx"
ASC_102_DF_Autosomal = pd.read_excel(file_path, sheet_name="Autosomal", skiprows=0)
ASC_102_DF_Xlinked = pd.read_excel(file_path, sheet_name="ChrX", skiprows=0)
ASC_102_DF = pd.read_excel(file_path, sheet_name="102_ASD", skiprows=0)

In [None]:
ASC_102_DF = ASC_102_DF[["gene", "entrez_id", "dn.ptv", "dn.misa", "dn.misb", "qval_dnccPTV"]]
ASC_102_DF = ASC_102_DF[ASC_102_DF["entrez_id"].notna()].copy()
ASC_102_DF["entrez_id"] = ASC_102_DF["entrez_id"].astype(int)

In [None]:
# Add gamma columns from Autosomal or Xlinked dataframes
# Combine the gamma columns from both dataframes (genes are either autosomal or X-linked)
gamma_cols = ["gamma_dn.ptv", "gamma_dn.misa", "gamma_dn.misb"]
gamma_df_auto = ASC_102_DF_Autosomal[["gene"] + gamma_cols].copy()
#gamma_df_x = ASC_102_DF_Xlinked[["gene"] + gamma_cols].copy()

# Concatenate both dataframes (genes appear in only one)
#gamma_df_combined = pd.concat([gamma_df_auto, gamma_df_x], ignore_index=True)

# Merge with ASC_102_DF to add gamma columns
ASC_102_DF = ASC_102_DF.merge(
    gamma_df_auto,
    on="gene",
    how="left"
)


In [None]:
ASC_102_DF

In [None]:
# Fu et al. 2022 USE BF for each mutation type as weight
def GeneWeights_ASC_102(ASC_102_DF):
    gene2MutN = {}
    for i, row in ASC_102_DF.iterrows():
        symbol = row["gene"]
        try:
            g = GeneSymbol2Entrez[symbol]
            entrez = row["entrez_id"]
            PR_LGD = ASC_102_DF.loc[i, "gamma_dn.ptv"]
            PR_MisA = ASC_102_DF.loc[i, "gamma_dn.misa"]
            PR_MisB = ASC_102_DF.loc[i, "gamma_dn.misb"]
            if entrez not in gene2MutN:
                gene2MutN[entrez] = 0
            gene2MutN[entrez] += (
                row["dn.ptv"] * PR_LGD +
                row["dn.misb"] * PR_MisB +
                row["dn.misa"] * PR_MisA
            )
        except Exception as e:
            print(f"Skipping gene {symbol} due to error: {e}")
    return gene2MutN

In [None]:
GW_ASC_102 = GeneWeights_ASC_102(ASC_102_DF)
Dict2Fil(GW_ASC_102, ProjDIR+"/dat/Genetics/GeneWeights/GW_ASC_102.gw")
ASC_102_STR_Bias = MouseSTR_AvgZ_Weighted(STR_BiasMat, GW_ASC_102)
ASC_102_STR_Bias["Region"] = [Anno.get(ct_idx, "Unknown") for ct_idx in ASC_102_STR_Bias.index.values]

In [None]:
plot_structure_bias_correlation(Spark_ASD_STR_Bias, ASC_102_STR_Bias, label_a='Mutation Bias \nZhou et al. 61 ASD genes', label_b='Mutation Bias \nSatterstorm et al. 102 ASD genes', title='Structure Bias Comparison: Spark vs ASC_102')

## Fu et al. 

In [None]:
file_path = "/home/jw3514/Work/ASD_Circuits_CellType/dat/Genetics/Fu_et_al_2022.xlsx"
fu_DF_SSCASC = pd.read_excel(file_path, sheet_name="Supplementary Table 5", skiprows=0)
fu_DF_SPARK = pd.read_excel(file_path, sheet_name="Supplementary Table 6", skiprows=0)
fu_DF_TADA_PR = pd.read_excel(file_path, sheet_name="Supplementary Table 8", skiprows=0)
fu_DF_Pval = pd.read_excel(file_path, sheet_name="Supplementary Table 11", skiprows=0)

fu_DF_SSCASC = fu_DF_SSCASC[fu_DF_SSCASC["gene_id"].notna()]
fu_DF_SPARK = fu_DF_SPARK[fu_DF_SPARK["gene_id"].notna()]
fu_DF_Pval = fu_DF_Pval[fu_DF_Pval["gene_id"].notna()]
fu_DF_TADA_PR = fu_DF_TADA_PR[fu_DF_TADA_PR["gene_id"].notna()]
fu_DF_TADA_PR.set_index("gene_gencodeV33", inplace=True)

In [None]:
Fu_ASD_72 = fu_DF_Pval[fu_DF_Pval["ASD72"]==1]
Fu_ASD_185 = fu_DF_Pval[fu_DF_Pval["ASD185"]==1]

In [None]:
# Fu et al. 2022 USE BF for each mutation type as weight
def GeneWeights_Fu2022(DF_Filt, DF_PR, mut_DFs):
    gene2MutN = {}
    for DF in mut_DFs:
        DF = DF[DF["gene_gencodeV33"].isin(DF_Filt["gene_gencodeV33"])]
        
        for i, row in DF.iterrows():
            symbol = row["gene_gencodeV33"]
            try:
                g = GeneSymbol2Entrez[symbol]
                PR_LGD = DF_PR.loc[symbol, "prior.dn.ptv"]
                PR_MisA = DF_PR.loc[symbol, "prior.dn.misa"]
                PR_MisB = DF_PR.loc[symbol, "prior.dn.misb"]
                if g not in gene2MutN:
                    gene2MutN[g] = 0
                gene2MutN[g] += (
                    row["dn.ptv"] * PR_LGD +
                    row["dn.misb"] * PR_MisB +
                    row["dn.misa"] * PR_MisA
                )
            except Exception as e:
                print(f"Skipping gene {symbol} due to error: {e}")
    return gene2MutN

In [None]:
GW_Fu_ASD_72 = GeneWeights_Fu2022(Fu_ASD_72, fu_DF_TADA_PR, [fu_DF_SSCASC, fu_DF_SPARK])
GW_Fu_ASD_185 = GeneWeights_Fu2022(Fu_ASD_185, fu_DF_TADA_PR, [fu_DF_SSCASC, fu_DF_SPARK])
Dict2Fil(GW_Fu_ASD_72, ProjDIR+"/dat/Genetics/GeneWeights/GW_Fu_ASD_72.gw")
Dict2Fil(GW_Fu_ASD_185, ProjDIR+"/dat/Genetics/GeneWeights/GW_Fu_ASD_185.gw")

In [None]:
Fu_ASD_72_STR_Bias = MouseSTR_AvgZ_Weighted(STR_BiasMat, GW_Fu_ASD_72)
Fu_ASD_72_STR_Bias["Region"] = [Anno.get(ct_idx, "Unknown") for ct_idx in Fu_ASD_72_STR_Bias.index.values]

Fu_ASD_185_STR_Bias = MouseSTR_AvgZ_Weighted(STR_BiasMat, GW_Fu_ASD_185)
Fu_ASD_185_STR_Bias["Region"] = [Anno.get(ct_idx, "Unknown") for ct_idx in Fu_ASD_185_STR_Bias.index.values]

In [None]:
Fu_ASD_72_STR_Bias.head(10)

In [None]:
Fu_ASD_185_STR_Bias.head(10)

In [None]:
plot_structure_bias_correlation(Spark_ASD_STR_Bias, Spark_ASD_159_STR_Bias, label_a='Mutation Bias \nZhou et al. 61 ASD genes', label_b='Mutation Bias \nZhou et al. 159 ASD genes', title='Structure Bias Comparison: Spark vs Fu_ASD_72')
plot_structure_bias_correlation(Spark_ASD_STR_Bias, Fu_ASD_72_STR_Bias, label_a='Mutation Bias \nZhou et al. 61 ASD genes', label_b='Mutation Bias \nFu et al. 72 ASD genes', title='Structure Bias Comparison: Spark vs Fu_ASD_72')
plot_structure_bias_correlation(Spark_ASD_STR_Bias, Fu_ASD_185_STR_Bias, label_a='Mutation Bias \nZhou et al. 61 ASD genes', label_b='Mutation Bias \nFu et al. 185 ASD genes', title='Structure Bias Comparison: Spark vs Fu_ASD_185')

In [None]:
ASD_Neuron_den_norm_bias = pd.read_csv("../dat/Unionize_bias/ASD.neuron.density.norm.bias.csv", 
                                      index_col="STR")
ASD_Glia_norm_bias = pd.read_csv("../dat/Unionize_bias/ASD.neuro2glia.norm.bias.csv", 
                                      index_col="STR")

In [None]:
plot_structure_bias_correlation(Spark_ASD_STR_Bias, ASD_Neuron_den_norm_bias, label_a='Mutation Bias \nZhou et al. 61 ASD genes', label_b='Neuronal Density Normalized Bias\nZhou et al. 61 ASD genes', title='Structure Bias Comparison: Spark vs Fu_ASD_72')
plot_structure_bias_correlation(Spark_ASD_STR_Bias, ASD_Glia_norm_bias, label_a='Mutation Bias \nZhou et al. 61 ASD genes', label_b='Neuro-to-Glia Ratio Normalized Bias\nZhou et al. 61 ASD genes', title='Structure Bias Comparison: Spark vs Fu_ASD_185')

In [None]:
ASD_Male = pd.read_csv("../dat/Unionize_bias/ASD.Male.ALL.bias.csv", 
                                      index_col="STR")
ASD_Female = pd.read_csv("../dat/Unionize_bias/ASD.Female.ALL.bias.csv", 
                                      index_col="STR")

In [None]:


plot_structure_bias_correlation(ASD_Male, ASD_Female, label_a='Male Mutation Bias\nZhou et al. 61 ASD genes', label_b='Female Mutation Bias\nZhou et al. 61 ASD genes', title='Structure Bias Comparison: Spark vs Fu_ASD_72')
#plot_structure_bias_correlation(Spark_ASD_STR_Bias, ASD_Female, label_a='Mutation Bias \nZhou et al. 61 ASD genes', label_b='Neuro-to-Glia Ratio Normalized Bias\nZhou et al. 61 ASD genes', title='Structure Bias Comparison: Spark vs Fu_ASD_185')

In [None]:
ScoreMatDir="/home/jw3514/Work/ASD_Circuits/dat/allen-mouse-conn/ScoreingMat_jw_v3/"
IpsiInfoMat=pd.read_csv(ScoreMatDir + "InfoMat.Ipsi.csv", index_col=0)
IpsiInfoMatShort_v1=pd.read_csv(ScoreMatDir + "InfoMat.Ipsi.Short.3900.csv", index_col=0)
IpsiInfoMatLong_v1=pd.read_csv(ScoreMatDir + "InfoMat.Ipsi.Long.3900.csv", index_col=0)

DIR = "/home/jw3514/Work/ASD_Circuits/scripts/RankScores/"
Cont_Distance = np.load("{}/RankScore.Ipsi.Cont.npy".format(DIR))
Cont_DistanceShort = np.load("{}/RankScore.Ipsi.Short.3900.Cont.npy".format(DIR))
Cont_DistanceLong = np.load("{}/RankScore.Ipsi.Long.3900.Cont.npy".format(DIR))

In [None]:
from collections import defaultdict

def compute_circuit_scores_for_profiles(
    profile_bias_dict,
    topNs,
    info_mats_dict,
    scoring_func=ScoreCircuit_SI_Joint,
):
    """
    Compute circuit scores for multiple profiles and connection types.

    Args:
        profile_bias_dict: dict of {profile_name: STR_Bias DataFrame}
        topNs: list of top N structure ranks to scan
        info_mats_dict: dict of {conn_type: info_mat pandas DataFrame}
        scoring_func: function (top_str_list, info_mat) => score

    Returns:
        results: dict of {profile_name: {conn_type: np.array of scores}}
    """
    results = {profile_name: {conn_type: [] for conn_type in info_mats_dict}
               for profile_name in profile_bias_dict}
    for profile_name, bias_df in profile_bias_dict.items():
        str_ranks = bias_df.sort_values("EFFECT", ascending=False).index.values
        for topN in topNs:
            top_strs = str_ranks[:topN]
            for conn_type, info_mat in info_mats_dict.items():
                score = scoring_func(top_strs, info_mat)
                results[profile_name][conn_type].append(score)
    # Convert lists to np.arrays for easier plotting
    for profile_name in results:
        for conn_type in results[profile_name]:
            results[profile_name][conn_type] = np.array(results[profile_name][conn_type])
    return results

def plot_circuit_connectivity_scores_multi(
    topNs,
    circuit_scores_results,
    cont_distance_dict,
    profile_plot_kwargs=None,
    show_siblings=True,
    profile_labels=None,
    xlim=(0, 121)
):
    """
    Plot circuit connectivity scores for multiple profiles and connection types.
    - circuit_scores_results: output of compute_circuit_scores_for_profiles
    - cont_distance_dict: {conn_type: np.array [n_iter, len(topNs)]}
    - profile_plot_kwargs: {profile_name: {kwargs for plt.plot}}
    - profile_labels: {profile_name: label string}
    """
    conn_types = list(cont_distance_dict.keys())
    n_conn = len(conn_types)
    fig, axes = plt.subplots(n_conn, 1, dpi=480, figsize=(7, 4*n_conn))

    if n_conn == 1:
        axes = [axes]

    BarLen = 34.1

    colors = ["blue", "red", "purple", "orange", "green", "black", "brown"]
    if profile_plot_kwargs is None:
        profile_plot_kwargs = {}
    if profile_labels is None:
        profile_labels = {}

    for i, conn_type in enumerate(conn_types):
        ax = axes[i]
        cont = (np.median if not conn_type.lower().startswith("long") else np.nanmean)(cont_distance_dict[conn_type], axis=0)
        # Plot scores for all profiles
        for idx, (profile_name, prof_scores) in enumerate(circuit_scores_results.items()):
            scores = prof_scores[conn_type]
            label = profile_labels.get(profile_name, profile_name)
            plot_args = profile_plot_kwargs.get(profile_name, {})
            if not plot_args:
                # Generate plot styles dynamically
                plot_args = dict(
                    color=colors[idx % len(colors)],
                    marker=["o", "s", "^", "d", "x", "v"][idx % 6],
                    markersize=5 if idx == 0 else 3,
                    lw=1,
                    ls='dashed' if idx == 0 else '-',
                    label=label
                )
            ax.plot(topNs, scores, **plot_args)

        # Plot Sibling controls
        if show_siblings:
            #if "nan" in str(type(cont_distance_dict[conn_type])):
            lower = np.nanpercentile(cont_distance_dict[conn_type], 50-BarLen, axis=0)
            upper = np.nanpercentile(cont_distance_dict[conn_type], 50+BarLen, axis=0)
            #else:
            #    lower = np.percentile(cont_distance_dict[conn_type], 50-BarLen, axis=0)
            #    upper = np.percentile(cont_distance_dict[conn_type], 50+BarLen, axis=0)
            ax.errorbar(
                topNs, cont, color="grey", marker="o", markersize=1.5, lw=1,
                yerr=(cont - lower, np.abs(upper - cont)),
                ls="dashed", label="Siblings"
            )

        ax.set_xlabel("Structure Rank\n", fontsize=17)
        ax.set_ylabel("Circuit Connectivity Score", fontsize=15)
        ax.legend(fontsize=13)
        ax.set_xlim(*xlim)
        ax.grid(True)
        ax.set_title(f'Connection Type: {conn_type}', fontsize=15)
    plt.tight_layout()
    return fig

# -- USAGE EXAMPLE/INSTANTIATION --

topNs = list(range(200, 5, -1))

# Define profiles to use (easily reusable/extendable!)
profiles = {
    "Fu_ASD_185": Fu_ASD_185_STR_Bias,
    "Fu_ASD_72": Fu_ASD_72_STR_Bias,
    "Spark 61": Spark_ASD_STR_Bias,
    "Spark 159": Spark_ASD_159_STR_Bias,
    # "Spark": Spark_ASD_STR_Bias, # add more as needed
}
info_mats = {
    "Standard": IpsiInfoMat,
    "Short": IpsiInfoMatShort_v1,
    "Long": IpsiInfoMatLong_v1,
}
cont_distance_dict = {
    "Standard": Cont_Distance,
    "Short": Cont_DistanceShort,
    "Long": Cont_DistanceLong,
}

# Compute all
circuit_scores = compute_circuit_scores_for_profiles(
    profiles, topNs, info_mats
)

# Plot
fig = plot_circuit_connectivity_scores_multi(
    topNs,
    circuit_scores,
    cont_distance_dict,
    profile_labels={
        "Fu_ASD_185": "185 genes (Fu_ASD_185)",
        "Fu_ASD_72": "72 genes (Fu_ASD_72)",
    },
    xlim=(0, 121)
)
plt.show()

# Try Spark Genelist with different number

In [None]:
# Spark_Denovo_Stage1 = pd.read_excel("../dat/Genetics/41588_2022_1148_MOESM4_ESM.xlsx",
#                            skiprows=2, sheet_name="Table S6")
# Spark_Denovo_Stage1 = Spark_Denovo_Stage1[Spark_Denovo_Stage1[
#     "pDenovoWEST"]!="."]
# Spark_Denovo_Stage1.shape

Spark_Denovo_Stage1 = pd.read_excel("~/Work/SPARK2020/TabS_DenovoWEST_Stage1.xlsx",
                           skiprows=1, sheet_name="AllGenes")
Spark_Denovo_Stage1 = Spark_Denovo_Stage1[Spark_Denovo_Stage1["pDenovoWEST"]!="."]
Spark_Denovo_Stage1.shape                       

In [None]:
Spark_Denovo_Stage1.columns

In [None]:
Spark_Bias_DF_list = []
Spark_Bias_DF_list_Unif = []
TopN_list = np.concatenate([np.arange(200, 2000, 100), np.arange(2000, 5000, 500)])  # Combine arrays properly

for topN in TopN_list:
    # This gets topN genes, but excludes the top 61 (i.e. gets genes ranked 62 to topN, like in ASC cell 21)
    SPARK_ASD_topN = Spark_Denovo_Stage1.head(topN).iloc[61:]
    GW_Unif, GW = SPARK_Gene_Weights(SPARK_ASD_topN)
    SPARK_ASD_topN_STR_Bias = MouseSTR_AvgZ_Weighted(STR_BiasMat, GW)
    SPARK_ASD_topN_STR_Bias_Unif = MouseSTR_AvgZ_Weighted(STR_BiasMat, GW_Unif)
    Spark_Bias_DF_list.append(SPARK_ASD_topN_STR_Bias)
    Spark_Bias_DF_list_Unif.append(SPARK_ASD_topN_STR_Bias_Unif)

In [None]:
Spark_Bias_DF_list[0]

In [None]:
import seaborn as sns
from scipy.stats import spearmanr, pearsonr

def plot_correlation_profile_together(
    TopN_list, 
    DF_list1, 
    DF_list2, 
    reference_df,
    label1="Weighted", 
    label2="Uniform",
    title="Correlation with Reference vs TopN",
    ylabel="Spearman r with Reference"
):
    """
    Plot the correlation (Spearman r) of each DataFrame in the lists with the reference DataFrame over TopN_list.
    Ensures that the structure (index) between each df and reference_df are matched before calculation.
    """
    ref_effect = reference_df['EFFECT']
    ref_index = reference_df.index

    cors1 = []
    for df in DF_list1:
        # Align indices before comparing
        df_matched = df.loc[df.index.intersection(ref_index)].copy()
        ref_matched = ref_effect.loc[df_matched.index]
        cors1.append(spearmanr(df_matched['EFFECT'], ref_matched).correlation)

    cors2 = []
    for df in DF_list2:
        df_matched = df.loc[df.index.intersection(ref_index)].copy()
        ref_matched = ref_effect.loc[df_matched.index]
        cors2.append(spearmanr(df_matched['EFFECT'], ref_matched).correlation)
    
    plt.figure(figsize=(8, 5.5))
    sns.set(style="whitegrid", font_scale=1.25)
    sns.lineplot(x=TopN_list, y=cors1, marker='o', linewidth=2.5, label=label1, color='royalblue')
    sns.lineplot(x=TopN_list, y=cors2, marker='s', linewidth=2.5, label=label2, color='orange')
    plt.xlabel("Number of Top ASD Genes (Excluding Top 61)", fontsize=14)
    plt.ylabel(ylabel, fontsize=14)
    plt.title(title, fontsize=16, weight='bold')
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.legend(loc='best')
    plt.tight_layout()
    plt.show()

# Plot correlation with reference (Spark_ASD_STR_Bias)
plot_correlation_profile_together(
    TopN_list,
    Spark_Bias_DF_list,
    Spark_Bias_DF_list_Unif,
    reference_df=Spark_ASD_STR_Bias,
    label1="Weighted",
    label2="Uniform",
    title="SPARK: Correlation (Spearman) with Top 61 Gene Structure Bias",
    ylabel="Spearman r with Top 61 Gene Structure Bias"
)

# Try ASC genelist with different number

In [None]:
fu_DF_Pval = fu_DF_Pval.sort_values(by="p_TADA_ASD", ascending=True)
fu_DF_Pval.head(5)

In [None]:
Bias_DF_list = []
for topN in TopN_list:
    # This gets topN genes, but excludes the top 72 (i.e. gets genes ranked 73 to topN)
    Fu_ASD_topN = fu_DF_Pval.head(topN).iloc[72:]
    GW_Fu_ASD_topN = GeneWeights_Fu2022(Fu_ASD_topN, fu_DF_TADA_PR, [fu_DF_SSCASC, fu_DF_SPARK])
    Fu_ASD_topN_STR_Bias = MouseSTR_AvgZ_Weighted(STR_BiasMat, GW_Fu_ASD_topN)
    Bias_DF_list.append(Fu_ASD_topN_STR_Bias)

In [None]:
#plot_correlation_vs_topn(TopN_list, Bias_DF_list, Fu_ASD_72_STR_Bias, reference_label="Top 72 (ASC)", title="Correlation vs TopN_list (ASC)")

In [None]:
## Test SCZ
SCZ_GeneDF = pd.read_csv("/home/jw3514/Work/CellType_Psy/CellTypeBias_VIP/dat/SCZ.ALLGENE.MutCountModified.csv", index_col=0)

In [None]:
def Aggregate_Gene_Weights_SCZ_Daly(MutFil, allen_mouse_genes, usepLI=False, Bmis=False, out=None, mode="MC", 
                                  lgd_weight=0.33, mis3_weight=0.27, mis2_weight=0.12):
    print("New")
    assert mode in ["OR", "MC", "ORMC"]
    print(mode)
    gene2MutN = {}
    for i, row in MutFil.iterrows():
        try:
            g = int(i)
            if g not in allen_mouse_genes:
                print(g, "not in Expression dataset")
                continue
        except:
            print(g, "Error converting Entrez ID")
        if usepLI:
            try:
                pLI = float(row["pLI"])
            except:
                print(g, "don't have pLI score on file, set to 0")
                pLI = 0.0
            if pLI >= 0.5:
                gene2MutN[g] = row["nLGD"] * 0.26 + row["nMis3"] * 0.25 + row["nMis2"] * 0.06  
            else:
                gene2MutN[g] = row["nLGD"] * 0.01 + row["nMis3"] * 0.01 + row["nMis2"] * 0 
        else:
            if mode == "OR":
                gene2MutN[g] = row["LGD_OR"] * lgd_weight + row["Mis3_OR"] * mis3_weight + row["Mis2_OR"] * mis2_weight
            elif mode == "MC":
                gene2MutN[g] = row["nLGD"] * lgd_weight + row["nMis3"] * mis3_weight + row["nMis2"] * mis2_weight
    if out != None:
        writer = csv.writer(open(out, 'wt'))
        for k,v in sorted(gene2MutN.items(), key=lambda x:x[1], reverse=True):
           writer.writerow([k,v]) 
    return gene2MutN

In [None]:
TopGeneToTeset = 100
TopN_list2 = np.concatenate([np.arange(200, 1600, 100)])

In [None]:
SCZ_61GW = Aggregate_Gene_Weights_SCZ_Daly(SCZ_GeneDF.head(TopGeneToTeset), STR_BiasMat.index.values, mode="MC", mis2_weight=0)
SCZ_61_Bias = MouseSTR_AvgZ_Weighted(STR_BiasMat, SCZ_61GW)
SCZ_61_Bias.head(10)


In [None]:
SCZ_Bias_DF_list = []
for topN in TopN_list2:
    # This gets topN genes, but excludes the top 72 (i.e. gets genes ranked 73 to topN)
    SCZ_topN = SCZ_GeneDF.head(topN).iloc[TopGeneToTeset:]
    GW_SCZ_topN = Aggregate_Gene_Weights_SCZ_Daly(SCZ_topN, STR_BiasMat.index.values, mode="MC", usepLI=True, mis2_weight=0)
    SCZ_topN_Bias = MouseSTR_AvgZ_Weighted(STR_BiasMat, GW_SCZ_topN)
    SCZ_Bias_DF_list.append(SCZ_topN_Bias)

In [None]:
Spark_Bias_DF_list = []
for topN in TopN_list2:
    # This gets topN genes, but excludes the top 61 (i.e. gets genes ranked 62 to topN, like in ASC cell 21)
    SPARK_ASD_topN = Spark_Denovo_Stage1.head(topN).iloc[TopGeneToTeset:]
    GW_Unif, GW = SPARK_Gene_Weights(SPARK_ASD_topN)
    SPARK_ASD_topN_STR_Bias = MouseSTR_AvgZ_Weighted(STR_BiasMat, GW)
    SPARK_ASD_topN_STR_Bias_Unif = MouseSTR_AvgZ_Weighted(STR_BiasMat, GW_Unif)
    Spark_Bias_DF_list.append(SPARK_ASD_topN_STR_Bias)
    Spark_Bias_DF_list_Unif.append(SPARK_ASD_topN_STR_Bias_Unif)

In [None]:
SPARK_ASD_topN = Spark_Denovo_Stage1.head(TopGeneToTeset)
GW_Unif, GW = SPARK_Gene_Weights(SPARK_ASD_topN)
SPARK_ASD_topN_STR_Bias = MouseSTR_AvgZ_Weighted(STR_BiasMat, GW)

In [None]:
# Plot ASD Spark vs SCZ
reference_df = SPARK_ASD_topN_STR_Bias   
ref_effect = reference_df['EFFECT']
ref_index = reference_df.index

spearman_cors1 = []
pearson_cors1 = []
for df in Spark_Bias_DF_list:
    # Align indices before comparing
    df_matched = df.loc[df.index.intersection(ref_index)].copy()
    ref_matched = ref_effect.loc[df_matched.index]
    spearman_cors1.append(spearmanr(df_matched['EFFECT'], ref_matched).correlation)
    pearson_cors1.append(pearsonr(df_matched['EFFECT'], ref_matched)[0])

reference_df = SCZ_61_Bias   
ref_effect = reference_df['EFFECT']
ref_index = reference_df.index

spearman_cors2 = []
pearson_cors2 = []
for df in SCZ_Bias_DF_list:
    df_matched = df.loc[df.index.intersection(ref_index)].copy()
    ref_matched = ref_effect.loc[df_matched.index]
    spearman_cors2.append(spearmanr(df_matched['EFFECT'], ref_matched).correlation)
    pearson_cors2.append(pearsonr(df_matched['EFFECT'], ref_matched)[0])

plt.figure(figsize=(8, 5.5), dpi=300)
sns.set(style="whitegrid", font_scale=1.25)

# Plot both Spearman and Pearson on same figure
sns.lineplot(x=TopN_list2, y=spearman_cors1, marker='o', linewidth=2.5, label="ASD (Spearman)", color='royalblue')
sns.lineplot(x=TopN_list2, y=pearson_cors1, marker='o', linewidth=2.5, label="ASD (Pearson)", color='darkblue', linestyle='--')

sns.lineplot(x=TopN_list2, y=spearman_cors2, marker='s', linewidth=2.5, label="SCZ (Spearman)", color='orange')
sns.lineplot(x=TopN_list2, y=pearson_cors2, marker='s', linewidth=2.5, label="SCZ (Pearson)", color='darkorange', linestyle='--')

plt.xlabel("Number of Genes (Excluding Top 100 Genes)", fontsize=14)
plt.ylabel("StructureBias Correlation \nwith Top 100 Genes", fontsize=14)
#plt.title("Spearman and Pearson Correlations with Reference Gene Sets", fontsize=16, weight='bold')
plt.tick_params(labelsize=12)
plt.legend(loc='best')

plt.tight_layout()
plt.show()

# Top 20 Plus longtail 

In [None]:
# topN = 500 
# SPARK_ASD_topN = Spark_Denovo_Stage1.head(topN).iloc[20:]

In [None]:
SPARK_ASD_topN = Spark_Denovo_Stage1.head(20)
GW_Unif, top_20_GW = SPARK_Gene_Weights(SPARK_ASD_topN)
Dict2Fil(top_20_GW, ProjDIR+"/dat/Genetics/GeneWeights/Spark_top20.gw")
SPARK_ASD_topN_STR_Bias = MouseSTR_AvgZ_Weighted(STR_BiasMat, top_20_GW)

In [None]:
import os

Bias_DF_list = []
save_dir = "../results/Bootstrap_bias/Spark_top20_Random40/"
os.makedirs(save_dir, exist_ok=True)
random_seed = 42  # Set a seed for reproducibility
for i in range(1000):
    # This gets topN genes, but excludes the top 72 (i.e. gets genes ranked 73 to topN)
    Random_longtail_ASD_topN = Spark_Denovo_Stage1.head(500).iloc[20:].sample(n=41, random_state=random_seed + i)
    _, GW_ASD_tmp = SPARK_Gene_Weights(Random_longtail_ASD_topN)
    combined_dict = {**top_20_GW, **GW_ASD_tmp}
    ASD_topN_STR_Bias = MouseSTR_AvgZ_Weighted(STR_BiasMat, combined_dict)

    Bias_DF_list.append(ASD_topN_STR_Bias)

    csv_path = os.path.join(save_dir, f"bias_df_{i}.csv")
    ASD_topN_STR_Bias.to_csv(csv_path)


In [None]:
ASD_topN_STR_Bias

In [None]:
# Spark_ASD_STR_Bias
# SPARK_ASD_topN_STR_Bias
# Bias_DF_list

# Compare structure correlation between Spark_ASD_STR_Bias and SPARK_ASD_topN_STR_Bias
# Ensure structure indices are aligned
print("Comparing Spark_ASD_STR_Bias vs SPARK_ASD_topN_STR_Bias")
print(f"Spark_ASD_STR_Bias structures: {len(Spark_ASD_STR_Bias)}")
print(f"SPARK_ASD_topN_STR_Bias structures: {len(SPARK_ASD_topN_STR_Bias)}")

# Find common structures
common_structures = Spark_ASD_STR_Bias.index.intersection(SPARK_ASD_topN_STR_Bias.index)
print(f"Common structures: {len(common_structures)}")

# Plot correlation with aligned indices
plot_structure_bias_correlation(
    Spark_ASD_STR_Bias.loc[common_structures], 
    SPARK_ASD_topN_STR_Bias.loc[common_structures], 
    label_a='Mutation Bias\nZhou et al. 61 ASD genes', 
    label_b='Mutation Bias\nTop 20 ASD genes', 
    title='Structure Bias Comparison: Spark_ASD_STR_Bias vs SPARK_ASD_topN_STR_Bias'
)

In [None]:
# Compare SPARK_ASD_topN_STR_Bias with each sampling from Bias_DF_list
# Ensure structure indices are aligned for correct bias correlation

from scipy.stats import pearsonr, spearmanr
import numpy as np

# Get common structures between SPARK_ASD_topN_STR_Bias and all DataFrames in Bias_DF_list
ref_index = SPARK_ASD_topN_STR_Bias.index
ref_effect = SPARK_ASD_topN_STR_Bias['EFFECT']

# Compute correlations for each bootstrap sample
pearson_cors = []
spearman_cors = []

print(f"Reference (SPARK_ASD_topN_STR_Bias) has {len(ref_index)} structures")
print(f"Computing correlations with {len(Bias_DF_list)} bootstrap samples...")

for i, df in enumerate(Bias_DF_list):
    # Align indices - find common structures
    common_idx = ref_index.intersection(df.index)
    
    if len(common_idx) == 0:
        print(f"Warning: No common structures found for bootstrap sample {i}")
        pearson_cors.append(np.nan)
        spearman_cors.append(np.nan)
        continue
    
    # Extract aligned EFFECT values
    ref_aligned = ref_effect.loc[common_idx]
    df_aligned = df.loc[common_idx, 'EFFECT']
    
    # Compute correlations
    pearson_r, pearson_p = pearsonr(ref_aligned, df_aligned)
    spearman_r, spearman_p = spearmanr(ref_aligned, df_aligned)
    
    pearson_cors.append(pearson_r)
    spearman_cors.append(spearman_r)
    
    if (i + 1) % 100 == 0:
        print(f"Processed {i + 1}/{len(Bias_DF_list)} samples")

pearson_cors = np.array(pearson_cors)
spearman_cors = np.array(spearman_cors)

print(f"\nSummary statistics:")
print(f"Pearson correlation - Mean: {np.nanmean(pearson_cors):.4f}, Std: {np.nanstd(pearson_cors):.4f}")
print(f"Spearman correlation - Mean: {np.nanmean(spearman_cors):.4f}, Std: {np.nanstd(spearman_cors):.4f}")


In [None]:
# Visualize the correlation distribution and print 95% CIs

import matplotlib.pyplot as plt
import seaborn as sns

# Calculate 95% confidence intervals for the correlations (ignoring NaNs)
def correlation_CI(data, alpha=0.05):
    data = data[~np.isnan(data)]
    lower = np.percentile(data, 100 * (alpha / 2))
    upper = np.percentile(data, 100 * (1 - alpha / 2))
    return lower, upper

pearson_ci_lower, pearson_ci_upper = correlation_CI(pearson_cors)
spearman_ci_lower, spearman_ci_upper = correlation_CI(spearman_cors)

print(f"Pearson correlation 95% CI: [{pearson_ci_lower:.4f}, {pearson_ci_upper:.4f}]")
print(f"Spearman correlation 95% CI: [{spearman_ci_lower:.4f}, {spearman_ci_upper:.4f}]")

fig, axes = plt.subplots(1, 2, figsize=(12, 5), dpi=120)

# Pearson correlation histogram
axes[0].hist(pearson_cors, bins=50, alpha=0.7, color='steelblue', edgecolor='black')
axes[0].axvline(np.nanmean(pearson_cors), color='red', linestyle='--', linewidth=2, 
                label=f'Mean: {np.nanmean(pearson_cors):.4f}')
axes[0].axvline(pearson_ci_lower, color='green', linestyle=':', linewidth=2, label=f'95% CI: [{pearson_ci_lower:.2f}, {pearson_ci_upper:.2f}]')
axes[0].axvline(pearson_ci_upper, color='green', linestyle=':', linewidth=2)
axes[0].set_xlabel('Pearson Correlation', fontsize=12)
axes[0].set_ylabel('Frequency', fontsize=12)
axes[0].set_title('Distribution of Pearson Correlations\nSPARK_ASD_topN_STR_Bias vs Bootstrap Samples', fontsize=13)
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Spearman correlation histogram
axes[1].hist(spearman_cors, bins=50, alpha=0.7, color='coral', edgecolor='black')
axes[1].axvline(np.nanmean(spearman_cors), color='red', linestyle='--', linewidth=2, 
                label=f'Mean: {np.nanmean(spearman_cors):.4f}')
axes[1].axvline(spearman_ci_lower, color='green', linestyle=':', linewidth=2, label=f'95% CI: [{spearman_ci_lower:.2f}, {spearman_ci_upper:.2f}]')
axes[1].axvline(spearman_ci_upper, color='green', linestyle=':', linewidth=2)
axes[1].set_xlabel('Spearman Correlation', fontsize=12)
axes[1].set_ylabel('Frequency', fontsize=12)
axes[1].set_title('Distribution of Spearman Correlations\nSPARK_ASD_topN_STR_Bias vs Bootstrap Samples', fontsize=13)
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()


In [None]:
# Create scatter plots for a few example bootstrap samples to visualize alignment
# Select a few representative samples
sample_indices = [0, 100, 500, 999]  # First, middle, and last samples

fig, axes = plt.subplots(2, 2, figsize=(12, 12), dpi=120)
axes = axes.flatten()

for idx, sample_idx in enumerate(sample_indices):
    if sample_idx >= len(Bias_DF_list):
        continue
    
    df_sample = Bias_DF_list[sample_idx]
    
    # Align indices
    common_idx = ref_index.intersection(df_sample.index)
    ref_aligned = ref_effect.loc[common_idx]
    sample_aligned = df_sample.loc[common_idx, 'EFFECT']
    
    # Compute correlation
    pearson_r, pearson_p = pearsonr(ref_aligned, sample_aligned)
    
    # Plot
    ax = axes[idx]
    ax.scatter(ref_aligned, sample_aligned, alpha=0.6, s=30, edgecolors='black', linewidth=0.3)
    
    # Add diagonal line
    min_val = min(ref_aligned.min(), sample_aligned.min())
    max_val = max(ref_aligned.max(), sample_aligned.max())
    ax.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.8, linewidth=1.5)
    
    # Format p-value
    if pearson_p < 1e-10:
        p_disp = f"p <{1e-10:.0e}"
    else:
        p_disp = f"p ={pearson_p:.2g}"
    
    ax.annotate(f'r = {pearson_r:.3f}\n{p_disp}',
                xy=(0.05, 0.95), xycoords='axes fraction',
                ha='left', va='top', fontsize=11,
                bbox=dict(boxstyle="round,pad=0.3", fc="w", ec="gray", alpha=0.8))
    
    ax.set_xlabel('SPARK_ASD_topN_STR_Bias (EFFECT)', fontsize=11)
    ax.set_ylabel(f'Bootstrap Sample {sample_idx} (EFFECT)', fontsize=11)
    ax.set_title(f'Bootstrap Sample {sample_idx}\n{len(common_idx)} common structures', fontsize=12)
    ax.grid(True, alpha=0.3)
    ax.set_aspect('equal', adjustable='box')

plt.tight_layout()
plt.show()
