# ASD Circuit Search: Pareto Front Analysis

This notebook visualizes the results of the GENCIC circuit search pipeline.

**Pipeline**: Run via `Snakefile.circuit` (see `config/circuit_config.yaml`)

**Input files**:
1. ASD mutation bias: `results/STR_ISH/ASD_All_bias_addP_sibling.csv`
2. Pareto front: `results/CircuitSearch/ASD_SPARK_61/pareto_fronts/ASD_SPARK_61_size_46_pareto_front.csv`
3. Sibling null profiles: `results/CircuitSearch_Sibling_Summary/Mutability/size_46/sibling_profiles.npz`
4. Bootstrap pareto fronts: `results/CircuitSearch_Bootstrap/*/pareto_fronts/*_pareto_front.csv`

In [None]:
%load_ext autoreload
%autoreload 2
import sys
import os

ProjDIR = "/home/jw3514/Work/ASD_Circuits_CellType"
os.chdir(ProjDIR + "/notebooks_mouse_str")
sys.path.insert(1, ProjDIR + '/src')

from ASD_Circuits import *

# 1. Load ASD Bias and Pareto Front

In [None]:
# Load ASD mutation bias (with mutability-weighted sibling null p-values)
Spark_ASD_STR_Bias = pd.read_csv("../results/STR_ISH/ASD_All_bias_addP_sibling.csv", index_col=0)
print(f"Bias: {len(Spark_ASD_STR_Bias)} structures, "
      f"top EFFECT={Spark_ASD_STR_Bias['EFFECT'].iloc[0]:.4f} at {Spark_ASD_STR_Bias.index[0]}")
print(f"Significant: {(Spark_ASD_STR_Bias['q-value'] < 0.05).sum()} at q<0.05, "
      f"{(Spark_ASD_STR_Bias['q-value'] < 0.10).sum()} at q<0.10")
Spark_ASD_STR_Bias.head(3)

In [None]:
# Load pareto front from pipeline
pareto_df = pd.read_csv(
    "../results/CircuitSearch/ASD_SPARK_61/pareto_fronts/ASD_SPARK_61_size_46_pareto_front.csv")
pareto_df = pareto_df.sort_values('mean_bias', ascending=False).reset_index(drop=True)
print(f"Pareto front: {len(pareto_df)} points (size 46)")
pareto_df.head()

In [None]:
# Load connectome matrices (needed for scoring)
InfoMat = pd.read_csv("../dat/allen-mouse-conn/ConnectomeScoringMat/InfoMat.Ipsi.csv", index_col=0)

In [None]:
# Parse structures from pareto front
pareto_structures = [row['structures'].split(',') for _, row in pareto_df.iterrows()]

# Select circuit: 3rd row (index 2) â€” high-bias knee of the Pareto front
SELECTED_IDX = 2
selected_structures = pareto_structures[SELECTED_IDX]
print(f"Selected circuit (index {SELECTED_IDX}):")
print(f"  Mean bias: {pareto_df.loc[SELECTED_IDX, 'mean_bias']:.4f}")
print(f"  Circuit score: {pareto_df.loc[SELECTED_IDX, 'circuit_score']:.4f}")
print(f"  {len(selected_structures)} structures")
print(f"  Regions: {RegionDistributionsList(selected_structures)}")

# 2. ASD Pareto Front

In [None]:
fig, ax = plt.subplots(dpi=120, figsize=(5, 5))
ax.plot(pareto_df['circuit_score'], pareto_df['mean_bias'],
        marker='.', color='#542788', lw=2, markersize=8, ls='-', label='ASD')
ax.scatter(pareto_df.loc[SELECTED_IDX, 'circuit_score'],
           pareto_df.loc[SELECTED_IDX, 'mean_bias'],
           marker='x', s=50, color='red', zorder=100, label='Selected Circuit')
ax.set_xlabel("Circuit Connectivity Score")
ax.set_ylabel("Mean Structure Bias")
ax.grid(True, alpha=0.3)
ax.legend()
plt.tight_layout()
plt.show()

# 3. ASD vs Sibling Null Circuits

Sibling null circuit profiles are generated by `Snakefile.circuit.sibling` using
mutability-weighted sibling gene sets. The aggregated profiles (NPZ) contain
mean and individual Pareto front curves from sibling null SA searches.

In [None]:
# Load sibling circuit profiles from pipeline
sibling_npz_path = "../results/CircuitSearch_Sibling_Summary/Mutability/size_46/sibling_profiles.npz"
sibling_data = np.load(sibling_npz_path)
sib_meanbias = sibling_data['meanbias']
sib_meanSI = sibling_data['meanSI']
sib_topbias_sub = sibling_data['topbias_sub']
print(f"Sibling profiles: {sib_topbias_sub.shape[0]} iterations, "
      f"{sib_topbias_sub.shape[2]} bias limit points each")

In [None]:
fig, ax = plt.subplots(dpi=480, figsize=(4.2, 4))

# ASD pareto front
ax.plot(pareto_df['circuit_score'], pareto_df['mean_bias'],
        marker='.', color='#542788', lw=2, markersize=8, ls='-', label='ASD')
ax.scatter(pareto_df.loc[SELECTED_IDX, 'circuit_score'],
           pareto_df.loc[SELECTED_IDX, 'mean_bias'],
           marker='x', s=70, color='red', lw=2, zorder=100)
ax.text(pareto_df.loc[SELECTED_IDX, 'circuit_score'],
        0.01 + pareto_df.loc[SELECTED_IDX, 'mean_bias'],
        s="Selected\n Circuit", fontsize=12, ha='left')

# Individual sibling profiles (gray)
ax.plot(sib_topbias_sub[:, 1, :].T, sib_topbias_sub[:, 0, :].T,
        color='grey', markersize=1, lw=0.5, ls='-', alpha=0.05)

# Average sibling
ax.plot(sib_meanSI, sib_meanbias, marker='.', color='Orange', lw=2,
        markersize=8, ls='-', alpha=1, label='Average Sibling Circuit')
ax.plot(sib_meanSI, sib_meanbias, color='grey', lw=2, markersize=8,
        ls='-', alpha=1, label='Sibling Circuit', zorder=0)

ax.legend(loc='lower left', frameon=False)
ax.set_xlabel("Circuit Connectivity Score", fontsize=14)
ax.set_ylabel("Average Mutation Bias", fontsize=14)
ax.grid(True, alpha=0.2)
ax.set_ylim(0.05, 0.42)
plt.tight_layout()
plt.show()

# 4. Bootstrap ASD

## 4.1 Extract and Visualize Bootstrap Pareto Fronts

Bootstrap pareto fronts are generated by `Snakefile.circuit.bootstrap`.
Each bootstrap sample resamples mutations and runs the full SA pipeline.

In [None]:
import glob

Boot_DIR = "../results/CircuitSearch_Bootstrap/"
pareto_files = glob.glob(Boot_DIR + "*/pareto_fronts/*_pareto_front.csv")
pareto_files.sort()
print(f"Found {len(pareto_files)} pareto front files")

In [None]:
# Read and combine all pareto front CSV files
all_pareto_data = []
for pf_file in pareto_files:
    boot_id = os.path.basename(os.path.dirname(os.path.dirname(pf_file)))
    df = pd.read_csv(pf_file)
    df['boot_id'] = boot_id
    all_pareto_data.append(df)

combined_pareto_df = pd.concat(all_pareto_data, ignore_index=True)
print(f"Total rows: {len(combined_pareto_df)}, "
      f"boot IDs: {combined_pareto_df['boot_id'].nunique()}")

In [None]:
# Bootstrap 95% CI via interpolation
from scipy import interpolate

boot_samples = [f'ASD_Boot{i}' for i in range(1000)]
boot_samples_in_data = [bid for bid in boot_samples if bid in combined_pareto_df['boot_id'].values]

all_circuit_scores = []
for boot_id in boot_samples_in_data:
    boot_data = combined_pareto_df[combined_pareto_df['boot_id'] == boot_id]
    all_circuit_scores.extend(boot_data['circuit_score'].values)

circuit_score_grid = np.linspace(np.min(all_circuit_scores), np.max(all_circuit_scores), 100)

interpolated_bias_values = []
for boot_id in boot_samples_in_data:
    boot_data = combined_pareto_df[combined_pareto_df['boot_id'] == boot_id].sort_values('circuit_score')
    if len(boot_data) >= 2:
        f_interp = interpolate.interp1d(
            boot_data['circuit_score'].values, boot_data['mean_bias'].values,
            kind='linear', bounds_error=False, fill_value=np.nan)
        interpolated_bias_values.append(f_interp(circuit_score_grid))

interpolated_bias_array = np.array(interpolated_bias_values)
lower_ci = np.nanpercentile(interpolated_bias_array, 2.5, axis=0)
upper_ci = np.nanpercentile(interpolated_bias_array, 97.5, axis=0)
median_bias = np.nanmedian(interpolated_bias_array, axis=0)
print(f"Computed 95% CI from {len(interpolated_bias_values)} bootstrap samples")

In [None]:
# Bootstrap CI plot
fig, ax = plt.subplots(dpi=120, figsize=(8, 6))

ax.fill_between(circuit_score_grid, lower_ci, upper_ci,
                color='grey', alpha=0.3, label='95% CI (Bootstrap)', zorder=1)
ax.plot(circuit_score_grid, median_bias,
        color='grey', lw=2, ls='--', label='Median Bootstrap', zorder=5)

spark_samples = ['ASD_SPARK_Main', 'ASD_SPARK_61']
colors_spark = ['#542788', '#d95f02']
for i, spark_id in enumerate(spark_samples):
    if spark_id in combined_pareto_df['boot_id'].values:
        spark_data = combined_pareto_df[combined_pareto_df['boot_id'] == spark_id].sort_values('circuit_score')
        ax.plot(spark_data['circuit_score'], spark_data['mean_bias'],
                marker='o', markersize=6, lw=2, label=spark_id, color=colors_spark[i], zorder=10)

ax.set_xlabel("Circuit Connectivity Score", fontsize=12)
ax.set_ylabel("Average Mutation Bias", fontsize=12)
ax.set_title("Bootstrap Pareto Fronts with 95% CI", fontsize=14)
ax.grid(True, alpha=0.3)
ax.legend(loc='best', fontsize=10)
plt.tight_layout()
plt.show()

In [None]:
# Individual bootstrap lines + CI
fig, ax = plt.subplots(dpi=120, figsize=(8, 6))

for boot_id in boot_samples_in_data[:100]:
    boot_data = combined_pareto_df[combined_pareto_df['boot_id'] == boot_id].sort_values('circuit_score')
    ax.plot(boot_data['circuit_score'], boot_data['mean_bias'],
            color='grey', lw=0.5, alpha=0.15, zorder=1)

ax.fill_between(circuit_score_grid, lower_ci, upper_ci,
                color='lightblue', alpha=0.5, label='95% CI (Bootstrap)', zorder=5)
ax.plot(circuit_score_grid, median_bias,
        color='navy', lw=2.5, ls='-', label='Median Bootstrap', zorder=8)

for i, spark_id in enumerate(spark_samples):
    if spark_id in combined_pareto_df['boot_id'].values:
        spark_data = combined_pareto_df[combined_pareto_df['boot_id'] == spark_id].sort_values('circuit_score')
        ax.plot(spark_data['circuit_score'], spark_data['mean_bias'],
                marker='o', markersize=6, lw=2.5, label=spark_id, color=colors_spark[i], zorder=10)

ax.set_xlabel("Circuit Connectivity Score", fontsize=12)
ax.set_ylabel("Average Mutation Bias", fontsize=12)
ax.set_title("Bootstrap Pareto Fronts: Individual Lines + 95% CI", fontsize=14)
ax.grid(True, alpha=0.3)
ax.legend(loc='best', fontsize=10)
plt.tight_layout()
plt.show()

In [None]:
# Save combined pareto front data
summary_dir = "../results/CircuitSearch_Bootstrap_Summary/"
os.makedirs(summary_dir, exist_ok=True)
output_file = summary_dir + "all_pareto_fronts_size_46_combined.csv"
combined_pareto_df.to_csv(output_file, index=False)
print(f"Saved combined pareto front data to: {output_file}")

## 4.2 Combined: ASD + Sibling + Bootstrap

In [None]:
fig, ax = plt.subplots(dpi=480, figsize=(4.2, 4))

# ASD pareto front
ax.plot(pareto_df['circuit_score'], pareto_df['mean_bias'],
        marker='.', color='#542788', lw=2, markersize=8, ls='-', label='ASD')
ax.scatter(pareto_df.loc[SELECTED_IDX, 'circuit_score'],
           pareto_df.loc[SELECTED_IDX, 'mean_bias'],
           marker='x', s=70, color='red', lw=2, zorder=100)

# Bootstrap 95% CI
ax.fill_between(circuit_score_grid, lower_ci, upper_ci,
                color='lightblue', alpha=0.5, label='95% CI (Bootstrap)', zorder=5)

# Individual sibling profiles
ax.plot(sib_topbias_sub[:, 1, :].T, sib_topbias_sub[:, 0, :].T,
        color='grey', markersize=1, lw=0.5, ls='-', alpha=0.05)

# Average sibling
ax.plot(sib_meanSI, sib_meanbias, marker='.', color='Orange', lw=2,
        markersize=8, ls='-', alpha=1, label='Average Sibling Circuit')
ax.plot(sib_meanSI, sib_meanbias, color='grey', lw=2, markersize=8,
        ls='-', alpha=1, label='Sibling Circuit', zorder=0)

ax.text(pareto_df.loc[SELECTED_IDX, 'circuit_score'],
        0.01 + pareto_df.loc[SELECTED_IDX, 'mean_bias'],
        s="Selected\n Circuit", zorder=1000, fontsize=12, ha='left')

ax.legend(loc='lower left', frameon=False)
ax.set_xlabel("Circuit Connectivity Score", fontsize=14)
ax.set_ylabel("Average Mutation Bias", fontsize=14)
ax.grid(True, alpha=0.2)
ax.set_ylim(0.00, 0.44)
plt.tight_layout()
plt.show()

# 5. Structure Overlap Analysis: Selected vs Bootstrap Circuits

This section compares structure overlap between:
1. The selected circuit (from the main Pareto front)
2. Each bootstrap circuit (3rd point on each bootstrap Pareto front)

In [None]:
selected_circuit_structures = set(selected_structures)
print(f"Selected circuit has {len(selected_circuit_structures)} structures")

In [None]:
# Extract 3rd point (index 2) from each bootstrap pareto front
bootstrap_3rd_point_structures = []
bootstrap_3rd_point_data = []

boot_ids = [bid for bid in combined_pareto_df['boot_id'].unique() if bid.startswith('ASD_Boot')]
boot_ids.sort()

for boot_id in boot_ids:
    boot_data = combined_pareto_df[combined_pareto_df['boot_id'] == boot_id].copy()
    boot_data = boot_data.sort_values('mean_bias', ascending=False).reset_index(drop=True)
    if len(boot_data) >= 3:
        third_point = boot_data.iloc[2]
        structures_set = set(third_point['structures'].split(','))
        bootstrap_3rd_point_structures.append(structures_set)
        bootstrap_3rd_point_data.append({
            'boot_id': boot_id, 'structures': structures_set,
            'mean_bias': third_point['mean_bias'],
            'circuit_score': third_point['circuit_score']
        })

print(f"Extracted 3rd point from {len(bootstrap_3rd_point_structures)} bootstrap samples")

In [None]:
# Calculate structure overlap
overlap_scores = []
for i, boot_structures in enumerate(bootstrap_3rd_point_structures):
    intersection = selected_circuit_structures.intersection(boot_structures)
    union = selected_circuit_structures.union(boot_structures)
    jaccard = len(intersection) / len(union) if len(union) > 0 else 0
    overlap_pct = len(intersection) / len(selected_circuit_structures)
    overlap_scores.append({
        'boot_id': bootstrap_3rd_point_data[i]['boot_id'],
        'jaccard_similarity': jaccard, 'overlap_percentage': overlap_pct,
        'n_intersection': len(intersection),
        'n_selected': len(selected_circuit_structures),
        'n_bootstrap': len(boot_structures)
    })

overlap_df = pd.DataFrame(overlap_scores)
print(f"Overlap statistics:")
print(f"  Mean Jaccard similarity: {overlap_df['jaccard_similarity'].mean():.4f}")
print(f"  Mean overlap percentage: {overlap_df['overlap_percentage'].mean():.4f}")
print(f"  Mean intersection size: {overlap_df['n_intersection'].mean():.2f}")

In [None]:
# Histogram of overlapping structures
fig, ax = plt.subplots(dpi=120, figsize=(8, 5))
mean_intersection = overlap_df['n_intersection'].mean()
ax.hist(overlap_df['n_intersection'], bins=30, edgecolor='black', alpha=0.7, color='#542788')
ax.axvline(mean_intersection, color='red', linestyle='--', linewidth=2, label=f'Mean: {mean_intersection:.1f}')
ax.set_xlabel('Number of Overlapping Structures', fontsize=12)
ax.set_ylabel('Frequency', fontsize=12)
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# Jaccard and overlap percentage distributions
fig, axes = plt.subplots(1, 2, dpi=120, figsize=(12, 5))

axes[0].hist(overlap_df['jaccard_similarity'], bins=30, edgecolor='black', alpha=0.7, color='#542788')
axes[0].axvline(overlap_df['jaccard_similarity'].mean(), color='red', linestyle='--', linewidth=2,
                label=f'Mean: {overlap_df["jaccard_similarity"].mean():.3f}')
axes[0].set_xlabel('Jaccard Similarity', fontsize=12)
axes[0].set_ylabel('Frequency', fontsize=12)
axes[0].set_title('Jaccard Similarity', fontsize=14)
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].hist(overlap_df['overlap_percentage'], bins=30, edgecolor='black', alpha=0.7, color='#542788')
axes[1].axvline(overlap_df['overlap_percentage'].mean(), color='red', linestyle='--', linewidth=2,
                label=f'Mean: {overlap_df["overlap_percentage"].mean():.3f}')
axes[1].set_xlabel('Overlap Percentage', fontsize=12)
axes[1].set_ylabel('Frequency', fontsize=12)
axes[1].set_title('Overlap: % of Selected Circuit', fontsize=14)
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 5.1 Top Structure Frequency in Bootstrap Circuits

In [None]:
# Top 20 structures by bias in selected circuit
selected_circuit_bias = sorted(
    [(s, Spark_ASD_STR_Bias.loc[s, 'EFFECT']) for s in selected_structures],
    key=lambda x: x[1], reverse=True)

top20_structures_by_bias = [s for s, _ in selected_circuit_bias[:20]]
top20_structures_by_bias_set = set(top20_structures_by_bias)

print("Top 20 structures by bias in the selected circuit:")
for i, (struct, bias) in enumerate(selected_circuit_bias[:20], 1):
    print(f"  {i}. {struct} (bias: {bias:.4f})")

In [None]:
# Calculate bootstrap frequency for top 20
top20_frequency = {}
for struct in top20_structures_by_bias_set:
    freq = np.mean([1 if struct in bs else 0 for bs in bootstrap_3rd_point_structures])
    top20_frequency[struct] = freq
    print(f"{struct}: {freq*100:.1f}% ({int(freq*len(bootstrap_3rd_point_structures))}"
          f"/{len(bootstrap_3rd_point_structures)})")

In [None]:
# Top 10 together check
top10_structures_by_bias = top20_structures_by_bias[:10]
top10_structures_by_bias_set = set(top10_structures_by_bias)
top10_frequency = {s: top20_frequency[s] for s in top10_structures_by_bias_set}

n_with_all_top10 = sum(1 for bs in bootstrap_3rd_point_structures
                       if top10_structures_by_bias_set.issubset(bs))
total_bootstraps = len(bootstrap_3rd_point_structures)

print(f"\nTop 10 structures present together:")
print(f"  {n_with_all_top10}/{total_bootstraps} ({n_with_all_top10/total_bootstraps*100:.1f}%)")

In [None]:
# Enhanced bar chart: top 20 structure frequency
import matplotlib.patheffects as pe
import seaborn as sns

fig, ax = plt.subplots(dpi=150, figsize=(11, 7))

structures_list = top20_structures_by_bias
frequencies = [top20_frequency.get(s, 0) * 100 for s in structures_list]
structures_list_display = [s.replace("_", " ") for s in structures_list]

color_map = []
for f in frequencies:
    if f >= 90:
        color_map.append(sns.color_palette("dark:#17B978", 15).as_hex()[8])
    elif f >= 75:
        color_map.append(sns.color_palette("dark:#FFC300", 15).as_hex()[8])
    else:
        color_map.append(sns.color_palette("dark:#FF5733", 15).as_hex()[8])

bars = ax.barh(range(len(structures_list)), frequencies, color=color_map, alpha=0.87,
               edgecolor='k', linewidth=1.3, height=0.67, zorder=10)

for i, (bar, freq) in enumerate(zip(bars, frequencies)):
    ax.text(freq + 1, bar.get_y() + bar.get_height() / 2, f'{freq:.1f}%',
            va='center', ha='left', fontsize=11, weight='bold', color=color_map[i],
            path_effects=[pe.withStroke(linewidth=2.7, foreground="white")])

ax.set_yticks(range(len(structures_list)))
ax.set_yticklabels(structures_list_display, fontsize=11, fontfamily='monospace')
ax.invert_yaxis()
ax.set_xlabel('Frequency in Bootstrap Circuits (%)', fontsize=13, labelpad=13)
ax.set_xlim(0, 105)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.grid(True, alpha=0.18, axis='x', linestyle='-', zorder=1, color='#888C')
plt.tight_layout(rect=[0, 0, 0.88, 1])
plt.show()

In [None]:
# All selected circuit structures: frequency in bootstrap
all_structures_frequency = {}
all_structures_bias = {}
for struct in selected_circuit_structures:
    freq = np.mean([1 if struct in bs else 0 for bs in bootstrap_3rd_point_structures])
    all_structures_frequency[struct] = freq
    all_structures_bias[struct] = Spark_ASD_STR_Bias.loc[struct, 'EFFECT']

all_structures_sorted = sorted(selected_circuit_structures,
                               key=lambda x: all_structures_bias[x], reverse=True)
all_frequencies = [all_structures_frequency[s] * 100 for s in all_structures_sorted]

fig, ax = plt.subplots(dpi=120, figsize=(14, 8))
colors = ['green' if f >= 90 else 'orange' if f >= 75 else 'red' for f in all_frequencies]
bars = ax.barh(range(len(all_structures_sorted)), all_frequencies,
               color=colors, alpha=0.7, edgecolor='black')
ax.axvline(90, color='red', linestyle='--', linewidth=2, label='90% threshold', zorder=0)
ax.axvline(75, color='orange', linestyle='--', linewidth=1.5, alpha=0.5, label='75% threshold', zorder=0)

for i, (bar, freq) in enumerate(zip(bars, all_frequencies)):
    ax.text(freq + 1, i, f'{freq:.1f}%', va='center', fontsize=8)

ax.set_yticks(range(len(all_structures_sorted)))
ax.set_yticklabels(all_structures_sorted, fontsize=8)
ax.invert_yaxis()
ax.set_xlabel('Frequency in Bootstrap Circuits (%)', fontsize=12)
ax.set_title('All Selected Circuit Structures: Frequency in Bootstrap Circuits', fontsize=14)
ax.set_xlim(0, 105)
ax.legend()
ax.grid(True, alpha=0.3, axis='x')
plt.tight_layout()
plt.show()

n_90pct = sum(1 for f in all_frequencies if f >= 90)
n_75pct = sum(1 for f in all_frequencies if f >= 75)
n_total = len(all_structures_sorted)
print(f"\nSummary for all {n_total} structures in selected circuit:")
print(f"  >=90%: {n_90pct}/{n_total} ({n_90pct/n_total*100:.1f}%)")
print(f"  >=75%: {n_75pct}/{n_total} ({n_75pct/n_total*100:.1f}%)")
print(f"  <75%: {n_total-n_75pct}/{n_total} ({(n_total-n_75pct)/n_total*100:.1f}%)")