In [None]:
%load_ext autoreload
%autoreload 2
import sys
import os
ProjDIR = "/home/jw3514/Work/ASD_Circuits_CellType/" # Change to your project directory
sys.path.insert(1, f'{ProjDIR}/src/')
from ASD_Circuits import *
import scipy.io as sio
from scipy.stats import spearmanr

try:
    os.chdir(f"{ProjDIR}/notebooks_mouse_str/")
    print(f"Current working directory: {os.getcwd()}")
except FileNotFoundError as e:
    print(f"Error: Could not change directory - {e}")
except Exception as e:
    print(f"Unexpected error: {e}")

HGNC, ENSID2Entrez, GeneSymbol2Entrez, Entrez2Symbol = LoadGeneINFO()

# Mouse fMRI data validation from 16 mouse model figure 4

In [None]:
FMRI = pd.read_excel("/home/jw3514/Work/FuncConnectome/ASD_Mouse/Clusters_Images/Clusters_Values.xlsx", index_col="Name")

In [None]:
FMRI.head(10)

In [None]:
# Compute Correlation between 4 Clusters, and top STR in common
# Compute Spearman and Pearson correlation between Cluster1-4
cluster_cols = ['Cluster1', 'Cluster2', 'Cluster3', 'Cluster4']
cluster_data = FMRI[cluster_cols]

# Calculate Spearman and Pearson correlation matrices
spearman_corr = cluster_data.corr(method='spearman')
pearson_corr = cluster_data.corr(method='pearson')

# Import matplotlib with explicit backend setting to avoid backend_bases error
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

# Create subplots for both correlation matrices
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# Spearman correlation heatmap
sns.heatmap(spearman_corr, annot=True, cmap='coolwarm', center=0, 
            square=True, linewidths=0.5, cbar_kws={"shrink": .8}, ax=ax1)
ax1.set_title('Spearman Correlation between Clusters 1-4')

# Pearson correlation heatmap
sns.heatmap(pearson_corr, annot=True, cmap='coolwarm', center=0, 
            square=True, linewidths=0.5, cbar_kws={"shrink": .8}, ax=ax2)
ax2.set_title('Pearson Correlation between Clusters 1-4')

plt.tight_layout()
plt.show()

# Display the correlation matrices
print("Spearman Correlation Matrix:")
print(spearman_corr.round(3))
print("\nPearson Correlation Matrix:")
print(pearson_corr.round(3))

In [None]:
GENCIC = pd.read_excel("/home/jw3514/Work/ASD_Circuits_CellType/results/SupTabs.v57.xlsx", sheet_name="Table-S1- Structure Bias", index_col=0)
# Need Annotate Name 


In [None]:
GENCIC.head(2)

In [None]:
ABA_Ontology = pd.read_csv("/home/jw3514/Work/ASD_Circuits/dat/Other/ontology.csv", index_col = "KEY")
ABA_Ontology.head(3)

In [None]:
for _str, row in GENCIC.iterrows():
    if _str in ABA_Ontology.index:
        GENCIC.loc[_str, "acronym"] = ABA_Ontology.loc[_str, "acronym"]
    else:
        print(f"{_str} not in ABA_Ontology")
GENCIC["Structure"] = GENCIC.index
GENCIC = GENCIC.set_index("acronym")

In [None]:
GENCIC.head(5)

In [None]:
from scipy.stats import spearmanr, pearsonr

for cluster in ['Cluster1', 'Cluster2', 'Cluster3', 'Cluster4']:
    # Get common acronyms between GENCIC and FMRI data
    common_acronyms = GENCIC.index.intersection(FMRI.index)
    
    if len(common_acronyms) > 0:
        # Extract values for common acronyms
        gencic_bias = GENCIC.loc[common_acronyms, 'Bias']
        fmri_cluster = FMRI.loc[common_acronyms, cluster]
        
        # Calculate Spearman correlation
        spearman_correlation, spearman_p_value = spearmanr(gencic_bias, fmri_cluster)
        
        # Calculate Pearson correlation
        pearson_correlation, pearson_p_value = pearsonr(gencic_bias, fmri_cluster)
        
        print(f"{cluster}:")
        print(f"  Number of common regions: {len(common_acronyms)}")
        print(f"  Spearman correlation: {spearman_correlation:.4f} (p = {spearman_p_value:.4f})")
        print(f"  Pearson correlation: {pearson_correlation:.4f} (p = {pearson_p_value:.4f})")
        print()
    else:
        print(f"{cluster}: No common acronyms found between GENCIC and FMRI data")
        print()


In [None]:
from scipy.stats import hypergeom

topN = 50

for cluster in ['Cluster1', 'Cluster2', 'Cluster3', 'Cluster4']:
    # Get common acronyms between GENCIC and FMRI data
    common_acronyms = GENCIC.index.intersection(FMRI.index)
    
    if len(common_acronyms) > 0:
        # Get top N structures from GENCIC (highest Bias values)
        gencic_topN = GENCIC.loc[common_acronyms].nlargest(topN, 'Bias').index
        
        # Get top N structures from fMRI cluster (highest values)
        fmri_topN_largest = FMRI.loc[common_acronyms].nlargest(topN, cluster).index
        
        # Get top N structures from fMRI cluster (lowest values)
        fmri_topN_smallest = FMRI.loc[common_acronyms].nsmallest(topN, cluster).index
        
        # Calculate overlap between GENCIC top N and fMRI top N (largest)
        overlap_largest = len(set(gencic_topN).intersection(set(fmri_topN_largest)))
        
        # Calculate overlap between GENCIC top N and fMRI top N (smallest)
        overlap_smallest = len(set(gencic_topN).intersection(set(fmri_topN_smallest)))
        
        # Calculate p-values using hypergeometric test
        # For largest values
        # Population size: total common regions
        # Successes in population: fMRI top N largest
        # Sample size: GENCIC top N
        # Observed successes: overlap_largest
        pval_largest = hypergeom.sf(overlap_largest - 1, len(common_acronyms), topN, topN)
        
        # For smallest values
        pval_smallest = hypergeom.sf(overlap_smallest - 1, len(common_acronyms), topN, topN)
        
        print(f"{cluster}:")
        print(f"  Number of common regions: {len(common_acronyms)}")
        print(f"  GENCIC top {topN} overlap with fMRI top {topN} (largest): {overlap_largest} (p = {pval_largest:.4f})")
        print(f"  GENCIC top {topN} overlap with fMRI top {topN} (smallest): {overlap_smallest} (p = {pval_smallest:.4f})")
        print()
    else:
        print(f"{cluster}: No common acronyms found between GENCIC and FMRI data")
        print()

In [None]:
# Add average and count columns to FMRI dataframe
FMRI['Average_Clusters'] = FMRI[['Cluster1', 'Cluster2', 'Cluster3', 'Cluster4']].mean(axis=1)
FMRI.head(5)

In [None]:
for cluster in ['Cluster1', 'Cluster2', 'Cluster3', 'Cluster4', 'Average_Clusters']:
    # Get common acronyms between GENCIC and FMRI data
    common_acronyms = GENCIC.index.intersection(FMRI.index)
    
    if len(common_acronyms) > 0:
        # Extract values for common acronyms
        gencic_bias = GENCIC.loc[common_acronyms, 'Bias']
        fmri_cluster = FMRI.loc[common_acronyms, cluster]
        
        # Calculate Spearman correlation
        correlation, p_value = spearmanr(gencic_bias, fmri_cluster)
        
        print(f"{cluster}:")
        print(f"  Number of common regions: {len(common_acronyms)}")
        print(f"  Spearman correlation: {correlation:.4f}")
        print(f"  P-value: {p_value:.8f}")
        print()
    else:
        print(f"{cluster}: No common acronyms found between GENCIC and FMRI data")
        print()

In [None]:
# Get common acronyms between GENCIC and FMRI data
common_acronyms = GENCIC.index.intersection(FMRI.index)

if len(common_acronyms) > 0:
    # Extract values for common acronyms
    gencic_bias = GENCIC.loc[common_acronyms, 'Bias']
    fmri_average_clusters = FMRI.loc[common_acronyms, 'Average_Clusters']
    
    # Calculate Spearman correlation
    spearman_correlation, spearman_p_value = spearmanr(gencic_bias, fmri_average_clusters)
    
    # Calculate Pearson correlation
    pearson_correlation, pearson_p_value = pearsonr(gencic_bias, fmri_average_clusters)
    
    # Create scatter plot
    plt.figure(figsize=(8, 6))
    plt.scatter(gencic_bias, fmri_average_clusters, alpha=0.6, s=50)
    
    # Add trend line
    z = np.polyfit(gencic_bias, fmri_average_clusters, 1)
    p = np.poly1d(z)
    plt.plot(gencic_bias, p(gencic_bias), "r--", alpha=0.8)
    
    # Add labels and title
    plt.xlabel('GENCIC Bias')
    plt.ylabel('fMRI Average Clusters')
    plt.title(f'GENCIC Bias vs fMRI Average Clusters\nSpearman r = {spearman_correlation:.4f}, p = {spearman_p_value:.4f}\nPearson r = {pearson_correlation:.4f}, p = {pearson_p_value:.4f}')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    print(f"Number of common regions: {len(common_acronyms)}")
    print(f"Spearman correlation: {spearman_correlation:.4f}")
    print(f"Spearman P-value: {spearman_p_value:.8f}")
    print(f"Pearson correlation: {pearson_correlation:.4f}")
    print(f"Pearson P-value: {pearson_p_value:.8f}")
else:
    print("No common acronyms found between GENCIC and FMRI data")

In [None]:
Cut = 0
FMRI['Count_Below_Threshold'] = (FMRI[['Cluster1', 'Cluster2', 'Cluster3', 'Cluster4']] < Cut).sum(axis=1)

In [None]:
# Create new 'Area' column based on SubArea if valid, otherwise MacroArea
FMRI['Area'] = FMRI['SubArea'].where(FMRI['SubArea'].notna() & (FMRI['SubArea'] != ''), FMRI['MacroArea'])

In [None]:
FMRI[FMRI["Count_Below_Threshold"]>=3].head(60)

In [None]:
# Aggregate Count_Below_Threshold at MacroArea level
macro_area_aggregation = FMRI.groupby('Area')['Count_Below_Threshold'].agg(['mean', 'std', 'count']).reset_index()

# Calculate total counts for each cluster below threshold by MacroArea
cluster_counts = FMRI.groupby('Area')[['Cluster1', 'Cluster2', 'Cluster3', 'Cluster4']].apply(lambda x: (x < Cut).sum()).reset_index()

# Create a stacked bar plot showing total counts for each cluster
plt.figure(figsize=(12, 6))
width = 0.6
x = range(len(cluster_counts))

# Define colors for each cluster
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']  # Blue, Orange, Green, Red

# Create stacked bars
bottom = [0] * len(cluster_counts)
for i, cluster in enumerate(['Cluster1', 'Cluster2', 'Cluster3', 'Cluster4']):
    plt.bar(x, cluster_counts[cluster], width, bottom=bottom, 
            label=cluster, color=colors[i], alpha=0.8)
    bottom = [b + c for b, c in zip(bottom, cluster_counts[cluster])]

plt.xlabel('MaArearoArea')
plt.ylabel('Total Count Below Threshold')
plt.title('Total Count Below Threshold by Area (Colored by 4 Clusters)')
plt.xticks(x, cluster_counts['Area'], rotation=45, ha='right')
plt.legend()
plt.tight_layout()
plt.show()

# Display the aggregation table
print("Count_Below_Threshold aggregated by Area:")
print(macro_area_aggregation)
print("\nTotal counts for each cluster below threshold by Area:")
print(cluster_counts)

In [None]:
Number_gt = 3
test_set = FMRI[FMRI["Count_Below_Threshold"]>=Number_gt]
test_set_STRs = test_set.index.tolist()
GENCIC_STRs = GENCIC[GENCIC["Circuits.32"]==1].index.tolist()

# Calculate overlap between test_set_STRs and GENCIC_STRs
overlap_STRs = set(test_set_STRs).intersection(set(GENCIC_STRs))
overlap_count = len(overlap_STRs)

# Get the sizes of each dataset
fMRI_size = len(FMRI.index)  # 163 regions
GENCIC_size = len(GENCIC.index)  # 213 regions

test_set_count = len(test_set_STRs)
GENCIC_count = len(GENCIC_STRs)

# For hypergeometric test, we need to determine the correct population
# Since we're testing overlap between subsets from different populations,
# we use the intersection of both datasets as our universe
common_regions = set(FMRI.index).intersection(set(GENCIC.index))
universe_size = len(common_regions)

# Adjust counts to only include regions in the common universe
test_set_in_universe = set(test_set_STRs).intersection(common_regions)
GENCIC_in_universe = set(GENCIC_STRs).intersection(common_regions)
test_set_count_adj = len(test_set_in_universe)
GENCIC_count_adj = len(GENCIC_in_universe)

from scipy.stats import hypergeom
# P-value is probability of getting overlap_count or more overlaps by chance
# hypergeom.sf(k-1, N, K, n) gives P(X >= k)
# N = universe_size, K = GENCIC_count_adj, n = test_set_count_adj, k = overlap_count
p_value = hypergeom.sf(overlap_count - 1, universe_size, GENCIC_count_adj, test_set_count_adj)

print(f"fMRI dataset size: {fMRI_size} regions")
print(f"GENCIC dataset size: {GENCIC_size} regions")
print(f"Common regions (universe): {universe_size} regions")
print(f"Test set (Count_Below_Threshold >= {Number_gt}): {test_set_count} regions ({test_set_count_adj} in universe)")
print(f"GENCIC Circuits.46 = 1: {GENCIC_count} regions ({GENCIC_count_adj} in universe)")
print(f"Overlap: {overlap_count} regions")
print(f"Overlap regions: {list(overlap_STRs)}")
print(f"Hypergeometric test p-value: {p_value:.6f}")

# Calculate expected overlap under null hypothesis
expected_overlap = (test_set_count_adj * GENCIC_count_adj) / universe_size
print(f"Expected overlap under null hypothesis: {expected_overlap:.2f}")

# Calculate overlap statistics
overlap_percentage_test = (overlap_count / test_set_count_adj) * 100 if test_set_count_adj > 0 else 0
overlap_percentage_GENCIC = (overlap_count / GENCIC_count_adj) * 100 if GENCIC_count_adj > 0 else 0

print(f"Overlap as % of fMRI set: {overlap_percentage_test:.2f}%")
print(f"Overlap as % of GENCIC set: {overlap_percentage_GENCIC:.2f}%")

In [None]:
# Run the permutation test
SET1 = np.arange(1, 164)
SET2 = np.arange(1, 214)
set1_size = 67
set2_size = 32
observed_overlap = 14

results = permutation_test_overlap(SET1, SET2, set1_size, set2_size, observed_overlap, n_permutations=10000)

# Print results
print(f"Mean intersection length: {results['mean_intersection']:.2f}")
print(f"Std intersection length: {results['std_intersection']:.2f}")
print(f"Min intersection length: {results['min_intersection']}")
print(f"Max intersection length: {results['max_intersection']}")
print(f"P-value for overlap >= {results['observed_overlap']}: {results['p_value']:.6f}")
print(f"Number of permutations with overlap >= {results['observed_overlap']}: {results['n_significant']}")

# Plot the results
plot_permutation_results(results)

In [None]:
Number_gt = 3
test_set = FMRI[FMRI["Count_Below_Threshold"]>=Number_gt]
test_set_STRs = test_set.index.tolist()
GENCIC_STRs = GENCIC[GENCIC["Circuits.46"]==1].index.tolist()

# Calculate overlap between test_set_STRs and GENCIC_STRs
overlap_STRs = set(test_set_STRs).intersection(set(GENCIC_STRs))
overlap_count = len(overlap_STRs)

# Get the sizes of each dataset
fMRI_size = len(FMRI.index)  # 163 regions
GENCIC_size = len(GENCIC.index)  # 213 regions

test_set_count = len(test_set_STRs)
GENCIC_count = len(GENCIC_STRs)


common_regions = set(FMRI.index).intersection(set(GENCIC.index))
universe_size = len(common_regions)

# Adjust counts to only include regions in the common universe
test_set_in_universe = set(test_set_STRs).intersection(common_regions)
GENCIC_in_universe = set(GENCIC_STRs).intersection(common_regions)
test_set_count_adj = len(test_set_in_universe)
GENCIC_count_adj = len(GENCIC_in_universe)

from scipy.stats import hypergeom

p_value = hypergeom.sf(overlap_count - 1, universe_size, GENCIC_count_adj, test_set_count_adj)

print(f"fMRI dataset size: {fMRI_size} regions")
print(f"GENCIC dataset size: {GENCIC_size} regions")
print(f"Test set (Count_Below_Threshold >= {Number_gt}): {test_set_count} regions ({test_set_count_adj} in universe)")
print(f"GENCIC Circuits.46 = 1: {GENCIC_count} regions ({GENCIC_count_adj} in universe)")
print(f"Overlap: {overlap_count} regions")
print(f"Overlap regions: {list(overlap_STRs)}")

# Run the permutation test
SET1 = np.arange(1, 164)
SET2 = np.arange(1, 214)
set1_size = 67
set2_size = 46
observed_overlap = 21
results = permutation_test_overlap(SET1, SET2, set1_size, set2_size, observed_overlap, n_permutations=10000)
print(f"Mean intersection length: {results['mean_intersection']:.2f}")
print(f"P-value for overlap >= {results['observed_overlap']}: {results['p_value']:.6f}")

plot_permutation_results(results)

In [None]:
GENCIC

In [None]:
# %% cell 25 code

# Set threshold for analysis
Number_gt = 3

# Define test sets
Region2Exclude = "Thalamus"
FMRI_filt = FMRI[FMRI["MacroArea"] != Region2Exclude]
GENCIC_filt = GENCIC[GENCIC["REGION"] != Region2Exclude]
test_set = FMRI_filt[FMRI_filt["Count_Below_Threshold"] >= Number_gt]
test_set_STRs = test_set.index.tolist()
GENCIC_STRs = GENCIC_filt[GENCIC_filt["Circuits.46"] == 1].index.tolist()

# Calculate overlap between test_set_STRs and GENCIC_STRs
overlap_STRs = set(test_set_STRs).intersection(set(GENCIC_STRs))
overlap_count = len(overlap_STRs)

# Get dataset sizes
fMRI_size = len(FMRI.index)  # 163 regions
GENCIC_size = len(GENCIC.index)  # 213 regions
test_set_count = len(test_set_STRs)
GENCIC_count = len(GENCIC_STRs)

# Define common universe and adjust counts
common_regions = set(FMRI.index).intersection(set(GENCIC.index))
universe_size = len(common_regions)

test_set_in_universe = set(test_set_STRs).intersection(common_regions)
GENCIC_in_universe = set(GENCIC_STRs).intersection(common_regions)
test_set_count_adj = len(test_set_in_universe)
GENCIC_count_adj = len(GENCIC_in_universe)

# Calculate hypergeometric p-value
from scipy.stats import hypergeom
p_value = hypergeom.sf(overlap_count - 1, universe_size, GENCIC_count_adj, test_set_count_adj)

# Print results
print(f"fMRI dataset size: {fMRI_size} regions")
print(f"GENCIC dataset size: {GENCIC_size} regions")
print(f"Test set (Count_Below_Threshold >= {Number_gt}): {test_set_count} regions ({test_set_count_adj} in universe)")
print(f"GENCIC Circuits.46 = 1: {GENCIC_count} regions ({GENCIC_count_adj} in universe)")
print(f"Overlap: {overlap_count} regions")
print(f"Overlap regions: {list(overlap_STRs)}")

# Run permutation test
SET1 = np.arange(1, 164)
SET2 = np.arange(1, 214)
set1_size = 53
set2_size = 38
observed_overlap = 19

results = permutation_test_overlap(SET1, SET2, set1_size, set2_size, observed_overlap, n_permutations=10000)
print(f"Mean intersection length: {results['mean_intersection']:.2f}")
print(f"P-value for overlap >= {results['observed_overlap']}: {results['p_value']:.6f}")

plot_permutation_results(results)

In [None]:
def permutation_test_overlap(set1_range, set2_range, set1_size, set2_size, observed_overlap, n_permutations=10000):
    """
    Perform a permutation test to assess the significance of overlap between two sets.
    
    Parameters:
    -----------
    set1_range : array-like
        Range of possible values for set 1 (e.g., np.arange(1, 164))
    set2_range : array-like
        Range of possible values for set 2 (e.g., np.arange(1, 214))
    set1_size : int
        Size of set 1 sample
    set2_size : int
        Size of set 2 sample
    observed_overlap : int
        The observed overlap to test against
    n_permutations : int, default=10000
        Number of permutations to perform
    
    Returns:
    --------
    dict : Dictionary containing test results
    """
    
    # Generate all random samples at once
    set1_samples = np.array([np.random.choice(set1_range, size=set1_size, replace=False) for _ in range(n_permutations)])
    set2_samples = np.array([np.random.choice(set2_range, size=set2_size, replace=False) for _ in range(n_permutations)])

    # Vectorized intersection calculation
    intersections = np.array([len(np.intersect1d(set1_samples[i], set2_samples[i])) for i in range(n_permutations)])

    # Calculate p-value for overlap >= observed_overlap
    p_value = np.sum(intersections >= observed_overlap) / len(intersections)
    
    # Prepare results
    results = {
        'intersections': intersections,
        'mean_intersection': np.mean(intersections),
        'std_intersection': np.std(intersections),
        'min_intersection': np.min(intersections),
        'max_intersection': np.max(intersections),
        'observed_overlap': observed_overlap,
        'p_value': p_value,
        'n_significant': np.sum(intersections >= observed_overlap),
        'n_permutations': n_permutations
    }
    
    return results

def plot_permutation_results(results):
    """
    Plot the results of a permutation test.
    
    Parameters:
    -----------
    results : dict
        Results dictionary from permutation_test_overlap function
    """
    intersections = results['intersections']
    observed_overlap = results['observed_overlap']
    mean_intersection = results['mean_intersection']
    n_permutations = results['n_permutations']
    
    plt.figure(figsize=(10, 6))
    plt.hist(intersections, bins=range(min(intersections), max(intersections)+2), 
             alpha=0.7, edgecolor='black', density=True)
    plt.axvline(observed_overlap, color='red', linestyle='--', linewidth=2, 
               label=f'Observed overlap = {observed_overlap}')
    plt.axvline(mean_intersection, color='blue', linestyle='--', linewidth=2, 
               label=f'Mean = {mean_intersection:.1f}')
    plt.xlabel('Intersection Length')
    plt.ylabel('Probability Density')
    plt.title(f'Distribution of Intersections from {n_permutations} Permutations')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

# Run the permutation test
SET1 = np.arange(1, 164)
SET2 = np.arange(1, 214)
set1_size = 67
set2_size = 46
observed_overlap = 21

results = permutation_test_overlap(SET1, SET2, set1_size, set2_size, observed_overlap, n_permutations=10000)

# Print results
print(f"Mean intersection length: {results['mean_intersection']:.2f}")
print(f"Std intersection length: {results['std_intersection']:.2f}")
print(f"Min intersection length: {results['min_intersection']}")
print(f"Max intersection length: {results['max_intersection']}")
print(f"P-value for overlap >= {results['observed_overlap']}: {results['p_value']:.6f}")
print(f"Number of permutations with overlap >= {results['observed_overlap']}: {results['n_significant']}")

# Plot the results
plot_permutation_results(results)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches

# Run the permutation test
SET1 = np.arange(1, 164)
SET2 = np.arange(1, 214)
set1_size = 67
set2_size = 46
observed_overlap = 21

results = permutation_test_overlap(SET1, SET2, set1_size, set2_size, observed_overlap, n_permutations=10000)

# Print results
print(f"Mean intersection length: {results['mean_intersection']:.2f}")
print(f"Std intersection length: {results['std_intersection']:.2f}")
print(f"Min intersection length: {results['min_intersection']}")
print(f"Max intersection length: {results['max_intersection']}")
print(f"P-value for overlap >= {results['observed_overlap']}: {results['p_value']:.6f}")
print(f"Number of permutations with overlap >= {results['observed_overlap']}: {results['n_significant']}")

# Plot the results
plot_permutation_results(results)

# Create simple Venn diagram visualization using matplotlib
fig, ax = plt.subplots(figsize=(8, 6))

# Draw two overlapping circles
circle1 = patches.Circle((0.35, 0.5), 0.3, alpha=0.5, color='blue', label='Set 1')
circle2 = patches.Circle((0.65, 0.5), 0.3, alpha=0.5, color='red', label='Set 2')

ax.add_patch(circle1)
ax.add_patch(circle2)

# Add text labels for each region
ax.text(0.2, 0.5, f'{set1_size - observed_overlap}', fontsize=14, ha='center', va='center')
ax.text(0.8, 0.5, f'{set2_size - observed_overlap}', fontsize=14, ha='center', va='center')
ax.text(0.5, 0.5, f'{observed_overlap}', fontsize=14, ha='center', va='center', weight='bold')

# Add set labels
ax.text(0.2, 0.2, 'Set 1', fontsize=12, ha='center', va='center', weight='bold')
ax.text(0.8, 0.2, 'Set 2', fontsize=12, ha='center', va='center', weight='bold')

ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_aspect('equal')
ax.axis('off')
ax.set_title(f'Venn Diagram of Set Overlap\nObserved Overlap: {observed_overlap}', fontsize=14)
plt.show()

In [None]:
# Run the permutation test
SET1 = np.arange(1, 164)
SET2 = np.arange(1, 164)
set1_size = 67
set2_size = 43
observed_overlap = 21

results = permutation_test_overlap(SET1, SET2, set1_size, set2_size, observed_overlap, n_permutations=10000)

# Print results
print(f"Mean intersection length: {results['mean_intersection']:.2f}")
print(f"Std intersection length: {results['std_intersection']:.2f}")
print(f"Min intersection length: {results['min_intersection']}")
print(f"Max intersection length: {results['max_intersection']}")
print(f"P-value for overlap >= {results['observed_overlap']}: {results['p_value']:.6f}")
print(f"Number of permutations with overlap >= {results['observed_overlap']}: {results['n_significant']}")

# Plot the results
plot_permutation_results(results)

In [None]:
# OK Let test if each individual cluster is overlap with Test set. 
for cluster in ['Cluster1', 'Cluster2', 'Cluster3', 'Cluster4']:
    #_sub_test_set = FMRI.sort_values(by=cluster, ascending=False).head(46).index.values
    _sub_test_set = FMRI.sort_values(by=cluster, ascending=False).tail(46).index.values
    # Test overlap between _sub_test_set and test_set_STRs
    _sub_overlap = set(_sub_test_set).intersection(set(test_set_STRs))
    _sub_overlap_count = len(_sub_overlap)
    
    _sub_test_count = len(_sub_test_set)
    _test_set_count = len(test_set_STRs)
    _pool_size = 163  # total regions in the pool
    
    _sub_p_value = hypergeom.sf(_sub_overlap_count - 1, _pool_size, _test_set_count, _sub_test_count)

    SET1 = np.arange(1, 164)
    SET2 = np.arange(1, 164)
    set1_size = 67
    set2_size = 46
    observed_overlap = _sub_overlap_count
    perm_p_value = permutation_test_overlap(SET1, SET2, set1_size, set2_size, _sub_overlap_count, n_permutations=10000)
    
    print(f"\n{cluster} Analysis:")
    print(f"Cluster regions: {_sub_test_count}")
    print(f"Test set regions: {_test_set_count}")
    print(f"Overlap: {_sub_overlap_count} regions")
    #print(f"Overlap regions: {list(_sub_overlap)}")
    #print(f"Hypergeometric p-value: {_sub_p_value:.6f}")
    print(f"Permutation p-value: {perm_p_value['p_value']:.6f}")

In [None]:
# Try another way. Under and Over connectivity above certain threshold. 
Cut = 0.5
FMRI['Count_Below_Threshold_v2'] = (FMRI[['Cluster1', 'Cluster2', 'Cluster3', 'Cluster4']].abs() > Cut).sum(axis=1)

In [None]:
Number_gt = 2
test_set = FMRI[FMRI["Count_Below_Threshold_v2"]>=Number_gt]
test_set_STRs = test_set.index.tolist()
GENCIC_STRs = GENCIC[GENCIC["Circuits.46"]==1].index.tolist()

# Calculate overlap between test_set_STRs and GENCIC_STRs
overlap_STRs = set(test_set_STRs).intersection(set(GENCIC_STRs))
overlap_count = len(overlap_STRs)


# First, let's get the union of all regions from both datasets to define our universe
all_fMRI_regions = set(FMRI.index)
all_GENCIC_regions = set(GENCIC.index)
universe_regions = all_fMRI_regions.union(all_GENCIC_regions)
universe_size = len(universe_regions)

test_set_count = len(test_set_STRs)
GENCIC_count = len(GENCIC_STRs)

from scipy.stats import hypergeom
p_value = hypergeom.sf(overlap_count - 1, universe_size, GENCIC_count, test_set_count)

print(f"Universe size (union of both datasets): {universe_size} regions")
print(f"Test set (Count_Below_Threshold >= {Number_gt}): {test_set_count} regions")
print(f"GENCIC Circuits.46 = 1: {GENCIC_count} regions")
print(f"Overlap: {overlap_count} regions")
print(f"Overlap regions: {list(overlap_STRs)}")


# Calculate expected overlap under null hypothesis
expected_overlap = (test_set_count * GENCIC_count) / universe_size
print(f"Expected overlap under null hypothesis: {expected_overlap:.2f}")

# Calculate overlap statistics
overlap_percentage_test = (overlap_count / test_set_count) * 100 if test_set_count > 0 else 0
overlap_percentage_GENCIC = (overlap_count / GENCIC_count) * 100 if GENCIC_count > 0 else 0

print(f"Overlap as % of fMRI set: {overlap_percentage_test:.2f}%")
print(f"Overlap as % of GENCIC set: {overlap_percentage_GENCIC:.2f}%")
print(f"Hypergeometric test p-value: {p_value:.6f}")

In [None]:
GENCIC.head(2)

In [None]:
# Annotate FMRI dataframe with GENCIC Bias and Circuits.46 membership
FMRI_annotated = FMRI.copy()

# Add GENCIC Bias column
FMRI_annotated['GENCIC_Bias'] = FMRI_annotated.index.map(GENCIC['Bias'])
FMRI_annotated['GENCIC_Circuits_46'] = FMRI_annotated.index.map(GENCIC['Circuits.46'])
FMRI_annotated['FullName'] = FMRI_annotated.index.map(GENCIC['Structure'])

# Fill NaN values with 0 for regions not in GENCIC
FMRI_annotated['GENCIC_Bias'] = FMRI_annotated['GENCIC_Bias'].fillna(0)
FMRI_annotated['GENCIC_Circuits_46'] = FMRI_annotated['GENCIC_Circuits_46'].fillna(0)

FMRI_annotated

In [None]:
FMRI_annotated.to_csv('/home/jw3514/Work/FuncConnectome/ASD_Mouse/Clusters_Images/FMRI_annotated.csv')

# Mouse Model fMRI

In [None]:
DataDIR = "/home/jw3514/Work/FuncConnectome/ASD_Mouse/OneDrive_1_7-31-2025/"
data_csf = sio.loadmat(DataDIR + "global_connectivity_allsubjs_CSF.mat")
data_gsr = sio.loadmat(DataDIR + "global_connectivity_allsubjs_GSR.mat")

parcel_indices = pd.read_csv(DataDIR + "parcel_indices_424.csv")
parcel_labels = pd.read_csv(DataDIR + "parc_labels_424_LR.csv")

In [None]:
data_csf

In [None]:
gc_csf = data_csf['global_connectivity_allsubjs']

In [None]:
gc_csf.shape

In [None]:
from pathlib import Path

class MouseGlobalConnectivity:
    def __init__(self, mat_file, parcel_idx_file, parcel_label_file):
        self.mat_file = Path(mat_file)
        self.mouse_models = ['shank3b', 'chd8', 'cntnap2', 'mecp2']
        self.genotypes = ['mutant', 'wt']

        # Load parcel metadata
        self.parcel_indices = pd.read_csv(parcel_idx_file, header=None, names=['index'])  # 1-based indices
        self.parcel_labels = pd.read_csv(parcel_label_file)  # has 'name' column

        # Map indices to names (adjust 1-based to 0-based indexing)
        idx_zero_based = self.parcel_indices['index'].values - 1
        self.parcel_names = self.parcel_labels.iloc[idx_zero_based]['name'].tolist()

        # Load MATLAB data
        self.data = self._load_mat()

    def _load_mat(self):
        """Load MATLAB .mat file (v7.2 or older)"""
        mat = sio.loadmat(self.mat_file, squeeze_me=True)
        return {k: v for k, v in mat.items() if not k.startswith('__')}

    def get_connectivity(self, mouse_model, genotype):
        """Return connectivity matrix for a given mouse model and genotype"""
        if mouse_model not in self.mouse_models:
            raise ValueError(f"Unknown mouse model: {mouse_model}")
        if genotype not in self.genotypes:
            raise ValueError(f"Genotype must be one of {self.genotypes}")

        arr = self.data['global_connectivity_allsubjs']  # 4Ã—2 array of matrices
        row = self.mouse_models.index(mouse_model)
        col = self.genotypes.index(genotype)

        mat = arr[row, col]
        return np.array(mat)

    def _merge_hemispheres(self, df, strategy="average"):
        """Merge left/right hemisphere parcels. If only one side exists, keep as is."""
        base_names = df.index.str.replace(r'_(L|R)$', '', regex=True)

        if strategy == "average":
            # For each base name, average L/R if both exist, else just keep the one present
            df = df.copy()
            df['__base__'] = base_names
            merged = []
            for base, group in df.groupby('__base__'):
                if len(group) == 2:
                    merged_row = group.drop(columns='__base__').mean()
                else:
                    merged_row = group.drop(columns='__base__').iloc[0]
                merged.append((base, merged_row))
            merged_df = pd.DataFrame([row for _, row in merged], index=[base for base, _ in merged])
            return merged_df

        elif strategy == "concat":
            # For each base name, concat L and R columns if both exist, else just keep the one present
            left_mask = df.index.str.endswith('_L')
            right_mask = df.index.str.endswith('_R')
            left_df = df[left_mask].copy()
            right_df = df[right_mask].copy()

            left_df.index = left_df.index.str.replace(r'_L$', '', regex=True)
            right_df.index = right_df.index.str.replace(r'_R$', '', regex=True)

            left_df.columns = [f"{c}_L" for c in left_df.columns]
            right_df.columns = [f"{c}_R" for c in right_df.columns]

            # Find all base names
            all_bases = set(left_df.index) | set(right_df.index)
            concat_rows = []
            concat_index = []
            for base in sorted(all_bases):
                left_row = left_df.loc[base] if base in left_df.index else None
                right_row = right_df.loc[base] if base in right_df.index else None
                if left_row is not None and right_row is not None:
                    row = pd.concat([left_row, right_row])
                elif left_row is not None:
                    row = left_row
                elif right_row is not None:
                    row = right_row
                else:
                    continue  # Should not happen
                concat_rows.append(row)
                concat_index.append(base)
            expanded = pd.DataFrame(concat_rows, index=concat_index)
            expanded = expanded.sort_index()
            return expanded

        elif strategy is None:
            return df

        else:
            raise ValueError("merge strategy must be 'average', 'concat', or None")

    def get_dataframe(self, mouse_model, genotype, merge=None):
        """
        Return DataFrame with parcel names and connectivity values.
        merge: None, 'average', or 'concat'
        """
        mat = self.get_connectivity(mouse_model, genotype)
        df = pd.DataFrame(
            mat,
            index=self.parcel_names,
            columns=[f"subj_{i+1}" for i in range(mat.shape[1])]
        )
        df.index.name = 'parcel_name'

        if merge is not None:
            df = self._merge_hemispheres(df, strategy=merge)

        return df

In [None]:
from scipy.stats import mannwhitneyu
import numpy as np

def connectivity_test(data, method, gene):
    mut_df = data[method][gene]["mutant"]
    wt_df = data[method][gene]["wt"]
    results = []
    for i, _str in enumerate(mut_df.index.values):
        mut_conn = mut_df.iloc[i, :]
        wt_conn = wt_df.iloc[i, :]
        # Exclude invalid values (NaN, inf, -inf)
        mut_conn_valid = mut_conn[~np.isnan(mut_conn) & np.isfinite(mut_conn)]
        wt_conn_valid = wt_conn[~np.isnan(wt_conn) & np.isfinite(wt_conn)]
        mut_conn_mean = mut_conn_valid.mean()
        wt_conn_mean = wt_conn_valid.mean()
        mut_conn_std = mut_conn_valid.std()
        wt_conn_std = wt_conn_valid.std()
        # Only perform test if both groups have at least one valid value
        if len(mut_conn_valid) > 0 and len(wt_conn_valid) > 0:
            stat, p = mannwhitneyu(mut_conn_valid, wt_conn_valid, alternative='two-sided')
        else:
            p = np.nan
        results.append({
            'parcel_name': _str,
            'mut_mean': mut_conn_mean,
            'wt_mean': wt_conn_mean,
            'conn_diff': mut_conn_mean - wt_conn_mean,
            'mut_std': mut_conn_std,
            'wt_std': wt_conn_std,
            'mwu_p': p
        })
    results_df = pd.DataFrame(results).set_index('parcel_name')
    results_df = results_df.sort_values(by='mwu_p')
    return results_df

def collapse_hemispheres(df, strategy='mean'):
    if strategy == 'mean':
        # Strip _L / _R suffix from parcel names
        collapsed = df.copy()
        collapsed.index = collapsed.index.str.replace(r'_(L|R)$', '', regex=True)
        # Group by the new parcel name and take the mean across L and R
        collapsed = collapsed.groupby(collapsed.index).mean()
        return collapsed
    elif strategy == 'concat':
        # Strip _L / _R suffix from parcel names
        collapsed = df.copy()
        collapsed.index = collapsed.index.str.replace(r'_(L|R)$', '', regex=True)
        # Group by the new parcel name and take the mean across L and R
        collapsed = collapsed.groupby(collapsed.index).mean()
        return collapsed
    else:
        raise ValueError(f"Invalid strategy: {strategy}")
        
def print_data_treeview(data, indent=0):
    for preproc in data:
        print("  " * indent + f"{preproc}/")
        for gene in data[preproc]:
            print("  " * (indent + 1) + f"{gene}/")
            for group in data[preproc][gene]:
                print("  " * (indent + 2) + f"{group}: DataFrame shape {data[preproc][gene][group].shape}")

In [None]:
# Cleaned up data structure: use nested dicts for easy access
# Structure: data[preproc][gene][group] = dataframe

data = {}
for preproc, loader in {
    "CSF": MouseGlobalConnectivity(
        mat_file=DataDIR + "global_connectivity_allsubjs_CSF.mat",
        parcel_idx_file=DataDIR + "parcel_indices_424.csv",
        parcel_label_file=DataDIR + "parc_labels_424_LR.csv"
    ),
    "GSR": MouseGlobalConnectivity(
        mat_file=DataDIR + "global_connectivity_allsubjs_GSR.mat",
        parcel_idx_file=DataDIR + "parcel_indices_424.csv",
        parcel_label_file=DataDIR + "parc_labels_424_LR.csv"
    )
}.items():
    data[preproc] = {}
    for gene in ["shank3b", "cntnap2", "chd8", "mecp2"]:
        data[preproc][gene] = {}
        for group in ["mutant", "wt"]:
            data[preproc][gene][group] = loader.get_dataframe(gene, group, merge=None)
print_data_treeview(data)

In [None]:
data_LR_merge = {}
for preproc, loader in {
    "CSF": MouseGlobalConnectivity(
        mat_file=DataDIR + "global_connectivity_allsubjs_CSF.mat",
        parcel_idx_file=DataDIR + "parcel_indices_424.csv",
        parcel_label_file=DataDIR + "parc_labels_424_LR.csv"
    ),
    "GSR": MouseGlobalConnectivity(
        mat_file=DataDIR + "global_connectivity_allsubjs_GSR.mat",
        parcel_idx_file=DataDIR + "parcel_indices_424.csv",
        parcel_label_file=DataDIR + "parc_labels_424_LR.csv"
    )
}.items():
    data_LR_merge[preproc] = {}
    for gene in ["shank3b", "cntnap2", "chd8", "mecp2"]:
        data_LR_merge[preproc][gene] = {}
        for group in ["mutant", "wt"]:
            data_LR_merge[preproc][gene][group] = loader.get_dataframe(gene, group, merge="average")
print_data_treeview(data_LR_merge)

In [None]:
#shank3b_mut.head(5)
data["CSF"]["shank3b"]["mutant"].head(5)

In [None]:
CSF_shank3b_res = connectivity_test(data, "CSF", "shank3b")
CSF_chd8_res = connectivity_test(data, "CSF", "chd8")
CSF_cntnap2_res = connectivity_test(data, "CSF", "cntnap2")
CSF_mecp2_res = connectivity_test(data, "CSF", "mecp2")

GSR_shank3b_res = connectivity_test(data, "GSR", "shank3b")
GSR_chd8_res = connectivity_test(data, "GSR", "chd8")
GSR_cntnap2_res = connectivity_test(data, "GSR", "cntnap2")
GSR_mecp2_res = connectivity_test(data, "GSR", "mecp2")

In [None]:
CSF_shank3b_res.head(10)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def compare_lr_correlations(results_df, plot=True, title=None):
    """
    Compare left vs right hemisphere correlation for conn_diff and mwu_p,
    and make scatter plots of L vs R for both conn_diff and -log10(mwu_p).
    
    Parameters
    ----------
    results_df : pd.DataFrame
        Must have index with `_L` and `_R` suffixes and columns 'conn_diff' and 'mwu_p'.
    plot : bool
        If True, show scatter plots.
    
    Returns
    -------
    pd.DataFrame
        Spearman correlations for conn_diff and mwu_p.
    """
    # Make sure we only use parcels with both L and R
    base_names = results_df.index.str.replace(r'_(L|R)$', '', regex=True)
    results_df = results_df.assign(base=base_names)
    
    left_df = results_df[results_df.index.str.endswith('_L')].copy()
    right_df = results_df[results_df.index.str.endswith('_R')].copy()
    
    # Align on base name
    left_df.index = left_df['base']
    right_df.index = right_df['base']
    
    # Intersect bases to be safe
    common = left_df.index.intersection(right_df.index)
    left_df = left_df.loc[common]
    right_df = right_df.loc[common]
    
    # Spearman correlations
    rho_conn, p_conn = spearmanr(left_df['conn_diff'], right_df['conn_diff'])
    rho_pval, p_pval = spearmanr(left_df['mwu_p'], right_df['mwu_p'])
    
    if plot:
        fig, axs = plt.subplots(1, 2, figsize=(10, 5))
        
        # Scatter for conn_diff
        axs[0].scatter(left_df['conn_diff'], right_df['conn_diff'], alpha=0.7)
        axs[0].set_xlabel('Left conn_diff')
        axs[0].set_ylabel('Right conn_diff')
        axs[0].set_title(f'conn_diff L vs R\nSpearman r={rho_conn:.2f}, p={p_conn:.2g}')
        axs[0].axline((0, 0), slope=1, color='gray', linestyle='--', linewidth=1)
        
        # Scatter for -log10(mwu_p)
        left_logp = -np.log10(left_df['mwu_p'].clip(lower=1e-20))
        right_logp = -np.log10(right_df['mwu_p'].clip(lower=1e-20))
        axs[1].scatter(left_logp, right_logp, alpha=0.7)
        axs[1].set_xlabel('-log10(mwu_p) Left')
        axs[1].set_ylabel('-log10(mwu_p) Right')
        axs[1].set_title(f'-log10(mwu_p) L vs R\nSpearman r={rho_pval:.2f}, p={p_pval:.2g}')
        axs[1].axline((0, 0), slope=1, color='gray', linestyle='--', linewidth=1)
        
        plt.tight_layout()
        if title is not None:
            fig.suptitle(title)
            plt.subplots_adjust(top=0.85)
        plt.show()
    
    res = {
        "n_pairs": len(common),
        "conn_diff_spearman_rho": rho_conn,
        "conn_diff_pval": p_conn,
        "mwu_p_spearman_rho": rho_pval,
        "mwu_p_pval": p_pval
    }
    return pd.DataFrame([res])

In [None]:
for res_df, name in zip([CSF_shank3b_res, GSR_shank3b_res, CSF_chd8_res, GSR_chd8_res, CSF_cntnap2_res, GSR_cntnap2_res, CSF_mecp2_res, GSR_mecp2_res], ["CSF Shank3b", "GSR Shank3b", "CSF Chd8", "GSR Chd8", "CSF Cntnap2", "GSR Cntnap2", "CSF Mepc2", "GSR Mepc2"]) :
    lr_corr = compare_lr_correlations(res_df, title=name)
#print(lr_corr)

In [None]:
CSF_merge_shank3b_res = connectivity_test(data_LR_merge, "CSF", "shank3b")
CSF_merge_chd8_res = connectivity_test(data_LR_merge, "CSF", "chd8")
CSF_merge_cntnap2_res = connectivity_test(data_LR_merge, "CSF", "cntnap2")
CSF_merge_mecp2_res = connectivity_test(data_LR_merge, "CSF", "mecp2")

GSR_merge_shank3b_res = connectivity_test(data_LR_merge, "GSR", "shank3b")
GSR_merge_chd8_res = connectivity_test(data_LR_merge, "GSR", "chd8")
GSR_merge_cntnap2_res = connectivity_test(data_LR_merge, "GSR", "cntnap2")
GSR_merge_mecp2_res = connectivity_test(data_LR_merge, "GSR", "mecp2")

merge_res_dict = {
    "CSF_merge": {
        "shank3b": CSF_merge_shank3b_res,
        "chd8": CSF_merge_chd8_res,
        "cntnap2": CSF_merge_cntnap2_res,
        "mecp2": CSF_merge_mecp2_res
    },
    "GSR_merge": {
        "shank3b": GSR_merge_shank3b_res,
        "chd8": GSR_merge_chd8_res,
        "cntnap2": GSR_merge_cntnap2_res,
        "mecp2": GSR_merge_mecp2_res
    }
}

In [None]:
merge_res_dict["CSF_merge"]["shank3b"].sort_values(by="conn_diff").head(50)

In [None]:
merge_res_dict["CSF_merge"]["chd8"].sort_values(by="conn_diff").tail(50)

In [None]:
def compare_models_spearman(merge_results_dict, stat_type="conn_diff", merge_key="CSF_merge"):
    """
    Compare models pairwise by Spearman correlation of MWU p-values using the new merge_res_dict structure.

    merge_results_dict: dict
        Outer keys = merge type (e.g., "CSF_merge", "GSR_merge")
        Inner keys = model name (str)
        Inner values = results DataFrame (must have 'conn_diff' column and same index)
    stat_type: str
        The column to use for correlation (default "conn_diff")
    merge_key: str
        Which merge type to use from merge_results_dict (default "CSF_merge")

    Returns:
        DataFrame of pairwise Spearman correlation coefficients.
    """
    # Use the specified merge_key to get the inner dict of models
    results_dict = merge_results_dict[merge_key]
    models = list(results_dict.keys())
    corr_df = pd.DataFrame(index=models, columns=models, dtype=float)

    for m1, m2 in combinations(models, 2):
        # Align by parcel_name index
        s1 = results_dict[m1][stat_type]
        s2 = results_dict[m2][stat_type]
        aligned = pd.concat([s1, s2], axis=1, join='inner').dropna()

        rho, _ = spearmanr(aligned.iloc[:, 0], aligned.iloc[:, 1])

        corr_df.loc[m1, m2] = rho
        corr_df.loc[m2, m1] = rho

    # Fill diagonal
    for m in models:
        corr_df.loc[m, m] = 1.0

    return corr_df

In [None]:
rho_table = compare_models_spearman(merge_res_dict, merge_key="CSF_merge")
sns.heatmap(rho_table.astype(float), annot=True, cmap="coolwarm", vmin=-1, vmax=1)
plt.title("Spearman correlation of MWU p-values across models")
plt.show()

In [None]:
rho_table = compare_models_spearman(merge_res_dict, merge_key="GSR_merge")
sns.heatmap(rho_table.astype(float), annot=True, cmap="coolwarm", vmin=-1, vmax=1)
plt.title("Spearman correlation of MWU p-values across models")
plt.show()

In [None]:
def compare_models_topN_overlap(results_dict, N=50, merge_key=None):
    """
    Compare models by % overlap of top-N parcels (lowest MWU p-values).

    Parameters
    ----------
    results_dict : dict
        If merge_key is None:
            Keys = model name (str)
            Values = results DataFrame (must have 'mwu_p' column)
        If merge_key is not None:
            Outer keys = merge type (e.g., "CSF_merge", "GSR_merge")
            Inner keys = model name (str)
            Inner values = results DataFrame (must have 'mwu_p' column)
    N : int
        Number of top parcels to compare.
    merge_key : str or None
        If not None, use this key to select the inner dict from results_dict.

    Returns
    -------
    pd.DataFrame
        Pairwise % overlap of top-N parcels.
    """
    # Handle merge_res_dict format if merge_key is provided
    if merge_key is not None:
        results_dict = results_dict[merge_key]

    models = list(results_dict.keys())
    overlap_df = pd.DataFrame(index=models, columns=models, dtype=float)

    # Precompute top-N sets for each model
    top_sets = {}
    for model, df in results_dict.items():
        top_sets[model] = set(df.sort_values('mwu_p').head(N).index)

    for m1, m2 in combinations(models, 2):
        overlap_count = len(top_sets[m1] & top_sets[m2])
        overlap_pct = overlap_count / N
        overlap_df.loc[m1, m2] = overlap_pct
        overlap_df.loc[m2, m1] = overlap_pct

    # Fill diagonal with 1.0
    for m in models:
        overlap_df.loc[m, m] = 1.0

    return overlap_df

In [None]:
compare_models_topN_overlap(merge_res_dict, merge_key="CSF_merge", N=50)

In [None]:
compare_models_topN_overlap(merge_res_dict, merge_key="GSR_merge", N=50)

In [None]:
GENCIC = pd.read_excel("/home/jw3514/Work/ASD_Circuits_CellType/results/SupTabs.v57.xlsx", sheet_name="Table-S1- Structure Bias", index_col=0)

In [None]:
GENCIC.head(2)

In [None]:
CommonSTRs = merge_res_dict["CSF_merge"]["shank3b"].index.intersection(GENCIC.index)
GENCIC_intersect = GENCIC.loc[CommonSTRs]

In [None]:
# Annotate conn_diff and mwu_p for each mouse model and each method to GENCIC DataFrame

# Define the methods and mouse models to annotate
methods = ["CSF_merge", "GSR_merge"]
mouse_models = list(merge_res_dict["CSF_merge"].keys())

for method in methods:
    for model in mouse_models:
        # Prepare column names for conn_diff and mwu_p
        conn_col = f"{model}_{method}_conn_diff"
        pval_col = f"{model}_{method}_mwu_p"
        # Initialize columns if not present
        if conn_col not in GENCIC_intersect.columns:
            GENCIC_intersect[conn_col] = pd.NA
        if pval_col not in GENCIC_intersect.columns:
            GENCIC_intersect[pval_col] = pd.NA
        # Get the result DataFrame for this model/method
        res_df = merge_res_dict[method][model]
        for STR in GENCIC_intersect.index:
            if STR in res_df.index:
                GENCIC_intersect.at[STR, conn_col] = res_df.at[STR, "conn_diff"] if "conn_diff" in res_df.columns else pd.NA
                GENCIC_intersect.at[STR, pval_col] = res_df.at[STR, "mwu_p"] if "mwu_p" in res_df.columns else pd.NA
            else:
                GENCIC_intersect.at[STR, conn_col] = pd.NA
                GENCIC_intersect.at[STR, pval_col] = pd.NA

In [None]:
GENCIC_intersect.head(2)

In [None]:
GENCIC_intersect.columns.values

In [None]:
import matplotlib.pyplot as plt

mousemodels = ["shank3b", "chd8", "cntnap2", "mecp2"]
methods = ["CSF_merge", "GSR_merge"]

fig, axes = plt.subplots(len(mousemodels), len(methods), figsize=(10, 16), dpi=150, sharex=True, sharey=False)
fig.subplots_adjust(hspace=0.4, wspace=0.3)

for i, mousemodel in enumerate(mousemodels):
    for j, method in enumerate(methods):
        ax = axes[i, j]
        x = GENCIC_intersect["Bias"]
        y = GENCIC_intersect[f"{mousemodel}_{method}_conn_diff"]
        valid = x.notna() & y.notna()
        if valid.sum() > 1:
            corr, p = spearmanr(x[valid], y[valid])
            ax.scatter(x[valid], y[valid], alpha=0.7, s=20)
            ax.set_title(f"{mousemodel} - {method}")
            ax.annotate(f"r={corr:.2f}\np={p:.2g}", xy=(0.05, 0.85), xycoords="axes fraction", fontsize=10,
                        bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="gray", alpha=0.7))
        else:
            ax.text(0.5, 0.5, "Not enough data", ha="center", va="center", fontsize=10)
            ax.set_title(f"{mousemodel} - {method}")
        if i == len(mousemodels) - 1:
            ax.set_xlabel("GENCIC Bias")
        if j == 0:
            ax.set_ylabel("Conn Diff")
plt.suptitle("GENCIC Bias vs Mouse Model Conn Diff\n(Spearman r and p shown)", fontsize=16, y=0.92)
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

# Collect all relevant columns for pairwise correlation
cols = ["Bias"]
for mousemodel in ["shank3b", "chd8", "cntnap2", "mecp2"]:
    for method in ["CSF_merge", "GSR_merge"]:
        col = f"{mousemodel}_{method}_conn_diff"
        if col in GENCIC_intersect.columns:
            cols.append(col)

# Subset and drop rows with all-NA
df_corr = GENCIC_intersect[cols].copy()
df_corr = df_corr.dropna(how="all", subset=cols)

# Compute pairwise Spearman correlation
corr_matrix = df_corr.corr(method="spearman")

# Cluster the correlation matrix and show clustered heatmap
from scipy.cluster.hierarchy import linkage, leaves_list
import seaborn as sns

# Compute linkage for rows and columns
linkage_rows = linkage(corr_matrix, method='average')
linkage_cols = linkage(corr_matrix.T, method='average')

# Get the order of rows and columns after clustering
row_order = leaves_list(linkage_rows)
col_order = leaves_list(linkage_cols)

# Reorder the correlation matrix
corr_matrix_clustered = corr_matrix.iloc[row_order, col_order]

plt.figure(figsize=(8, 6), dpi=300)
sns.heatmap(
    corr_matrix_clustered, annot=True, cmap="vlag", center=0,
    linewidths=0.5, cbar_kws={"label": "Spearman r"}
)
plt.title("Clustered Spearman Correlation: GENCIC Bias & Mouse Model Conn Diff")
plt.tight_layout()
plt.show()

In [None]:
# Top HypoConnected vs GENCIC Bias 
from scipy.stats import hypergeom

def compute_hypergeometric_pvalue(N_total, N_set1, N_set2, N_common):
    """
    Compute the p-value for observing at least N_common overlap between two sets
    of size N_set1 and N_set2 drawn from a population of size N_total.
    """
    # P(X >= N_common)
    # sf is "survival function" = 1 - cdf, so sf(N_common-1) = P(X >= N_common)
    pval = hypergeom.sf(N_common-1, N_total, N_set1, N_set2)
    return pval

GENCIC_STRs = GENCIC_intersect[GENCIC_intersect["Circuits.46"] == 1].index.values
N_total_STR = 211
N_GENCIC = len(GENCIC_STRs)
N_top = 44
N_bottom = 44

for mousemodel in ["shank3b", "chd8", "cntnap2", "mecp2"]:
    for method in ["CSF_merge", "GSR_merge"]:
        col = f"{mousemodel}_{method}_conn_diff"
        col = GENCIC_intersect[col].sort_values(ascending=False)
        top46 = col.head(N_top)
        bottom44 = col.tail(N_bottom)
        Common_hyper = set(GENCIC_STRs).intersection(set(top46.index))
        Common_hypo = set(GENCIC_STRs).intersection(set(bottom44.index))
        pval_hyper = compute_hypergeometric_pvalue(N_total_STR, N_GENCIC, N_top, len(Common_hyper))
        pval_hypo = compute_hypergeometric_pvalue(N_total_STR, N_GENCIC, N_bottom, len(Common_hypo))
        print(f"{mousemodel} {method} | Hyper: {len(Common_hyper)} (p={pval_hyper:.4g}), Hypo: {len(Common_hypo)} (p={pval_hypo:.4g})")
        #print(Common_hyper)
        #print(Common_hypo)
