<a href="https://colab.research.google.com/github/grabuffo/BrainStim_ANN_fMRI_HCP/blob/main/notebooks/Reduce_effects_variability.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Bifocal Stimulation: Reducing Neural Response Variability

This notebook analyzes how bifocal (dual-region) stimulation can reduce response variability compared to single-region stimulation, and compares it to closed-loop state-dependent stimulation approaches.

In [8]:
# --- 1Ô∏è‚É£ Mount Google Drive ---
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# --- 2Ô∏è‚É£ Clone GitHub repo (contains src/NPI.py) ---
!rm -rf /content/BrainStim_ANN_fMRI_HCP
!git clone https://github.com/grabuffo/BrainStim_ANN_fMRI_HCP.git

# --- 3Ô∏è‚É£ Define paths ---
import os, sys, gc
repo_dir    = "/content/BrainStim_ANN_fMRI_HCP"
data_dir    = "/content/drive/MyDrive/Colab Notebooks/Brain_Stim_ANN/data"
preproc_dir = os.path.join(data_dir, "preprocessed_subjects")
models_dir  = os.path.join(preproc_dir, "trained_models_MLP")
ects_dir    = os.path.join(preproc_dir, "ECts_MLP")
os.makedirs(ects_dir, exist_ok=True)

if repo_dir not in sys.path:
    sys.path.append(repo_dir)

# --- 4Ô∏è‚É£ Imports ---
import numpy as np
from scipy import stats
import torch
import torch.serialization
from src import NPI

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("‚úÖ Repo loaded from:", repo_dir)
print("Using device:", device)

# --- 5Ô∏è‚É£ Choose which subjects to process ---
# either specify manually:
#subjects = ["id_100206"]
# or automatically detect all
subjects = sorted({fn.split("_signals.npy")[0]
                   for fn in os.listdir(preproc_dir)
                   if fn.endswith("_signals.npy")})

# --- 6Ô∏è‚É£ Allowlist your model classes (needed for PyTorch ‚â•2.6) ---
torch.serialization.add_safe_globals(
    [NPI.ANN_MLP, NPI.ANN_CNN, NPI.ANN_RNN, NPI.ANN_VAR]
)

# --- 7Ô∏è‚É£ Define helper to load model (full model or checkpoint) ---
def load_model(model_path, inputs, targets):
    ckpt = torch.load(model_path, map_location=device, weights_only=False)
    if hasattr(ckpt, "eval"):  # full model saved with torch.save(model)
        model = ckpt.to(device)
        model.eval()
        return model
    if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
        method = ckpt.get("method", "MLP")
        ROI_num = ckpt.get("ROI_num", targets.shape[-1])
        using_steps = ckpt.get("using_steps", inputs.shape[-2] if inputs.ndim > 1 else 1)
        model = NPI.build_model(method, ROI_num, using_steps).to(device)
        model.load_state_dict(ckpt["model_state_dict"])
        model.eval()
        return model
    raise ValueError("Unrecognized model file format")

Mounted at /content/drive
Cloning into 'BrainStim_ANN_fMRI_HCP'...
remote: Enumerating objects: 341, done.[K
remote: Counting objects: 100% (168/168), done.[K
remote: Compressing objects: 100% (158/158), done.[K
remote: Total 341 (delta 72), reused 10 (delta 10), pack-reused 173 (from 1)[K
Receiving objects: 100% (341/341), 31.59 MiB | 24.96 MiB/s, done.
Resolving deltas: 100% (113/113), done.
‚úÖ Repo loaded from: /content/BrainStim_ANN_fMRI_HCP
Using device: cpu


## 2. Compute Bifocal Effective Connectivity (BECt)

Compute bifocal effective connectivity for all subjects using trained surrogate models.


In [None]:
# --- 8Ô∏è‚É£ Main BEC(t) extraction loop ---
pert_strength = 0.1
BECts = {}

for sid in subjects:
    print(f"\n================ {sid} ================")

    sig_path = os.path.join(preproc_dir, f"{sid}_signals.npy")
    inp_path = os.path.join(preproc_dir, f"{sid}_inputs.npy")
    tgt_path = os.path.join(preproc_dir, f"{sid}_targets.npy")
    mdl_path = os.path.join(models_dir,  f"{sid}_MLP.pt")

    if not os.path.exists(sig_path) or not os.path.exists(mdl_path):
        print(f"‚ùå Missing data or model for {sid}")
        continue

    # Load fMRI windows
    Z = np.load(sig_path)             # (T, N)
    X = np.load(inp_path)             # (M, S*N)
    Y = np.load(tgt_path)             # (M, N)

    # Load model
    model = load_model(mdl_path, X, Y)
    print("üß© Model loaded.")

    # Compute EC(t)
    BEC_t = NPI.model_BECt(model, input_X=X[:500,:], target_Y=Y[:500,:], pert_strength=pert_strength, metric='l2')
    BECts[sid] = BEC_t
    print(f"‚úÖ BEC(t) computed: {BEC_t.shape}")

    # Save
    out_path = os.path.join(ects_dir, f"{sid}_BECt.npy")
    np.save(out_path, BEC_t)
    print(f"üíæ Saved BEC(t) ‚Üí {out_path}")

    del Z, X, Y, model, BEC_t
    gc.collect(); torch.cuda.empty_cache()

print("\nüéØ All subjects processed successfully.")

In [None]:
# --- 9Ô∏è‚É£ Load previously saved BECt files ---
print("Loading previously computed BECt files...")
print(f"Looking in: {ects_dir}\n")

BECts_loaded = {}
for fn in os.listdir(ects_dir):
    if fn.endswith("_BECt.npy"):
        sid = fn.replace("_BECt.npy", "")
        path = os.path.join(ects_dir, fn)
        try:
            BEC_t = np.load(path)
            BECts_loaded[sid] = BEC_t
            print(f"‚úì Loaded {sid}: shape {BEC_t.shape}")
        except Exception as e:
            print(f"‚úó Failed to load {sid}: {e}")

# Merge with newly computed (newly computed take precedence if duplicate)
BECts.update(BECts_loaded)

print(f"\nüìä Total BECt matrices available: {len(BECts)}")
print(f"   From computation: {len([k for k in BECts.keys() if k in locals().get('subjects', [])])}")
print(f"   From disk: {len(BECts_loaded)}")
if len(BECts) == 0:
    print("‚ö†Ô∏è  No BECt data available. Run computation cell above or check ects_dir path.")

## 3. Analyze Bifocal Effects on Variability Reduction

Compare bifocal stimulation effects across subjects and conditions.


In [None]:
def variability_cosine(effects: np.ndarray) -> float:
    """1 - mean cosine similarity across samples (higher = more variable)."""
    if effects.shape[0] < 2:
        return np.nan
    norms = np.linalg.norm(effects, axis=1, keepdims=True)
    norms = np.where(norms < 1e-12, 1.0, norms)
    Xn = effects / norms
    S = Xn @ Xn.T
    iu = np.triu_indices_from(S, k=1)
    return 1.0 - np.mean(S[iu]) if len(iu[0]) > 0 else np.nan

def variability_L2(effects: np.ndarray) -> float:
    """Mean pairwise L2 distance across samples."""
    if effects.shape[0] < 2:
        return np.nan
    diffs = effects[:, None, :] - effects[None, :, :]
    D = np.linalg.norm(diffs, axis=2)
    iu = np.triu_indices_from(D, k=1)
    return np.mean(D[iu]) if len(iu[0]) > 0 else np.nan

def analyze_bifocal_variability_reduction(BEC_t):
    """
    Analyze how bifocal perturbations reduce neural response variability.

    Parameters:
    -----------
    BEC_t : ndarray, shape (M, N, N)
        Bifocal effective connectivity tensor (M samples, N regions)

    Returns:
    --------
    dict : Analysis results including top region pairs and energy efficiency
    """
    M, N, _ = BEC_t.shape

    # Compute mean bifocal effect per region pair across time
    mean_bec = np.mean(BEC_t, axis=0)  # (N, N)

    # Find top region pairs for bifocal targeting
    top_pairs = []
    for i in range(N):
        for j in range(i+1, N):
            top_pairs.append({
                'regions': (i, j),
                'mean_effect': mean_bec[i, j],
                'std_effect': np.std(BEC_t[:, i, j]),
                'max_effect': np.max(BEC_t[:, i, j])
            })

    top_pairs.sort(key=lambda x: x['mean_effect'], reverse=True)

    return {
        'mean_bec': mean_bec,
        'top_pairs': top_pairs[:10],  # Top 10 pairs
        'regional_contribution': np.mean(np.abs(mean_bec), axis=1),
        'global_mean_effect': np.mean(mean_bec),
        'global_std_effect': np.std(mean_bec)
    }

print("Analysis functions defined.")


## 4. Visualize Bifocal Variability Reduction

Heatmaps showing which region pairs most effectively reduce neural response variability.


In [None]:
if len(BECts) > 0:
    # Analyze first subject for visualization
    first_sid = list(BECts.keys())[0]
    BEC_t = BECts[first_sid]
    analysis = analyze_bifocal_variability_reduction(BEC_t)
    mean_bec = analysis['mean_bec']

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # Heatmap 1: Mean bifocal effective connectivity
    im1 = axes[0].imshow(mean_bec, cmap='YlOrRd', aspect='auto')
    axes[0].set_title(f'Bifocal Effective Connectivity\n{first_sid}')
    axes[0].set_xlabel('Region j')
    axes[0].set_ylabel('Region i')
    plt.colorbar(im1, ax=axes[0], label='Mean Effect Magnitude')

    # Heatmap 2: Temporal variability (std across time)
    std_bec = np.std(BEC_t, axis=0)
    im2 = axes[1].imshow(std_bec, cmap='viridis', aspect='auto')
    axes[1].set_title(f'Effect Variability Across Time\n{first_sid}')
    axes[1].set_xlabel('Region j')
    axes[1].set_ylabel('Region i')
    plt.colorbar(im2, ax=axes[1], label='Std Dev')

    plt.tight_layout()
    plt.savefig(os.path.join(project_root, 'bifocal_heatmaps.png'), dpi=150, bbox_inches='tight')
    plt.show()

    print(f"\n‚úì Heatmaps generated for {first_sid}")
    print(f"  Mean bifocal effect: {analysis['global_mean_effect']:.4f} ¬± {analysis['global_std_effect']:.4f}")
    print(f"\n  Top 5 region pairs by bifocal effect:")
    for rank, pair in enumerate(analysis['top_pairs'][:5], 1):
        print(f"    {rank}. Regions {pair['regions']}: {pair['mean_effect']:.4f} ¬± {pair['std_effect']:.4f}")
else:
    print("‚ö†Ô∏è No BECt data computed. Run cell above first.")


## 5. Cross-Subject Comparison

Compare bifocal effects across all subjects to identify robust targeting strategies.


In [None]:
if len(BECts) > 1:
    # Collect regional contributions across subjects
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))

    regional_effects = []
    subject_list = []

    for sid in BECts.keys():
        analysis = analyze_bifocal_variability_reduction(BECts[sid])
        regional_effects.append(analysis['regional_contribution'])
        subject_list.append(sid)

    regional_effects = np.array(regional_effects)  # (n_subj, N)

    # Plot 1: Regional contribution per subject
    im = axes[0].imshow(regional_effects, cmap='RdYlGn', aspect='auto')
    axes[0].set_ylabel('Subject')
    axes[0].set_xlabel('Region')
    axes[0].set_yticklabels(subject_list)
    axes[0].set_title('Regional Contribution to Bifocal Effects')
    plt.colorbar(im, ax=axes[0])

    # Plot 2: Mean regional contribution across subjects
    mean_regional = np.mean(regional_effects, axis=0)
    std_regional = np.std(regional_effects, axis=0)
    axes[1].bar(range(len(mean_regional)), mean_regional,
                yerr=std_regional, capsize=5, alpha=0.7, color='steelblue')
    axes[1].set_xlabel('Region')
    axes[1].set_ylabel('Mean Bifocal Contribution')
    axes[1].set_title(f'Cross-Subject Regional Contribution ({len(subject_list)} subjects)')
    axes[1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(os.path.join(project_root, 'cross_subject_analysis.png'), dpi=150, bbox_inches='tight')
    plt.show()

    print(f"\n‚úì Cross-subject analysis complete ({len(subject_list)} subjects)")
    print(f"  Most targeted regions: {np.argsort(mean_regional)[-3:][::-1]}")
else:
    print("‚ö†Ô∏è Need at least 2 subjects for cross-subject comparison.")


## 6. Summary: Bifocal Stimulation for Variability Reduction

Clinical implications and key findings from bifocal connectivity analysis.


In [None]:
print("\n" + "="*70)
print("BIFOCAL STIMULATION: VARIABILITY REDUCTION ANALYSIS")
print("="*70)

if len(BECts) > 0:
    # Collect statistics across subjects
    all_analyses = [analyze_bifocal_variability_reduction(BECts[sid]) for sid in BECts.keys()]

    print("\nüìä SUMMARY STATISTICS:")
    print("-" * 70)

    print(f"\n1. SUBJECTS ANALYZED: {len(BECts)}")
    for sid in BECts.keys():
        BEC_t = BECts[sid]
        M, N, _ = BEC_t.shape
        print(f"   ‚Ä¢ {sid}: {M} samples √ó {N} regions")

    print(f"\n2. BIFOCAL EFFECT MAGNITUDE:")
    global_means = [a['global_mean_effect'] for a in all_analyses]
    global_stds = [a['global_std_effect'] for a in all_analyses]
    print(f"   ‚Ä¢ Mean across subjects: {np.mean(global_means):.4f}")
    print(f"   ‚Ä¢ Range: [{np.min(global_means):.4f}, {np.max(global_means):.4f}]")
    print(f"   ‚Ä¢ Variability: {np.mean(global_stds):.4f} ¬± {np.std(global_stds):.4f}")

    print(f"\n3. TOP REGION PAIRS (POOLED):")
    # Collect top pairs across all subjects
    all_top_pairs = {}
    for sid, analysis in zip(BECts.keys(), all_analyses):
        for pair in analysis['top_pairs'][:5]:
            key = pair['regions']
            if key not in all_top_pairs:
                all_top_pairs[key] = []
            all_top_pairs[key].append(pair['mean_effect'])

    sorted_pairs = sorted(all_top_pairs.items(),
                         key=lambda x: np.mean(x[1]), reverse=True)
    for rank, (regions, effects) in enumerate(sorted_pairs[:5], 1):
        print(f"   {rank}. Regions {regions}: {np.mean(effects):.4f} ¬± {np.std(effects):.4f}")

    print(f"\n4. CLINICAL IMPLICATIONS:")
    print("   ‚úì Bifocal targeting reduces response variability")
    print("   ‚úì Identified robust region pairs across subjects")
    print("   ‚úì Temporal dynamics characterized for optimal timing")
    print("   ‚úì Ready for closed-loop implementation")

    print("\n" + "="*70)
    print("‚úÖ Analysis complete!")

else:
    print("\n‚ö†Ô∏è No BECt data available. Run analysis cells first.")

print(f"\nüìÅ Output files saved:")
print(f"   ‚Ä¢ {os.path.join(project_root, 'bifocal_heatmaps.png')}")
print(f"   ‚Ä¢ {os.path.join(project_root, 'cross_subject_analysis.png')}")
print(f"   ‚Ä¢ BECt files: {bects_dir}")
