
### üß† SAINT Protocol ‚Äì fMRI-Guided DLPFC Targeting Pipeline

This pipeline implements the **SAINT (Stanford Accelerated Intelligent Neuromodulation Therapy)** protocol for identifying an optimal left DLPFC target for TMS based on individual resting-state fMRI connectivity with the sgACC.

The workflow consists of **three sequential steps**, designed to be **modular, reproducible, and method-agnostic** (supports multiple correlation metrics).

---

### üî¢ Pipeline Overview

| Step | Name | Purpose |
|------|------|--------|
| **üß± Step 0** | Environment Setup | Installs missing dependencies (`nibabel`, `plotly`, `kaleido`, etc.) |
| **üî¨ Step 1** | Functional Parcellation | Parcellates left DLPFC and bilateral sgACC into functional subunits using HAC |
| **üîç Step 2** | Optimal Subunit Selection | Ranks DLPFC subunits by anticorrelation, spatial concentration, and size |
| **üñºÔ∏è Step 3** | 3D Visualization | Renders the top DLPFC subunit as a smooth mesh with anatomical context |

> ‚ö†Ô∏è **Execution Order Matters**:  
> Run **Step 1 ‚Üí Step 2 ‚Üí Step 3** in sequence. Each step depends on the outputs of the previous one.

---

### üìÇ Required Directory Structure

Your data **must** be organized as follows:



#### üîë Naming Rules:
- Subject folders **must** be named `sub1`, `sub2`, `sub005`, etc. (prefix `sub` + number).
- **Mandatory fMRI file**: `filtered_func_data.nii.gz` in each subject folder.
- **ROI masks**: At least one of the following for each ROI:
  - `l_DLPFC_bin.nii.gz` or `l_DLPFC_func.nii.gz`
  - `sgACC_bin.nii.gz` or `sgACC_func.nii.gz`

> üí° **Tip**: The pipeline automatically prefers `_bin.nii.gz` over `_func.nii.gz` if both exist.

---

### ‚öôÔ∏è How to Configure

1. **Set your base directory** at the top of **Step 1** (`Algorithm 1`):
   ```python
   base_dir = "G:/SAINT/Subjects"  # üëà CHANGE THIS TO YOUR PATH

‚úÖ Output Files (per subject)  
After running all steps, you‚Äôll find:

### üß± Step 0: Environment Setup ‚Äî Dependency Check & Installation

Before running any part of the SAINT pipeline, ensure all required Python packages are installed. This step:

- **Checks** which packages are already present on your system.
- **Installs only the missing ones** (does not upgrade existing packages).
- Uses your current Python environment (respects virtual environments).

> ‚úÖ This is a **safe, non-destructive** setup step. It will not modify already-installed packages.

#### üì¶ Packages Managed:
- `nibabel` ‚Äì for neuroimaging data I/O  
- `numpy` ‚Äì core numerical computing  
- `pandas` ‚Äì data structures and Excel handling  
- `scipy` ‚Äì statistical tests and clustering  
- `plotly` ‚Äì 3D interactive visualization  
- `scikit-image` ‚Äì `marching_cubes` for mesh generation  
- `kaleido` ‚Äì required for saving Plotly figures as PNG/SVG  

#### ‚ñ∂Ô∏è How to Use:
Simply run this cell **once** at the beginning of your workflow. If all packages are present, it does nothing. If any are missing, it installs them silently via `pip`.

> ‚ö†Ô∏è **Note**: Write access to your Python environment is required. If you encounter permission errors, run with appropriate privileges or in a virtual environment.

In [1]:
import importlib
import subprocess
import sys
def get_installed_versions():
    """
    Returns a list of package names.
    If version is available, returns 'package==version'.
    If not installed or no version, returns just 'package'.
    """
    import importlib

    packages = {
        "nibabel": "nibabel",
        "numpy": "numpy",
        "pandas": "pandas",
        "scipy": "scipy",
        "plotly": "plotly",
        "scikit-image": "skimage",
        "kaleido": "kaleido"
    }

    result = []
    for pip_name, module_name in packages.items():
        try:
            mod = importlib.import_module(module_name)
            version = getattr(mod, "__version__", None)
            if version is not None:
                result.append(f"{pip_name}=={version}")
            else:
                result.append(pip_name)
        except ImportError:
            result.append(pip_name)  # not installed ‚Üí show name only
    return result


def install_missing_packages(requirements_list):
    """
    Install packages from a list of strings like ['numpy==1.26', 'pandas==2.2'].
    Only installs if the package is NOT already installed (any version).
    
    Parameters:
    requirements_list (list of str): e.g. ['nibabel==5.2.1', 'kaleido==0.2.1']
    """
    to_install = []
    
    for req in requirements_list:
        # ÿßÿ≥ÿ™ÿÆÿ±ÿßÿ¨ ŸÜÿßŸÖ Ÿæ⁄©€åÿ¨ ÿßÿ≤ 'nibabel==5.2.1' ‚Üí 'nibabel'
        package_name = req.split("==")[0].split(">=")[0].split("<=")[0].split("!=")[0]
        
        # ŸÜÿßŸÖ ŸÖÿß⁄òŸàŸÑ ŸÖÿ±ÿ®Ÿàÿ∑Ÿá (scikit-image ‚Üí skimage)
        module_name = "skimage" if package_name == "scikit-image" else package_name
        
        try:
            importlib.import_module(module_name)
        except ImportError:
            to_install.append(req)
    
    if to_install:
        print(f"Installing missing packages: {to_install}")
        subprocess.check_call(
            [sys.executable, "-m", "pip", "install"] + to_install
        )
    else:
        print("All packages already installed.")

install_missing_packages(get_installed_versions())  

All packages already installed.


### üî¨ Step 1: Functional Parcellation and Representative Time Series Extraction

In this step, two regions of interest ‚Äî **left DLPFC** and **bilateral sgACC** ‚Äî are functionally parcellated **independently**, using a **hierarchical clustering approach based on inter-voxel time series similarity**.

#### üìå Key Stages:

**1. Time Series Extraction:**  
Voxel-wise BOLD time series are extracted from the preprocessed fMRI data within each ROI mask.

**2. Pairwise Similarity Computation:**  
A correlation matrix is computed across all voxels in each ROI using one of the following methods:
- **Spearman rank correlation** (default, robust to outliers)  
- **Pearson correlation** (linear dependence)  
- **Kendall‚Äôs tau** (rank-based, more robust but computationally intensive)  
- **Zero-lag cross-correlation** (Pearson on z-scored time series)

**3. Hierarchical Agglomerative Clustering (HAC):**  
- Distance is defined as `distance = 1 - œÅ` (where œÅ is the correlation coefficient).  
- Clustering uses **average linkage**.  
- The dendrogram is cut at a **distance threshold ‚â§ 0.5**, equivalent to requiring **pairwise correlation œÅ ‚â• 0.5** within each subunit.

**4. Representative Voxel Selection:**  
For each functional subunit:  
- The **median time series** of all voxels in the cluster is computed.  
- The **representative voxel** is selected as the one whose time series has the **highest correlation** (using the same method as step 2) with this median.

**5. Output Generation:**  
- A **summary Excel file** (`subXXX_SAINT_Protocol_Output_{method}.xlsx`) containing:  
  - Cluster IDs, sizes, representative voxel coordinates, and full time series.  
- A **compressed `.npz` file** (`subXXX_cluster_voxel_coords_{method}.npz`) storing:  
  - Full voxel coordinates for every cluster.  
  - Cluster labels and the correlation method used.


In [None]:
# Final pipeline: Spearman-based HAC (cut at rho >= 0.5), representative voxel selection
import nibabel as nib
import numpy as np
import pandas as pd
import os
import re
from scipy.stats import spearmanr, pearsonr, kendalltau
from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial.distance import squareform
import warnings

warnings.filterwarnings("ignore", category=FutureWarning)

# -------------------------
# CONFIGURATION
# -------------------------
rho_threshold = 0.5
base_dir = "G:/SAINT/Subjects"
# -------------------------
# Helper: find mask file (prefer *_bin, fallback to *_func)
# -------------------------
def find_mask_file(subject_dir, base_name):
    """Try bin first, then func."""
    bin_path = os.path.join(subject_dir, f"{base_name}_bin.nii.gz")
    func_path = os.path.join(subject_dir, f"{base_name}_func.nii.gz")
    if os.path.exists(func_path):
        return func_path
    elif os.path.exists(bin_path):
        return bin_path
    else:
        return None

# -------------------------
# Correlation matrix computation with multiple methods
# -------------------------
def compute_correlation_matrix(voxel_ts, method="spearman"):
    """
    Compute correlation matrix using specified method.
    Supported methods: 'spearman', 'pearson', 'kendall', 'crosscorr'
    """
    if voxel_ts.size == 0:
        return np.array([[]])
    
    V, T = voxel_ts.shape
    
    # Handle single-voxel case: correlation matrix is [[1.0]]
    if V == 1:
        return np.array([[1.0]])
    
    if method == "spearman":
        rho, _ = spearmanr(voxel_ts, axis=1)
    elif method == "pearson":
        # ‚úÖ Use np.corrcoef ‚Äî pearsonr does NOT support axis!
        rho = np.corrcoef(voxel_ts)
    elif method == "crosscorr":
        # Zero-lag cross-correlation = Pearson on z-scored time series
        ts_norm = (voxel_ts - voxel_ts.mean(axis=1, keepdims=True)) / (voxel_ts.std(axis=1, keepdims=True) + 1e-10)
        rho = np.corrcoef(ts_norm)
    elif method == "kendall":
        # Warning for large ROIs
        if V > 300:
            print(f"‚ö†Ô∏è Warning: Kendall‚Äôs tau is very slow for large ROIs (V={V}). Consider using spearman/pearson.")
        rho = np.eye(V)
        for i in range(V):
            for j in range(i + 1, V):
                tau, _ = kendalltau(voxel_ts[i], voxel_ts[j])
                rho[i, j] = rho[j, i] = tau if not np.isnan(tau) else 0.0
    else:
        raise ValueError(f"Unsupported correlation method: {method}. Choose from: spearman, pearson, kendall, crosscorr")
    
    rho = np.nan_to_num(rho, nan=0.0)
    return np.clip(rho, -1.0, 1.0)

# -------------------------
# Main processing function for ONE subject
# -------------------------
def process_subject(subject_id, base_dir="G:/SAINT/Subjects", corr_method="spearman"):
    subject_dir = os.path.join(base_dir, f"sub{subject_id}")
    fmri_path = os.path.join(subject_dir, "filtered_func_data.nii.gz")
    
    # Validate fMRI file
    if not os.path.exists(fmri_path):
        print(f"‚ö†Ô∏è Skipping sub{subject_id}: fMRI file not found.")
        return False

    # Find mask files
    dlpfc_mask_path = find_mask_file(subject_dir, "l_DLPFC")
    sgacc_mask_path = find_mask_file(subject_dir, "sgACC")
    
    if dlpfc_mask_path is None:
        print(f"‚ö†Ô∏è Skipping sub{subject_id}: DLPFC mask not found (needs l_DLPFC_bin.nii.gz or l_DLPFC_func.nii.gz).")
        return False
    if sgacc_mask_path is None:
        print(f"‚ö†Ô∏è Skipping sub{subject_id}: sgACC mask not found (needs sgACC_bin.nii.gz or sgACC_func.nii.gz).")
        return False

    print(f"\nüîç Processing sub{subject_id} with correlation method: {corr_method}...")
    
    # Load data
    fmri_img = nib.load(fmri_path)
    fmri_data = fmri_img.get_fdata()
    affine = fmri_img.affine

    dlpfc_mask = nib.load(dlpfc_mask_path).get_fdata().astype(bool)
    sgacc_mask = nib.load(sgacc_mask_path).get_fdata().astype(bool)
    brain_mask = np.any(fmri_data != 0, axis=3)

    print("Voxels: brain={}, DLPFC={}, sgACC={}".format(
        np.sum(brain_mask), np.sum(dlpfc_mask), np.sum(sgacc_mask)
    ))

    # Extract time series
    def extract_mask_timeseries(fmri_data, mask):
        inds = np.argwhere(mask)
        ts = np.array([fmri_data[x, y, z, :] for x, y, z in inds])
        return inds, ts

    dlpfc_inds, dlpfc_ts = extract_mask_timeseries(fmri_data, dlpfc_mask)
    sgacc_inds, sgacc_ts = extract_mask_timeseries(fmri_data, sgacc_mask)

    print("DLPFC time series shape (V, T):", dlpfc_ts.shape)
    print("sgACC time series shape (V, T):", sgacc_ts.shape)

    # Spearman correlation ‚Üí replaced with general correlation
    rho_dlpfc = compute_correlation_matrix(dlpfc_ts, method=corr_method)
    rho_sgacc = compute_correlation_matrix(sgacc_ts, method=corr_method)

    # Distance matrix
    dist_dlpfc = 1.0 - rho_dlpfc
    dist_sgacc = 1.0 - rho_sgacc
    np.fill_diagonal(dist_dlpfc, 0.0)
    np.fill_diagonal(dist_sgacc, 0.0)

    # HAC clustering
    def hac_cut_by_rho(dist_matrix, rho_threshold=0.5, method='average'):
        if dist_matrix.size == 0:
            return np.array([], dtype=int)
        condensed = squareform(dist_matrix, checks=False)
        Z = linkage(condensed, method=method)
        max_dist = 1.0 - rho_threshold
        return fcluster(Z, t=max_dist, criterion='distance').astype(int)

    dlpfc_labels = hac_cut_by_rho(dist_dlpfc, rho_threshold=rho_threshold)
    sgacc_labels = hac_cut_by_rho(dist_sgacc, rho_threshold=rho_threshold)

    print("Number of DLPFC clusters:", len(np.unique(dlpfc_labels)))
    print("Number of sgACC clusters :", len(np.unique(sgacc_labels)))

    # Representative selection + keep all voxel coords
    def representatives_from_clusters(inds, voxel_ts, labels):
        reps = []
        for lab in np.unique(labels):
            mask = (labels == lab)
            cluster_ts = voxel_ts[mask]
            cluster_coords = inds[mask]  # ‚Üê ALL voxels in this cluster
            if cluster_ts.shape[0] == 0:
                continue
            median_ts = np.median(cluster_ts, axis=0)
            if cluster_ts.shape[0] == 1:
                chosen_voxel = cluster_coords[0]
                chosen_ts = cluster_ts[0]
            else:
                stack = np.vstack([median_ts, cluster_ts])
                # Use the same correlation method for representative selection
                corr_mat = compute_correlation_matrix(stack, method=corr_method)
                rho_with_median = corr_mat[0, 1:]
                chosen_idx = int(np.argmax(rho_with_median))
                chosen_voxel = cluster_coords[chosen_idx]
                chosen_ts = cluster_ts[chosen_idx]
            reps.append({
                "cluster_label": int(lab),
                "voxel_coords": cluster_coords,          # ‚Üê FULL list of voxel coordinates (Nx3)
                "rep_coords": tuple(map(int, chosen_voxel)),
                "timeseries": chosen_ts,
                "size": cluster_ts.shape[0]
            })
        return reps

    dlpfc_reps = representatives_from_clusters(dlpfc_inds, dlpfc_ts, dlpfc_labels)
    sgacc_reps = representatives_from_clusters(sgacc_inds, sgacc_ts, sgacc_labels)

    print("DLPFC representatives:", len(dlpfc_reps))
    print("sgACC representatives :", len(sgacc_reps))

    # -------------------------
    # SAVE 1: Excel summary (filename includes corr_method)
    # -------------------------
    output_dir = os.path.join(subject_dir, "SAINT-PROTOCOL")
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, f"sub{subject_id}_SAINT_Protocol_Output_{corr_method}.xlsx")

    def build_cluster_df(reps):
        if not reps:
            return pd.DataFrame()
        df = pd.DataFrame([{
            "ClusterID": r["cluster_label"],
            "Size": r["size"],
            "RepVoxel_X": r["rep_coords"][0],
            "RepVoxel_Y": r["rep_coords"][1],
            "RepVoxel_Z": r["rep_coords"][2]
        } for r in reps])
        ts_array = np.array([r["timeseries"] for r in reps])
        ts_df = pd.DataFrame(ts_array, columns=[f"T{i}" for i in range(ts_array.shape[1])])
        return pd.concat([df, ts_df], axis=1)

    dlpfc_df = build_cluster_df(dlpfc_reps)
    sgacc_df = build_cluster_df(sgacc_reps)

    summary_df = pd.DataFrame({
        "Region": ["DLPFC", "sgACC"],
        "N_Clusters": [len(dlpfc_reps), len(sgacc_reps)],
        "Correlation_Method": [corr_method, corr_method]
    })

    with pd.ExcelWriter(output_path, engine='openpyxl') as writer:
        summary_df.to_excel(writer, sheet_name='Summary', index=False)
        if not dlpfc_df.empty:
            dlpfc_df.to_excel(writer, sheet_name='DLPFC_Clusters', index=False)
        if not sgacc_df.empty:
            sgacc_df.to_excel(writer, sheet_name='sgACC_Clusters', index=False)

    # -------------------------
    # SAVE 2: Full voxel coordinates in .npz file (filename includes corr_method)
    # -------------------------
    voxel_coords_path = os.path.join(output_dir, f"sub{subject_id}_cluster_voxel_coords_{corr_method}.npz")
    
    # Convert lists to object arrays for safe saving
    dlpfc_voxel_arrays = [r["voxel_coords"] for r in dlpfc_reps]
    sgacc_voxel_arrays = [r["voxel_coords"] for r in sgacc_reps]
    
    np.savez_compressed(
        voxel_coords_path,
        dlpfc_voxel_coords=np.array(dlpfc_voxel_arrays, dtype=object),
        sgacc_voxel_coords=np.array(sgacc_voxel_arrays, dtype=object),
        dlpfc_cluster_labels=np.array([r["cluster_label"] for r in dlpfc_reps]),
        sgacc_cluster_labels=np.array([r["cluster_label"] for r in sgacc_reps]),
        correlation_method=corr_method,
        allow_pickle=True
    )

    print(f"\n‚úÖ Results ({corr_method}) saved to:\n{output_path}")
    print(f"üì¶ Full voxel coordinates ({corr_method}) saved to:\n{voxel_coords_path}")
    print(f"- DLPFC clusters: {len(dlpfc_reps)}")
    print(f"- sgACC clusters: {len(sgacc_reps)}")
    return True

# -------------------------
# -------------------------
# Main execution logic
# -------------------------
if __name__ == "__main__":
    # ‚öôÔ∏è CONFIGURATION
    
    # Choose mode:
    run_all_methods = True  # Set to False to run only one method
    corr_method = "spearman"  # Used only if run_all_methods = False

    # List of all supported correlation methods
    all_methods = ["spearman", "pearson", "kendall", "crosscorr"]

    # Decide which methods to run
    methods_to_run = all_methods if run_all_methods else [corr_method]

    # Single subject or batch?
    subject_id = None  # Set to e.g., "001" for single subject; None for batch

    if subject_id is not None:
        # Single subject: run all selected methods
        for method in methods_to_run:
            print(f"\nüöÄ Running Algorithm 1 for sub{subject_id} with method: {method}")
            success = process_subject(subject_id, base_dir, corr_method=method)
            if success:
                print(f"\nüéâ Subject sub{subject_id} processed successfully with {method}!")
            else:
                print(f"\n‚ùå Failed to process sub{subject_id} with {method}.")
    else:
        # Batch mode: for each subject, run all selected methods
        subject_folders = [f for f in os.listdir(base_dir) if re.match(r'^sub\d+$', f)]
        subject_ids = sorted([f.replace('sub', '') for f in subject_folders], key=int)
        total_subjects = len(subject_ids)

        print(f"Found {total_subjects} subjects. Starting batch processing...\n")
        
        for sid in subject_ids:
            print(f"\n{'='*60}")
            print(f"Processing subject: sub{sid}")
            print(f"{'='*60}")
            
            completed_methods = 0
            for method in methods_to_run:
                try:
                    print(f"\n‚Üí Running method: {method}")
                    success = process_subject(sid, base_dir, corr_method=method)
                    if success:
                        completed_methods += 1
                        print(f"‚úÖ Completed {method} for sub{sid}")
                    else:
                        print(f"‚ö†Ô∏è Skipped {method} for sub{sid} (missing files)")
                except Exception as e:
                    print(f"‚ùå Error in {method} for sub{sid}: {e}")
            
            print(f"\nüìå Summary for sub{sid}: {completed_methods}/{len(methods_to_run)} methods completed.")

        print(f"\nüéâ Batch processing complete for {total_subjects} subjects!")
        print(f"Methods run: {', '.join(methods_to_run)}")

### üîç Step 2: Optimal DLPFC Subunit Selection

Building on the functional parcellation from **Algorithm 1**, this step identifies the **single optimal DLPFC subunit** for targeting, based on three biologically and clinically motivated criteria.

#### ‚ö†Ô∏è Prerequisite:
> **Algorithm 1 must be successfully executed first** for the same subject(s) and with the **same correlation method** (e.g., `spearman`, `pearson`, `kendall`, or `crosscorr`).  
> This step **depends entirely** on two output files generated by Algorithm 1:
> - `subXXX_SAINT_Protocol_Output_{method}.xlsx` (containing DLPFC/sgACC cluster time series)
> - `subXXX_cluster_voxel_coords_{method}.npz` (containing full voxel coordinates for spatial concentration)

If these files are missing or inconsistent (e.g., different correlation method), Algorithm 2 will **skip** the subject.

#### üìå Key Components:

**1. Net Functional Coupling with sgACC:**  
- For each DLPFC subunit, compute the **weighted net correlation** with all sgACC subunits:  
  `NetCorrelation = Œ£ (œÅ_DLPFC‚ÄìsgACC √ó sgACC_subunit_size)`  
- **Stronger anticorrelation (more negative net correlation) is preferred**, as it aligns with the therapeutic hypothesis of the SAINT protocol.

**2. Spatial Concentration:**  
- Measures how compact a subunit is in 3D space:  
  `SpatialConcentration = (Number of voxels) / (Mean pairwise Euclidean distance)`  
- **Higher values indicate tighter, more focal clusters**, which are preferable for precise neuromodulation.

**3. Cluster Size:**  
- Larger subunits are easier to target reliably with TMS coils.  
- Size is included as a direct feature in the scoring.

**4. Composite Scoring & Selection:**  
- Each metric is **min-max normalized** to [0, 1] (optional via `USE_NORMALIZATION`).  
- A weighted final score is computed:  
  `FinalScore = Œ±¬∑AnticorrScore + Œ≤¬∑SpatialScore + Œ≥¬∑SizeScore`  
  where:  
  - `Œ± = 0.5` (anticorrelation strength)  
  - `Œ≤ = 0.35` (spatial concentration)  
  - `Œ≥ = 0.15` (cluster size)  
- The DLPFC subunit with the **highest FinalScore** is selected as optimal.

**5. Output:**  
- Results are saved in a new sheet **`Optimal_DLPFC`** within the same Excel file from Algorithm 1 (`subXXX_SAINT_Protocol_Output_{method}.xlsx`).  
- The output includes: cluster ID, centroid coordinates, raw metrics, normalized scores, weights, final score, and the correlation method used.

In [None]:
# =========================================
# Algorithm 2: Optimal DLPFC Subunit Selection
# Uses full voxel coordinates from .npz file (saved by Algorithm 1)
# Supports multiple correlation methods: 'spearman', 'pearson', 'kendall', 'crosscorr'
# Processes one subject or all subjects in batch
# =========================================

import pandas as pd
import numpy as np
import os
import re
from scipy.stats import spearmanr, pearsonr, kendalltau
from scipy.spatial.distance import pdist

# -------------------------
# CONFIGURATION
# -------------------------
USE_NORMALIZATION = True
alpha = 0.5   # weight for anticorrelation strength
beta = 0.35   # weight for spatial concentration
gamma = 0.15  # weight for cluster size


# -------------------------
# Helper: compute correlation based on method
# -------------------------
# -------------------------
# Helper: compute correlation based on method
# -------------------------
def compute_correlation(ts1, ts2, method="spearman"):
    """Compute correlation between two time series with given method."""
    # Ensure both time series have at least 2 time points
    if len(ts1) < 2 or len(ts2) < 2:
        return 0.0
    
    try:
        if method == "spearman":
            rho, _ = spearmanr(ts1, ts2)
        elif method == "pearson":
            rho, _ = pearsonr(ts1, ts2)
        elif method == "kendall":
            rho, _ = kendalltau(ts1, ts2)
        elif method == "crosscorr":
            def zscore(x):
                return (x - x.mean()) / (x.std() + 1e-8)
            ts1_z = zscore(ts1)
            ts2_z = zscore(ts2)
            rho = np.corrcoef(ts1_z, ts2_z)[0, 1]
        else:
            raise ValueError(f"Unsupported correlation method: {method}")
        
        return 0.0 if np.isnan(rho) else float(rho)
    except Exception:
        return 0.0

# -------------------------
# Helper functions
# -------------------------
def spatial_concentration(coords):
    """Compute spatial concentration: N_voxels / mean pairwise distance."""
    if coords.shape[0] < 2:
        return 0.0
    distances = pdist(coords, metric="euclidean")
    return coords.shape[0] / np.mean(distances)

def min_max_norm(x):
    """Min-Max normalize to [0, 1]."""
    x = np.array(x, dtype=float)
    if x.size == 0 or x.max() == x.min():
        return np.ones_like(x)
    return (x - x.min()) / (x.max() - x.min())

# -------------------------
# Main processing function for ONE subject
# -------------------------
def process_subject_algo2(subject_id, base_dir="G:/SAINT/Subjects", corr_method="spearman"):
    """Run Algorithm 2 for a single subject using full voxel data from Algorithm 1."""
    print(f"\nüîç Running Algorithm 2 for sub{subject_id} (method: {corr_method})...")
    
    output_dir = os.path.join(base_dir, f"sub{subject_id}", "SAINT-PROTOCOL")
    # Dynamically build input file names based on corr_method
    output_path = os.path.join(output_dir, f"sub{subject_id}_SAINT_Protocol_Output_{corr_method}.xlsx")
    voxel_coords_path = os.path.join(output_dir, f"sub{subject_id}_cluster_voxel_coords_{corr_method}.npz")
    
    # Check if required files exist
    if not os.path.exists(output_path):
        print(f"‚ö†Ô∏è Skipping sub{subject_id}: Algorithm 1 Excel output not found for method '{corr_method}'.")
        return False
    if not os.path.exists(voxel_coords_path):
        print(f"‚ö†Ô∏è Skipping sub{subject_id}: Full voxel coordinates (.npz) not found for method '{corr_method}'.")
        return False

    # Load cluster data from Excel
    try:
        dlpfc_df = pd.read_excel(output_path, sheet_name='DLPFC_Clusters')
        sgacc_df = pd.read_excel(output_path, sheet_name='sgACC_Clusters')
    except Exception as e:
        print(f"‚ö†Ô∏è Skipping sub{subject_id}: Error reading Excel: {e}")
        return False

    if dlpfc_df.empty:
        print(f"‚ö†Ô∏è Skipping sub{subject_id}: No DLPFC clusters found.")
        return False
    if sgacc_df.empty:
        print(f"‚ö†Ô∏è Skipping sub{subject_id}: No sgACC clusters found.")
        return False

    # Load full voxel coordinates from .npz
    try:
        voxel_data = np.load(voxel_coords_path, allow_pickle=True)
        dlpfc_voxel_arrays = voxel_data['dlpfc_voxel_coords']
        sgacc_voxel_arrays = voxel_data['sgacc_voxel_coords']
        dlpfc_labels_np = voxel_data['dlpfc_cluster_labels']
        sgacc_labels_np = voxel_data['sgacc_cluster_labels']
        saved_method = voxel_data.get('correlation_method', corr_method)
        if saved_method != corr_method:
            print(f"‚ö†Ô∏è Warning: .npz file was saved with method '{saved_method}', but you're using '{corr_method}'.")
    except Exception as e:
        print(f"‚ö†Ô∏è Skipping sub{subject_id}: Error loading .npz file: {e}")
        return False

    # Reconstruct dlpfc_reps and sgacc_reps with FULL voxel_coords
    def df_and_voxels_to_reps(df, voxel_arrays, labels_array):
        reps = []
        label_to_voxels = dict(zip(labels_array, voxel_arrays))
        ts_cols = [col for col in df.columns if col.startswith('T')]
        for _, row in df.iterrows():
            cluster_id = int(row["ClusterID"])
            ts = df.loc[df.index == row.name, ts_cols].values.flatten()
            voxel_coords = label_to_voxels.get(cluster_id)
            if voxel_coords is None:
                print(f"‚ö†Ô∏è Warning: Cluster {cluster_id} not found in .npz file.")
                continue
            reps.append({
                "cluster_label": cluster_id,
                "rep_coords": (int(row["RepVoxel_X"]), int(row["RepVoxel_Y"]), int(row["RepVoxel_Z"])),
                "timeseries": ts,
                "size": int(row["Size"]),
                "voxel_coords": voxel_coords
            })
        return reps

    dlpfc_reps = df_and_voxels_to_reps(dlpfc_df, dlpfc_voxel_arrays, dlpfc_labels_np)
    sgacc_reps = df_and_voxels_to_reps(sgacc_df, sgacc_voxel_arrays, sgacc_labels_np)

    if not dlpfc_reps:
        print(f"‚ö†Ô∏è Skipping sub{subject_id}: No matching DLPFC clusters in .npz.")
        return False

    # Step 1: Compute metrics using FULL voxel_coords
    raw_net_corrs = []
    raw_sizes = []
    raw_spatial_concs = []
    cluster_labels = []
    centroids = []

    for d in dlpfc_reps:
        net_corr = 0.0
        for s in sgacc_reps:
            # ‚úÖ Use the specified correlation method!
            rho = compute_correlation(d["timeseries"], s["timeseries"], method=corr_method)
            net_corr += rho * s["size"]

        size = d["size"]
        spatial = spatial_concentration(d["voxel_coords"])
        centroid = d["rep_coords"]

        raw_net_corrs.append(net_corr)
        raw_sizes.append(size)
        raw_spatial_concs.append(spatial)
        cluster_labels.append(d["cluster_label"])
        centroids.append(centroid)

    # Step 2: Normalize
    if USE_NORMALIZATION:
        anticorr_scores = min_max_norm([-c for c in raw_net_corrs])
        spatial_scores = min_max_norm(raw_spatial_concs)
        size_scores = min_max_norm(raw_sizes)
    else:
        anticorr_scores = np.array([-c for c in raw_net_corrs])
        spatial_scores = np.array(raw_spatial_concs)
        size_scores = np.array(raw_sizes)

    # Step 3: Final score
    final_scores = (
        alpha * anticorr_scores +
        beta * spatial_scores +
        gamma * size_scores
    )

    # Step 4: Prepare results
    results = []
    for i, label in enumerate(cluster_labels):
        results.append({
            "Subject": f"sub{subject_id}",
            "DLPFC_Cluster": label,
            "Centroid_X": float(centroids[i][0]),
            "Centroid_Y": float(centroids[i][1]),
            "Centroid_Z": float(centroids[i][2]),
            "NetCorrelation": raw_net_corrs[i],
            "ClusterSize": raw_sizes[i],
            "SpatialConcentration": raw_spatial_concs[i],
            "AnticorrScore": anticorr_scores[i],
            "SpatialScore": spatial_scores[i],
            "SizeScore": size_scores[i],
            "Alpha": alpha,
            "Beta": beta,
            "Gamma": gamma,
            "FinalScore": final_scores[i],
            "Normalized": USE_NORMALIZATION,
            "CorrelationMethod": corr_method
        })

    df_algo2 = pd.DataFrame(results)
    df_algo2 = df_algo2.sort_values("FinalScore", ascending=False).reset_index(drop=True)

    # Step 5: Update Excel file (same file as input)
    with pd.ExcelFile(output_path, engine='openpyxl') as reader:
        existing_sheets = {sheet: reader.parse(sheet) for sheet in reader.sheet_names}

    existing_sheets['Optimal_DLPFC'] = df_algo2

    with pd.ExcelWriter(output_path, engine='openpyxl', mode='w') as writer:
        for sheet_name, sheet_df in existing_sheets.items():
            sheet_df.to_excel(writer, sheet_name=sheet_name, index=False)

    # Step 6: Print results
    best = df_algo2.iloc[0]
    print(f"‚úÖ Best DLPFC Subunit: Cluster {best['DLPFC_Cluster']}")
    print(f"   Centroid: ({best['Centroid_X']:.1f}, {best['Centroid_Y']:.1f}, {best['Centroid_Z']:.1f})")
    print(f"   NetCorrelation: {best['NetCorrelation']:.3f}")
    print(f"   SpatialConcentration: {best['SpatialConcentration']:.3f}")
    print(f"   FinalScore: {best['FinalScore']:.3f}")
    print(f"\nTop 3 DLPFC subunits:")
    print(df_algo2.head(3)[["DLPFC_Cluster", "Centroid_X", "Centroid_Y", "Centroid_Z", "NetCorrelation", "SpatialConcentration", "FinalScore"]])
    print(f"\nüìä Full results saved to sheet 'Optimal_DLPFC' in:\n{output_path}")
    
    return True

# -------------------------
# -------------------------
# Main execution logic
# -------------------------
if __name__ == "__main__":
    # ‚öôÔ∏è CONFIGURATION
    
    # Choose mode:
    run_all_methods = True  # Set to False to run only one method
    corr_method = "spearman"  # Used only if run_all_methods = False

    # List of all supported correlation methods
    all_methods = ["spearman", "pearson", "kendall", "crosscorr"]

    # Decide which methods to run
    methods_to_run = all_methods if run_all_methods else [corr_method]

    # Single subject or batch?
    subject_id = None  # Set to e.g., "001" for single subject; None for batch

    if subject_id is not None:
        # Single subject: run all selected methods
        for method in methods_to_run:
            print(f"\nüöÄ Running Algorithm 2 for sub{subject_id} with method: {method}")
            success = process_subject_algo2(subject_id, base_dir, corr_method=method)
            if success:
                print(f"\nüéâ Algorithm 2 completed successfully for sub{subject_id} ({method})!")
            else:
                print(f"\n‚ùå Algorithm 2 failed for sub{subject_id} with {method}.")
    else:
        # Batch mode: for each subject, run all selected methods
        subject_folders = [f for f in os.listdir(base_dir) if re.match(r'^sub\d+$', f)]
        subject_ids = sorted([f.replace('sub', '') for f in subject_folders], key=int)
        total_subjects = len(subject_ids)

        print(f"Found {total_subjects} subjects. Starting Algorithm 2 batch processing...\n")
        
        for sid in subject_ids:
            print(f"\n{'='*60}")
            print(f"Processing subject: sub{sid}")
            print(f"{'='*60}")
            
            completed_methods = 0
            for method in methods_to_run:
                try:
                    print(f"\n‚Üí Running method: {method}")
                    success = process_subject_algo2(sid, base_dir, corr_method=method)
                    if success:
                        completed_methods += 1
                        print(f"‚úÖ Completed {method} for sub{sid}")
                    else:
                        print(f"‚ö†Ô∏è Skipped {method} for sub{sid} (missing files)")
                except Exception as e:
                    print(f"‚ùå Error in {method} for sub{sid}: {e}")
            
            print(f"\nüìå Summary for sub{sid}: {completed_methods}/{len(methods_to_run)} methods completed.")

        print(f"\nüéâ Algorithm 2 batch processing complete for {total_subjects} subjects!")
        print(f"Methods run: {', '.join(methods_to_run)}")

### üñºÔ∏è Step 3: 3D Visualization of the Top DLPFC Cluster

This step produces **high-quality 3D visualizations** of the optimal DLPFC subunit selected by Algorithm 2, rendered as smooth volumetric surfaces using `marching_cubes`.

#### ‚ö†Ô∏è Critical Dependency:
> **This visualization CANNOT run independently.**  
> It **requires successful execution of both Algorithm 1 and Algorithm 2** for the same subject(s) and the **same correlation method** (e.g., `spearman`, `pearson`, etc.).

Specifically, it expects the following files to exist in the `SAINT-PROTOCOL` folder:
- `subXXX_SAINT_Protocol_Output_{method}.xlsx` ‚Üí **must contain the `Optimal_DLPFC` sheet** (generated by Algorithm 2)
- `subXXX_cluster_voxel_coords_{method}.npz` ‚Üí **full voxel coordinates** (generated by Algorithm 1)

If either file is missing or was generated with a different correlation method, the visualization will be **skipped** for that subject/method.

#### üñåÔ∏è Visualization Details:
- **Smooth mesh surfaces** (not point clouds) for anatomical realism
- **Color scheme**:
  - Light grey: Whole-brain background (low opacity)
  - Gold: Full left DLPFC and sgACC masks
  - Red: The **top-ranked DLPFC subunit** (high opacity, highlighted)
- **Custom 3D camera angle** optimized for frontal view
- **Clean annotation-based legend** (no extra markers in 3D space)
- **Automatic fallback** to point-cloud rendering if mesh generation fails

#### üíæ Output:
- Saves a high-resolution PNG image:  
  `subXXX_top_dlpfc_cluster_{method}.png`
- Also displays an **interactive 3D figure** in the notebook/environment

> üìå **Note**: Always verify that Algorithms 1 and 2 completed successfully before running this step.

In [None]:
# =========================================
# 3D Visualization: Top DLPFC Cluster from Algorithm 2
# Uses Mesh3d + marching_cubes with annotation legend
# Supports multiple correlation methods: 'spearman', 'pearson', 'kendall', 'crosscorr'
# Saves output as PNG in SAINT-PROTOCOL folder
# =========================================

import pandas as pd
import numpy as np
import plotly.graph_objects as go
import plotly.io as pio
import nibabel as nib
import os
import re
from skimage import measure  # For marching_cubes


def find_mask_path(subject_dir, base_name):
    """Try bin first, then func."""
    bin_path = os.path.join(subject_dir, f"{base_name}_bin.nii.gz")
    func_path = os.path.join(subject_dir, f"{base_name}_func.nii.gz")
    if os.path.exists(func_path):
        return func_path
    elif os.path.exists(bin_path):
        return bin_path
    else:
        return None

def visualize_and_save_top_dlpfc(subject_id, base_dir="G:/SAINT/Subjects", corr_method="spearman"):
    """Visualize and save top DLPFC cluster using Mesh3d with annotation legend."""
    subject_dir = os.path.join(base_dir, f"sub{subject_id}")
    output_dir = os.path.join(subject_dir, "SAINT-PROTOCOL")
    
    # Paths (dynamically based on corr_method)
    excel_path = os.path.join(output_dir, f"sub{subject_id}_SAINT_Protocol_Output_{corr_method}.xlsx")
    npz_path = os.path.join(output_dir, f"sub{subject_id}_cluster_voxel_coords_{corr_method}.npz")
    fmri_path = os.path.join(subject_dir, "filtered_func_data.nii.gz")
    
    # Validate files
    if not os.path.exists(excel_path):
        print(f"‚ö†Ô∏è Skipping sub{subject_id}: Excel output not found for method '{corr_method}'.")
        return False
    if not os.path.exists(npz_path):
        print(f"‚ö†Ô∏è Skipping sub{subject_id}: .npz voxel data not found for method '{corr_method}'.")
        return False
    if not os.path.exists(fmri_path):
        print(f"‚ö†Ô∏è Skipping sub{subject_id}: fMRI data not found.")
        return False
    
    # Load top cluster label from Excel
    try:
        df_algo2 = pd.read_excel(excel_path, sheet_name='Optimal_DLPFC')
        top_cluster_label = int(df_algo2.iloc[0]['DLPFC_Cluster'])
    except Exception as e:
        print(f"‚ö†Ô∏è Skipping sub{subject_id}: Error reading Excel: {e}")
        return False

    print(f"‚úÖ Loading top DLPFC cluster {top_cluster_label} for sub{subject_id} (method: {corr_method})")

    # Load full voxel coordinates from .npz
    try:
        voxel_data = np.load(npz_path, allow_pickle=True)
        dlpfc_voxel_arrays = voxel_data['dlpfc_voxel_coords']
        dlpfc_labels = voxel_data['dlpfc_cluster_labels']
        idx = np.where(dlpfc_labels == top_cluster_label)[0][0]
        top_cluster_voxels = dlpfc_voxel_arrays[idx]
    except Exception as e:
        print(f"‚ö†Ô∏è Skipping sub{subject_id}: Error loading .npz: {e}")
        return False

    # Load masks and data
    dlpfc_mask_path = find_mask_path(subject_dir, "l_DLPFC")
    sgacc_mask_path = find_mask_path(subject_dir, "sgACC")
    if dlpfc_mask_path is None or sgacc_mask_path is None:
        print(f"‚ö†Ô∏è Skipping sub{subject_id}: Mask files not found.")
        return False

    fmri_img = nib.load(fmri_path)
    fmri_data = fmri_img.get_fdata()
    brain_mask = np.any(fmri_data != 0, axis=3)
    dlpfc_mask = nib.load(dlpfc_mask_path).get_fdata().astype(bool)
    sgacc_mask = nib.load(sgacc_mask_path).get_fdata().astype(bool)

    # -----------------------------
    # Build binary volumes for meshing
    # -----------------------------
    shape = brain_mask.shape
    top_cluster_vol = np.zeros(shape, dtype=bool)
    for coord in top_cluster_voxels:
        top_cluster_vol[tuple(coord)] = True

    # -----------------------------
    # Create Mesh3d traces (for actual visualization)
    # -----------------------------
    traces = []

    # 1. Whole brain mesh (background)
    try:
        verts, faces, _, _ = measure.marching_cubes(brain_mask.astype(float), level=0.5)
        brain_mesh = go.Mesh3d(
            x=verts[:, 0], y=verts[:, 1], z=verts[:, 2],
            i=faces[:, 0], j=faces[:, 1], k=faces[:, 2],
            color='lightgrey', opacity=0.08, name='Whole Brain', showlegend=False
        )
        traces.append(brain_mesh)
    except RuntimeError:
        brain_iso = go.Isosurface(
            x=np.arange(shape[0]), y=np.arange(shape[1]), z=np.arange(shape[2]),
            value=brain_mask.astype(float),
            isomin=0.5, isomax=1.0,
            opacity=0.08, colorscale=[[0, 'lightgrey'], [1, 'lightgrey']],
            showscale=False, showlegend=False, name='Whole Brain'
        )
        traces.append(brain_iso)

    # 2. Full DLPFC mask mesh ‚Äî yellow
    if np.any(dlpfc_mask):
        try:
            verts, faces, _, _ = measure.marching_cubes(dlpfc_mask.astype(float), level=0.5)
            dlpfc_mesh = go.Mesh3d(
                x=verts[:,0], y=verts[:,1], z=verts[:,2],
                i=faces[:,0], j=faces[:,1], k=faces[:,2],
                color='gold', opacity=0.5, name='DLPFC (full mask)', showlegend=False
            )
            traces.append(dlpfc_mesh)
        except RuntimeError:
            pass

    # 3. Full sgACC mask mesh ‚Äî gold
    if np.any(sgacc_mask):
        try:
            verts, faces, _, _ = measure.marching_cubes(sgacc_mask.astype(float), level=0.5)
            sgacc_mesh = go.Mesh3d(
                x=verts[:,0], y=verts[:,1], z=verts[:,2],
                i=faces[:,0], j=faces[:,1], k=faces[:,2],
                color='gold', opacity=0.5, name='sgACC (full mask)', showlegend=False
            )
            traces.append(sgacc_mesh)
        except RuntimeError:
            pass

    # 4. Top DLPFC cluster ‚Äî red (highlighted)
    if np.any(top_cluster_vol):
        try:
            verts, faces, _, _ = measure.marching_cubes(top_cluster_vol.astype(float), level=0.5)
            top_mesh = go.Mesh3d(
                x=verts[:,0], y=verts[:,1], z=verts[:,2],
                i=faces[:,0], j=faces[:,1], k=faces[:,2],
                color='#d7191c', opacity=0.9, name=f'DLPFC - Top Cluster {top_cluster_label}', showlegend=False
            )
            traces.append(top_mesh)
        except RuntimeError:
            # Fallback to scatter only for top cluster if meshing fails
            top_scatter = go.Scatter3d(
                x=top_cluster_voxels[:,0], y=top_cluster_voxels[:,1], z=top_cluster_voxels[:,2],
                mode='markers',
                marker=dict(size=3, color='#d7191c', opacity=0.9),
                name=f'DLPFC - Top Cluster {top_cluster_label} (fallback)',
                showlegend=False
            )
            traces.append(top_scatter)

    # -----------------------------
    # Final figure ‚Äî NO AXES, NO GRID, NO LABELS
    # -----------------------------
    fig = go.Figure(data=traces)
    fig.update_layout(
        title={
            'text': f"Top DLPFC Cluster from Algorithm 2<br><sub>Method: {corr_method} | Cluster {top_cluster_label} | Subject sub{subject_id}</sub>",
            'x': 0.02,
            'xanchor': 'left',
            'font': {'size': 16}
        },
        scene=dict(
            xaxis=dict(visible=False),
            yaxis=dict(visible=False),
            zaxis=dict(visible=False),
            aspectmode='data',
            camera=dict(
                eye=dict(x=-0.5, y=1.5, z=1),
                up=dict(x=0, y=1, z=1),
                center=dict(x=0, y=0, z=0)
            )
        ),
        width=1100, height=950,
        showlegend=False,
        margin=dict(l=0, r=0, t=50, b=0)
    )

    # -----------------------------
    # ADD CUSTOM LEGEND AS ANNOTATION (NO SCATTER3D!)
    # -----------------------------
    fig.add_annotation(
        x=0.02, y=0.98,
        text="<b>Legend:</b><br>"
             "<span style='color:lightgrey'>‚óè</span> Brain (background)<br>"
             "<span style='color:gold'>‚óè</span> DLPFC (full mask)<br>"
             "<span style='color:gold'>‚óè</span> sgACC (full mask)<br>"
             "<span style='color:#d7191c'>‚óè</span> DLPFC - Top Cluster {} ({} voxels)".format(top_cluster_label, len(top_cluster_voxels)),
        showarrow=False,
        xref="paper", yref="paper",
        align="left",
        bgcolor="rgba(255,255,255,0.8)",
        bordercolor="black",
        borderwidth=1,
        font=dict(size=14)
    )

    # -------------------------
    # SAVE IMAGE with method name
    # -------------------------
    image_base = os.path.join(output_dir, f"sub{subject_id}_top_dlpfc_cluster_{corr_method}")
    
    try:
        fig.write_image(f"{image_base}.png", scale=2)
        print(f"üíæ Image saved to: {image_base}.png")
    except Exception as e:
        print(f"‚ö†Ô∏è Warning: Could not save image (install 'kaleido'): {e}")

    # Also show interactively
    fig.show()
    return True

# -------------------------
# Main execution logic
# -------------------------
if __name__ == "__main__":
    # ‚öôÔ∏è CONFIGURATION
    
    # Choose mode:
    run_all_methods = True  # Set to False to run only one method
    corr_method = "spearman"  # Used only if run_all_methods = False

    # List of all supported correlation methods
    all_methods = ["spearman", "pearson", "kendall", "crosscorr"]

    # Decide which methods to run
    methods_to_run = all_methods if run_all_methods else [corr_method]

    # Single subject or batch?
    subject_id = None  # Set to e.g., "001" for single subject; None for batch

    if subject_id is not None:
        for method in methods_to_run:
            print(f"\nüñºÔ∏è  Visualizing top DLPFC for sub{subject_id} with method: {method}")
            success = visualize_and_save_top_dlpfc(subject_id, base_dir, corr_method=method)
            if success:
                print(f"‚úÖ Visualization completed for {method}!")
            else:
                print(f"‚ùå Failed for {method}.")
    else:
        subject_folders = [f for f in os.listdir(base_dir) if re.match(r'^sub\d+$', f)]
        subject_ids = sorted([f.replace('sub', '') for f in subject_folders], key=int)
        total_subjects = len(subject_ids)

        print(f"Found {total_subjects} subjects. Starting batch visualization...\n")
        
        for sid in subject_ids:
            print(f"\n{'='*60}")
            print(f"Visualizing subject: sub{sid}")
            print(f"{'='*60}")
            
            completed_methods = 0
            for method in methods_to_run:
                try:
                    print(f"\n‚Üí Running method: {method}")
                    success = visualize_and_save_top_dlpfc(sid, base_dir, corr_method=method)
                    if success:
                        completed_methods += 1
                        print(f"‚úÖ Completed {method} for sub{sid}")
                    else:
                        print(f"‚ö†Ô∏è Skipped {method} for sub{sid}")
                except Exception as e:
                    print(f"‚ùå Error in {method} for sub{sid}: {e}")
            
            print(f"\nüìå Summary for sub{sid}: {completed_methods}/{len(methods_to_run)} methods visualized.")

        print(f"\nüéâ Batch visualization complete for {total_subjects} subjects!")
        print(f"Methods visualized: {', '.join(methods_to_run)}")