In [None]:
# =============================================================================
# 1. Imports and Global Settings
# =============================================================================
import warnings
import logging
import datetime
import time
import pickle
import numpy as np
import os
import re
import sys
import random
import itertools
import pandas as pd
from scipy import stats
import networkx as nx
from scipy.sparse.csgraph import connected_components
from scipy.sparse import csr_matrix
from dataclasses import dataclass
import concurrent.futures
from numba import jit, prange
from typing import Dict, List, Tuple, Optional
from functools import partial
from sklearn.model_selection import KFold
from IPython.display import clear_output

# =============================================================================
# 2. Data Classes
# =============================================================================
@dataclass
class ConvergenceStats:
    """Track convergence statistics for NBS permutations."""
    window_size: int = 50
    alpha: float = 0.05
    min_iterations: int = 100

    def __post_init__(self):
        self.component_sizes = []
        self.running_mean = []
        self.running_var = []
        self.converged = False

    def update(self, size: float) -> bool:
        """Update convergence statistics with the new component size."""
        self.component_sizes.append(size)
        if len(self.component_sizes) >= self.min_iterations:
            recent_sizes = self.component_sizes[-self.window_size:]
            mean = np.mean(recent_sizes)
            var = np.var(recent_sizes)
            self.running_mean.append(mean)
            self.running_var.append(var)
            if var < self.alpha:
                self.converged = True
                return True
        return False

# =============================================================================
# 3. Helper Functions (Statistical, Parsing, and Edge Utilities)
# =============================================================================
def calculate_effect_sizes(component: set, data1: np.ndarray, data2: np.ndarray) -> dict:
    """
    Calculate multiple effect size metrics for NBS components.
    
    Args:
        component: Set of nodes in the component.
        data1, data2: Original connectivity data for both groups.
    
    Returns:
        Dictionary of effect size metrics.
    """
    effect_sizes = {}
    
    for node in component:
        # Calculate Cohen's d
        d = (np.mean(data1[:, node]) - np.mean(data2[:, node])) / \
            np.sqrt((np.var(data1[:, node]) + np.var(data2[:, node])) / 2)
            
        # Calculate rank-biserial correlation
        u_stat, _ = stats.mannwhitneyu(data1[:, node], data2[:, node])
        n1, n2 = len(data1), len(data2)
        rank_biserial = 2 * (u_stat / (n1 * n2)) - 1
        
        effect_sizes[node] = {
            'cohens_d': d,
            'rank_biserial': rank_biserial
        }
    
    return effect_sizes


def validate_nbs_results(results: dict, data1: np.ndarray, data2: np.ndarray, alpha: float = 0.05) -> dict:
    """
    Validate NBS results using additional statistical tests and account for dynamic nature of data.
    
    Args:
        results: NBS analysis results.
        data1, data2: Original connectivity data.
        alpha: Significance level.
    
    Returns:
        Dictionary with validation metrics including dynamic connection counts.
    """
    validation = {}
    
    for idx, component in enumerate(results['significant_components']):
        # Cross-validation
        cv_scores = []
        kf = KFold(n_splits=5, shuffle=True)
        
        for train_idx, test_idx in kf.split(data1):
            # Train/test split
            train1, test1 = data1[train_idx], data1[test_idx]
            train2, test2 = data2[train_idx], data2[test_idx]
            
            # Calculate component stability
            component_pvals = []
            for node in component:
                _, p = stats.mannwhitneyu(
                    test1[:, node],
                    test2[:, node],
                    alternative='two-sided'
                )
                component_pvals.append(p < alpha)
            
            # Calculate validation score
            cv_scores.append(np.mean(component_pvals))
        
        validation[idx] = {
            'cross_validation_score': np.mean(cv_scores),
            'stability': np.std(cv_scores),
            'dynamic_size': len(component) * len(np.unique(results.get('time_windows', [1])))
        }
    
    return validation


def identify_compensatory_connections(
    aging_results: dict, 
    taichi_results: dict,
    yac_means: dict,  # For consistency, parameter name remains yac_means
    comparison: str,
    labels: np.ndarray,
    edge_indices_reverse: dict,
    group1_means_all: dict,
    group2_means_all: dict
) -> dict:
    """
    Identify compensatory and deterioration connections specifically for TCOA vs OAC comparison.
    YAC means are used as reference points.
    Only processes mechanism categories for TCOA comparison, not for OAC vs YAC.
    """
    results = {
        'compensatory_components': [],
        'effect_sizes': {},
        'mechanisms': {}
    }

    if comparison != 'oac_vs_tcoa':
        return results

    for key in aging_results.keys():
        aging_nbs_result = aging_results[key]
        taichi_nbs_result = taichi_results.get(key)
        if not taichi_nbs_result:
            continue

        yac_mean = yac_means.get(key)
        if not yac_mean:
            continue

        for aging_comp_idx, aging_component in enumerate(aging_nbs_result['significant_components']):
            aging_effect = aging_nbs_result['effect_sizes'][aging_comp_idx]['mean_effect']

            for tc_comp_idx, tc_component in enumerate(taichi_nbs_result['significant_components']):
                tc_effect = taichi_nbs_result['effect_sizes'][tc_comp_idx]['mean_effect']
                common_edges = aging_component.intersection(tc_component)

                if common_edges:
                    for edge in common_edges:
                        edge_nodes = aging_nbs_result['component_edges'][aging_comp_idx]
                        for node_pair in edge_nodes:
                            node1, node2 = node_pair
                            network1 = get_network_name(labels[node1])
                            network2 = get_network_name(labels[node2])
                            if not network1 or not network2:
                                continue

                            node_networks = sorted([network1, network2])

                            edge_idx = edge_indices_reverse.get((node1, node2))
                            if edge_idx is None:
                                edge_idx = edge_indices_reverse.get((node2, node1))
                            if edge_idx is None:
                                continue

                            oac_mean = group1_means_all[key][edge_idx]
                            tcoa_mean = group2_means_all[key][edge_idx]
                            yac_mean_edge = yac_mean[edge_idx]

                            oac_distance = abs(oac_mean - yac_mean_edge)
                            tcoa_distance = abs(tcoa_mean - yac_mean_edge)

                            mechanism = determine_mechanism(
                                oac_distance, tcoa_distance, 
                                oac_mean, tcoa_mean, yac_mean_edge
                            )

                            results['compensatory_components'].append({
                                'edge': edge_idx,
                                'node1': node1,
                                'node2': node2,
                                'aging_component': aging_comp_idx,
                                'mechanism': mechanism,
                                'key': key,
                                'network_pair': '-'.join(node_networks)
                            })

                            results['effect_sizes'][edge_idx] = {
                                'aging': oac_mean,
                                'intervention': tcoa_mean
                            }
                            results['mechanisms'][edge_idx] = mechanism

    return results


def determine_mechanism(
    oac_distance: float, 
    tcoa_distance: float, 
    oac_mean: float,
    tcoa_mean: float,
    yac_mean: float
) -> str:
    """
    Determine the mechanism based on distances and means.
    
    Args:
        oac_distance: Distance between OAC and YAC means.
        tcoa_distance: Distance between TCOA and YAC means.
        oac_mean: Mean connectivity value for OAC group.
        tcoa_mean: Mean connectivity value for TCOA group.
        yac_mean: Mean connectivity value for YAC group.
    
    Returns:
        str: The identified mechanism type.
    """
    if tcoa_distance < oac_distance:
        if tcoa_mean > yac_mean > oac_mean:
            return 'Enhancement'
        elif tcoa_mean < yac_mean < oac_mean:
            return 'Normalization'
        else:
            same_side = ((tcoa_mean - yac_mean) * (oac_mean - yac_mean) > 0)
            if same_side:
                return 'Restoration'
            else:
                return 'Alternative'
    else:
        if tcoa_mean > yac_mean > oac_mean:
            return 'Exacerbation'
        elif tcoa_mean < yac_mean < oac_mean:
            return 'Decompensation'
        else:
            same_side = ((tcoa_mean - yac_mean) * (oac_mean - yac_mean) > 0)
            if same_side:
                return 'Deterioration'
            else:
                return 'Maladaptive'


def get_edge_indices(n_nodes):
    """Get a mapping from edge indices to node pairs."""
    edge_indices = {}
    idx = 0
    for i in range(n_nodes):
        for j in range(i + 1, n_nodes):
            edge_indices[idx] = (i, j)
            idx += 1
    return edge_indices


def get_edge_indices_reverse(n_nodes):
    """Get a mapping from node pairs to edge indices."""
    edge_indices = {}
    idx = 0
    for i in range(n_nodes):
        for j in range(i + 1, n_nodes):
            edge_indices[(i, j)] = idx
            edge_indices[(j, i)] = idx  # Include reverse mapping
            idx += 1
    return edge_indices


def get_network_name(label_name):
    """Extract the network name from the label name."""
    if 'Vis' in label_name:
        return 'Visual'
    elif 'SomMot' in label_name:
        return 'Somatomotor'
    elif 'DorsAttn' in label_name:
        return 'DorsalAttention'
    elif 'SalVentAttn' in label_name or 'VentAttn' in label_name:
        return 'VentralAttention'
    elif 'Limbic' in label_name:
        return 'Limbic'
    elif 'Cont' in label_name:
        return 'Frontoparietal'   
    elif 'Default' in label_name:
        return 'Default'
    else:
        return None


def get_region_name(label_name):
    """Extract the region name from the label name."""
    parts = label_name.split('-')
    network_name = get_network_name(label_name)
    if network_name == 'Visual' or network_name == 'Somatomotor':
        region_info = parts[0].split('_')[2:]
    else:
        region_info = parts[0].split('_')[3:]
    hemisphere = parts[1]
    return '_'.join(region_info) + '-' + hemisphere


def parse_within_conn_value(value):
    """Parse within-network connectivity string."""
    pattern = r'\[(.*?)\]: (.*?) - (.*?) - (.*)'
    match = re.match(pattern, value)
    if match:
        network = match.group(1)
        region1 = match.group(2)
        connectivity = float(match.group(3))
        region2 = match.group(4)
        return network, region1, connectivity, region2
    else:
        raise ValueError(f"Unable to parse within network connectivity value: {value}")


def parse_between_conn_value(value):
    """Parse between-network connectivity string."""
    pattern = r'\[(.*?), (.*?)\]: (.*?) - (.*?) - (.*)'
    match = re.match(pattern, value)
    if match:
        network1, network2 = sorted([match.group(1), match.group(2)])
        region1 = match.group(3)
        connectivity = float(match.group(4))
        region2 = match.group(5)
        return [network1, network2], region1, connectivity, region2
    else:
        raise ValueError(f"Unable to parse between network connectivity value: {value}")


def parse_connectivity_data(connectivity_data, connectivity_type, participants):
    """Parse the within or between network connectivity data per participant."""
    print("\nParsing connectivity data...")
    parsed_data = {}
    total_participants = len(set(key[0] for key in connectivity_data.keys()))
    participants_processed = 0

    # Group keys by participant
    participant_keys = {}
    for key in connectivity_data.keys():
        participant = key[0]
        if participant not in participant_keys:
            participant_keys[participant] = []
        participant_keys[participant].append(key)

    for participant in participants:
        if participant not in participant_keys:
            continue  # Skip participants without data
        for conn_type in ['within', 'between']:
            parsed_data_participant = {}
            for key in participant_keys[participant]:
                if connectivity_type == conn_type:
                    values = connectivity_data[key]
                    parsed_values = []
                    for value in values:
                        if conn_type == 'within':
                            network, region1, conn_value, region2 = parse_within_conn_value(value)
                            parsed_values.append((network, region1, conn_value, region2))
                        elif conn_type == 'between':
                            networks_pair, region1, conn_value, region2 = parse_between_conn_value(value)
                            parsed_values.append((networks_pair, region1, conn_value, region2))
                    parsed_data_participant[key] = parsed_values
            parsed_data.update(parsed_data_participant)
        participants_processed += 1
        sys.stdout.write(f"\rProcessed participant {participant} ({participants_processed}/{total_participants})")
        sys.stdout.flush()
    print("\nFinished parsing connectivity data.\n")
    return parsed_data


def generate_within_network_edge_labels(networks):
    """Generate edge labels for within-network connections."""
    edge_labels = []
    for network_name, regions in networks.items():
        region_indices = [region[0] for region in regions]
        region_names = [region[1] for region in regions]
        n_regions = len(region_indices)
        for i in range(n_regions):
            for j in range(i + 1, n_regions):
                edge_labels.append((region_names[i], region_names[j]))
    return edge_labels


def generate_between_network_edge_labels(networks):
    """Generate edge labels for between-network connections."""
    edge_labels = []
    network_names = sorted(list(networks.keys()))
    for i in range(len(network_names)):
        for j in range(i + 1, len(network_names)):
            net1, net2 = network_names[i], network_names[j]
            regions1 = networks[net1]
            regions2 = networks[net2]
            for region1 in regions1:
                for region2 in regions2:
                    edge_labels.append((region1[1], region2[1]))
    return edge_labels


def standardize_edge(edge):
    """Ensure that the edge tuple is always in a consistent order."""
    return tuple(sorted(edge))


def generate_region_pairs(labels):
    """Generate all possible region pairs."""
    region_pairs = []
    for i in range(len(labels)):
        for j in range(i + 1, len(labels)):
            region_pairs.append((labels[i], labels[j]))
    return region_pairs


def should_include_edge(diff, corr1, corr2, change_type):
    """Determine whether a connection should be included based on the change type."""
    if change_type == 'IMPC':
        return (diff > 0) and (corr1 > 0) and (corr2 > 0)
    elif change_type == 'DMPC':
        return (diff < 0) and (corr1 > 0) and (corr2 > 0)
    elif change_type == 'SNPC':
        return (corr1 < 0) and (corr2 > 0)
    elif change_type == 'SPNC':
        return (corr1 > 0) and (corr2 < 0)
    elif change_type == 'IMNC':
        return (diff < 0) and (corr1 < 0) and (corr2 < 0)
    elif change_type == 'DMNC':
        return (diff > 0) and (corr1 < 0) and (corr2 < 0)
    return False


def determine_favored_group(diff, corr1, corr2, change_type, comparison):
    """Determine which group is favored based on the change type and correlation values."""
    group1_name, group2_name = group_names[comparison]
    if change_type == 'IMPC':
        return f'{group1_name} favored' if (diff > 0 and corr1 > 0 and corr2 > 0) else f'{group2_name} favored'
    elif change_type == 'DMPC':
        return f'{group1_name} favored' if (diff < 0 and corr1 > 0 and corr2 > 0) else f'{group2_name} favored'
    elif change_type == 'SNPC':
        return f'{group2_name} favored' if (corr1 < 0 and corr2 > 0) else f'{group1_name} favored'
    elif change_type == 'IMNC':
        return f'{group1_name} favored' if (diff < 0 and corr1 < 0 and corr2 < 0) else f'{group2_name} favored'
    elif change_type == 'DMNC':
        return f'{group1_name} favored' if (diff > 0 and corr1 < 0 and corr2 < 0) else f'{group2_name} favored'
    elif change_type == 'SPNC':
        return f'{group2_name} favored' if (corr1 > 0 and corr2 < 0) else f'{group1_name} favored'
    return None

# =============================================================================
# 4. Data Loading Functions
# =============================================================================
def load_connectivity_data(base_dir, participants, modes):
    """Load raw connectivity matrices with comprehensive validation."""
    print("\nLoading raw connectivity matrices...")
    
    connectivity_data = {}
    total_participants = len(participants)
    out_of_bounds_count = 0
    total_matrices = 0
    
    for i, participant in enumerate(participants, 1):
        connectivity_data[participant] = {}
        matrices_loaded = 0
        
        for mode in modes:
            matrix_path = os.path.join(base_dir, participant, mode, 
                                     f"{participant}_correlation_matrices.npy.npz")
            
            if os.path.exists(matrix_path):
                try:
                    data = np.load(matrix_path, allow_pickle=True)
                    connectivity_data[participant][mode] = {}
                    
                    for key in data.files:
                        matrix = data[key]
                        if matrix.ndim != 2:
                            print(f"Warning: Non-2D matrix found in {matrix_path}, key {key}")
                            continue
                        matrix = matrix[:-2, :-2]
                        if np.any(matrix < -1) or np.any(matrix > 1):
                            out_of_bounds_count += 1
                            print(f"\nOut-of-bounds values in {participant}, {mode}, {key}")
                            print(f"Range: [{np.min(matrix):.3f}, {np.max(matrix):.3f}]")
                            matrix = np.clip(matrix, -1, 1)
                        if np.isnan(matrix).any() or np.isinf(matrix).any():
                            print(f"\nWarning: NaN/Inf values in {participant}, {mode}, {key}")
                            valid_mask = ~(np.isnan(matrix) | np.isinf(matrix))
                            if valid_mask.any():
                                mean_val = matrix[valid_mask].mean()
                                matrix = np.nan_to_num(matrix, nan=mean_val, posinf=mean_val, neginf=mean_val)
                            else:
                                print(f"Error: No valid values in matrix for {participant}, {mode}, {key}")
                                continue
                        
                        try:
                            state, window = map(int, key.split('_'))
                            connectivity_data[participant][mode][(state, window)] = matrix
                            matrices_loaded += 1
                            total_matrices += 1
                        except ValueError:
                            print(f"Warning: Invalid key format {key} in {matrix_path}")
                            continue
                            
                except Exception as e:
                    print(f"\nError loading {matrix_path}: {str(e)}")
                    continue
        
        sys.stdout.write(f"\rLoading matrices: {i}/{total_participants} participants | "
                        f"Current: {participant} | Matrices loaded: {matrices_loaded}")
        sys.stdout.flush()
    
    print(f"\nFinished loading {total_matrices} matrices.")
    print(f"Out-of-bounds matrices encountered: {out_of_bounds_count}")
    print(f"Out-of-bounds percentage: {(out_of_bounds_count/total_matrices*100):.2f}%")
    
    return connectivity_data

# =============================================================================
# 5. Connectivity Calculation Functions
# =============================================================================
def calculate_network_connectivity(connectivity_data, networks, labels):
    """Calculate within and between network connectivity from raw matrices."""
    print("\nCalculating network connectivity...")
    
    within_network_connectivity = {}
    between_network_connectivity = {}
    networks_indices = {k: [r[0] for r in v] for k, v in networks.items()}
    
    total_participants = len(connectivity_data)
    processed = 0
    
    for participant, modes_data in connectivity_data.items():
        for mode, windows_data in modes_data.items():
            for (state, window), matrix in windows_data.items():
                key = (participant, mode, (state, window))
                within_network_connectivity[key] = []
                between_network_connectivity[key] = []
                
                for network_name, regions in networks.items():
                    indices = networks_indices[network_name]
                    network_matrix = matrix[np.ix_(indices, indices)]
                    mask = np.triu_indices_from(network_matrix, k=1)
                    
                    for idx in range(len(mask[0])):
                        i, j = mask[0][idx], mask[1][idx]
                        region1 = regions[i][1]
                        region2 = regions[j][1]
                        conn_value = network_matrix[i, j]
                        within_network_connectivity[key].append(
                            f"[{network_name}]: {region1} - {conn_value:.2f} - {region2}")
                
                network_names = sorted(networks.keys())
                for i in range(len(network_names)):
                    for j in range(i + 1, len(network_names)):
                        net1, net2 = network_names[i], network_names[j]
                        indices1 = networks_indices[net1]
                        indices2 = networks_indices[net2]
                        between_matrix = matrix[np.ix_(indices1, indices2)]
                        
                        regions1 = networks[net1]
                        regions2 = networks[net2]
                        
                        for ii in range(len(indices1)):
                            for jj in range(len(indices2)):
                                region1 = regions1[ii][1]
                                region2 = regions2[jj][1]
                                conn_value = between_matrix[ii, jj]
                                between_network_connectivity[key].append(
                                    f"[{net1}, {net2}]: {region1} - {conn_value:.2f} - {region2}")
        
        processed += 1
        sys.stdout.write(f"\rProcessing: {processed}/{total_participants} participants | "
                        f"Current: {participant}")
        sys.stdout.flush()
    
    sys.stdout.write("\nFinished calculating network connectivity.\n")
    sys.stdout.flush()
    
    return within_network_connectivity, between_network_connectivity


def organize_connectivity_data(participant_groups, modes, parsed_connectivity, participants):
    """
    Organize connectivity data for NBS analysis using parsed connectivity data.

    Args:
        participant_groups: Dictionary mapping participants to their groups.
        modes: List of modes (e.g., ['EC', 'EO']).
        parsed_connectivity: Parsed connectivity data (within or between).
        participants: List of all participants.

    Returns:
        Dictionary of organized data ready for NBS analysis.
    """
    print("\nOrganizing connectivity data for NBS analysis...")
    organized_data = {group: {mode: {} for mode in modes} for group in set(participant_groups.values())}

    total_participants = len(set(key[0] for key in parsed_connectivity.keys()))
    participants_processed = 0

    participant_keys = {}
    for key in parsed_connectivity.keys():
        participant = key[0]
        if participant not in participant_keys:
            participant_keys[participant] = []
        participant_keys[participant].append(key)

    for participant in participants:
        group = participant_groups.get(participant)
        if participant not in participant_keys or not group:
            continue
        for mode in modes:
            for key in participant_keys[participant]:
                if key[1] != mode:
                    continue
                _, _, (state, window) = key
                connections = parsed_connectivity[key]
                connections.sort(key=lambda x: (x[1], x[3]))
                vector_form = [conn[2] for conn in connections]
                state_window = f"state_{state}_window_{window}"
                if state_window not in organized_data[group][mode]:
                    organized_data[group][mode][state_window] = {}
                organized_data[group][mode][state_window][participant] = np.array(vector_form)
        participants_processed += 1
        sys.stdout.write(f"\rOrganized data for participant {participant} ({participants_processed}/{total_participants})")
        sys.stdout.flush()

    print("\nFinished organizing connectivity data.\n")
    return organized_data

# =============================================================================
# 6. Permutation and Statistical Functions
# =============================================================================
def process_single_permutation(combined_data, n1, threshold):
    """Process a single permutation."""
    np.random.shuffle(combined_data)  # In-place shuffle
    perm_data1 = combined_data[:n1]
    perm_data2 = combined_data[n1:]
    
    t_stats, _, _, _, _ = compute_test_stats(perm_data1, perm_data2)
    significant_edges = np.abs(t_stats) > threshold
    n_edges = len(t_stats)
    n_nodes = int((1 + np.sqrt(1 + 8 * n_edges)) / 2)
    
    adj_matrix = np.zeros((n_nodes, n_nodes))
    edge_idx = 0
    for i in range(n_nodes):
        for j in range(i + 1, n_nodes):
            if significant_edges[edge_idx]:
                adj_matrix[i, j] = 1
                adj_matrix[j, i] = 1
            edge_idx += 1
    
    G = nx.from_numpy_array(adj_matrix)
    components = list(nx.connected_components(G))
    
    return max([len(comp) for comp in components]) if components else 0


def compute_test_stats(data1, data2):
    """Optimized test statistic computation using NumPy operations."""
    with warnings.catch_warnings():
        warnings.filterwarnings('ignore', category=RuntimeWarning, message='Degrees of freedom <= 0 for slice')
        warnings.filterwarnings('ignore', category=RuntimeWarning, message='invalid value encountered in divide')
        
        n1 = len(data1)
        n2 = len(data2)
        
        mean1 = np.mean(data1, axis=0)
        mean2 = np.mean(data2, axis=0)
        mean_diff = mean1 - mean2
        
        var1 = np.var(data1, axis=0, ddof=1) if n1 > 1 else np.zeros_like(mean1)
        var2 = np.var(data2, axis=0, ddof=1) if n2 > 1 else np.zeros_like(mean2)
        
        pooled_var = ((n1 - 1) * var1 + (n2 - 1) * var2) / max(n1 + n2 - 2, 1)
        pooled_std = np.sqrt(pooled_var * (1/n1 + 1/n2))
        
        t_stats = np.zeros_like(mean_diff)
        mask = (pooled_std > 0) & (~np.isnan(pooled_std))
        t_stats[mask] = mean_diff[mask] / pooled_std[mask]
        
        cohens_d = np.zeros_like(mean_diff)
        pooled_sd = np.sqrt((var1 + var2) / 2)
        mask = (pooled_sd > 0) & (~np.isnan(pooled_sd))
        cohens_d[mask] = mean_diff[mask] / pooled_sd[mask]
        
        return t_stats, mean_diff, var1, var2, cohens_d

# =============================================================================
# 7. Network Based Statistic (NBS) Class and Methods
# =============================================================================
class NetworkBasedStatistic:
    """Optimized implementation of Network-Based Statistic."""
    
    def __init__(self, 
                 max_permutations: int = 5000,
                 primary_threshold: float = 1.5,
                 convergence_window: int = 50,
                 convergence_alpha: float = 0.05,
                 min_iterations: int = 100,
                 window_idx: int = None,
                 total_windows: int = None,
                 state_window: str = None):
        """
        Initialize NBS with parallel processing capability and window tracking.
        """
        self.max_permutations = max_permutations
        self.primary_threshold = primary_threshold
        self.convergence_stats = ConvergenceStats(
            window_size=convergence_window,
            alpha=convergence_alpha,
            min_iterations=min_iterations
        )
        self.window_idx = window_idx
        self.total_windows = total_windows
        self.state_window = state_window
        
        if self.window_idx is not None and self.total_windows is not None:
            logging.info(f"\nStarting window {self.window_idx}/{self.total_windows}: {self.state_window}")
    
    def _empty_result(self):
        """Return an empty result dictionary when no significant components are found."""
        return {
            'significant_components': [],
            'component_pvals': [],
            'effect_sizes': [],
            'component_edges': [],
            'n_nodes': 0,
            'current_permutation': 0,
            'convergence_info': {
                'n_permutations': 0,
                'converged': False,
                'running_mean': [],
                'running_var': []
            }
        }

    def _process_chunk(self, chunk_size, combined_data, n1):
        """Process a chunk of permutations."""
        results = []
        for _ in range(chunk_size):
            perm_data = combined_data.copy()
            result = process_single_permutation(perm_data, n1, self.primary_threshold)
            results.append(result)
        return results

    def _find_components(self, adjacency_matrix: np.ndarray) -> Tuple[List[set], List[int], List[float], List[List[Tuple[int, int]]]]:
        """Identify and characterize connected components."""
        if adjacency_matrix.size == 0 or not np.any(adjacency_matrix):
            return [], [], [], []

        try:
            G = nx.from_numpy_array(adjacency_matrix)
            components = list(nx.connected_components(G))
            
            if not components:
                return [], [], [], []
                
            sizes = []
            densities = []
            component_edges = []

            for comp in components:
                if not comp:
                    continue
                    
                subgraph = G.subgraph(comp).copy()
                
                if subgraph.number_of_edges() > 0:
                    sizes.append(len(comp))
                    densities.append(nx.density(subgraph))
                    component_edges.append(list(subgraph.edges()))

            return components, sizes, densities, component_edges
        except Exception as e:
            logging.error(f"Error in finding components: {str(e)}")
            return [], [], [], []

        
    def _calculate_final_results(self, orig_components, perm_max_sizes, t_stats, orig_effects, mean_diff, 
                                   adj_matrix, G, n_nodes, processed_perms):
        """Calculate final results from permutation testing."""
        significant_components = []
        component_pvals = []
        component_effect_sizes = []
        component_edges_list = []
        component_all_connections = []

        for component in orig_components:
            if len(component) > 1:
                subgraph = G.subgraph(component).copy()
                size = len(component)
                n_greater = np.sum(np.array(perm_max_sizes) >= size)
                p_value = n_greater / len(perm_max_sizes) if len(perm_max_sizes) > 0 else 1.0

                if p_value < 0.05:
                    all_connections = []
                    nodes = list(component)
                    for i in range(len(nodes)):
                        for j in range(i + 1, len(nodes)):
                            all_connections.append((nodes[i], nodes[j]))

                    significant_components.append(component)
                    component_pvals.append(p_value)
                    component_edges_list.append(list(subgraph.edges()))
                    component_all_connections.append(all_connections)

                    component_effect_sizes.append({
                        'cohens_d': np.mean([orig_effects[i] for i in component]),
                        't_stat': np.mean([t_stats[i] for i in component]),
                        'density': nx.density(subgraph),
                        'size': size,
                        'mean_effect': np.mean([mean_diff[i] for i in component])
                    })

        return {
            'significant_components': significant_components,
            'component_pvals': component_pvals,
            'effect_sizes': component_effect_sizes,
            'component_edges': component_edges_list,
            'component_all_connections': component_all_connections,
            'n_nodes': n_nodes,
            'current_permutation': processed_perms,
            'convergence_info': {
                'n_permutations': len(perm_max_sizes),
                'converged': self.convergence_stats.converged,
                'running_mean': self.convergence_stats.running_mean,
                'running_var': self.convergence_stats.running_var
            }
        }

    def fit(self, data1: np.ndarray, data2: np.ndarray) -> Dict:
        """Optimized fitting using ThreadPoolExecutor."""
        n1, n_edges = data1.shape
        n2 = data2.shape[0]

        if n1 == 0 or n2 == 0:
            return self._empty_result()

        t_stats, mean_diff, var1, var2, cohens_d = compute_test_stats(data1, data2)
        significant_edges = np.abs(t_stats) > self.primary_threshold
        if not np.any(significant_edges):
            return self._empty_result()

        n_nodes = int((1 + np.sqrt(1 + 8 * n_edges)) / 2)
        edge_indices = np.triu_indices(n_nodes, k=1)

        adj_matrix = np.zeros((n_nodes, n_nodes))
        adj_matrix[edge_indices[0][significant_edges], edge_indices[1][significant_edges]] = 1
        adj_matrix = adj_matrix + adj_matrix.T

        G = nx.from_numpy_array(adj_matrix)
        orig_components = list(nx.connected_components(G))

        if not orig_components:
            return self._empty_result()

        combined_data = np.vstack([data1, data2])
        chunk_size = 50
        perm_max_sizes = []
        processed = 0

        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = []
            remaining_perms = self.max_permutations

            if self.window_idx is not None:
                sys.stdout.write('\r' + ' ' * 100)
                sys.stdout.flush()

            while remaining_perms > 0 and not self.convergence_stats.converged:
                current_chunk = min(chunk_size, remaining_perms)
                future = executor.submit(
                    self._process_chunk,
                    current_chunk,
                    combined_data,
                    n1
                )
                futures.append(future)
                remaining_perms -= current_chunk

                for completed_future in concurrent.futures.as_completed(futures):
                    chunk_results = completed_future.result()
                    perm_max_sizes.extend(chunk_results)
                    processed += len(chunk_results)

                    if self.window_idx is not None:
                        sys.stdout.write('\r' + ' ' * 100)
                        sys.stdout.write('\r')
                        progress_msg = (f"Processing window {self.window_idx}/{self.total_windows}: {self.state_window} | "
                                      f"Matrix shape: {n1}x{n_edges} | Threshold: {self.primary_threshold:.3f} | "
                                      f"Permutations: {processed}/{self.max_permutations}")
                        sys.stdout.write(progress_msg)
                        sys.stdout.flush()

                    if any(self.convergence_stats.update(size) for size in chunk_results):
                        remaining_perms = 0
                        break

                futures = []

            if self.window_idx is not None:
                sys.stdout.write('\n')
                sys.stdout.flush()

        return self._calculate_final_results(
            orig_components, perm_max_sizes, t_stats, cohens_d, mean_diff, 
            adj_matrix, G, n_nodes, processed
        )

# =============================================================================
# 8. Dynamic Group Differences Calculation Functions
# =============================================================================
def calculate_dynamic_group_differences_nbs(
    group1_data, group2_data, comparison, edge_labels, alpha=0.05, 
    state_window_fraction=1.0, max_permutations=5000, group3_data=None,
    results_filename=None,
    existing_differences=None,
    existing_nbs_results=None,
    existing_group1_means_all=None,
    existing_group2_means_all=None,
    existing_convergence_stats=None,
    existing_total_state_windows=0,
    existing_group3_means_all=None,
    processed_windows=None
):
    """Calculate group differences using Network-Based Statistics with global threshold."""
    log_filename = setup_logging()
    logging.info("Starting NBS-based dynamic group differences calculation...")

    differences = existing_differences if existing_differences is not None else {}
    nbs_results = existing_nbs_results if existing_nbs_results is not None else {}
    group1_means_all = existing_group1_means_all if existing_group1_means_all is not None else {}
    group2_means_all = existing_group2_means_all if existing_group2_means_all is not None else {}
    convergence_stats = existing_convergence_stats if existing_convergence_stats is not None else {}
    total_state_windows = existing_total_state_windows
    processed_windows = processed_windows if processed_windows is not None else set()

    if group3_data is not None:
        group3_means_all = existing_group3_means_all if existing_group3_means_all is not None else {}
    else:
        group3_means_all = None

    modes = list(group1_data.keys())
    total_modes = len(modes)

    logging.info("\nCollecting data across all modes for global threshold optimization...")
    data1_all_windows = []
    data2_all_windows = []
    total_windows = 0

    for mode in modes:
        group1_mode_data = group1_data[mode]
        group2_mode_data = group2_data[mode]
        common_windows = set(group1_mode_data.keys()) & set(group2_mode_data.keys())
        total_windows += len(common_windows)
        for state_window in common_windows:
            group1_arrays = [participant_data for participant_data in group1_mode_data[state_window].values()]
            group2_arrays = [participant_data for participant_data in group2_mode_data[state_window].values()]
            
            if all(arr.shape == group1_arrays[0].shape for arr in group1_arrays + group2_arrays):
                data1_all_windows.append(np.array(group1_arrays))
                data2_all_windows.append(np.array(group2_arrays))

    logging.info(f"Total windows across all modes: {total_windows}")

    threshold_filename = f'optimized_threshold_{comparison}.pkl'
    try:
        with open(threshold_filename, 'rb') as f:
            optimal_threshold = pickle.load(f)
        logging.info(f"\nLoaded optimized threshold: {optimal_threshold:.3f} for comparison: {comparison}")
    except FileNotFoundError:
        logging.info("Calculating global optimal threshold...")
        optimal_threshold = optimize_global_threshold(data1_all_windows, data2_all_windows, total_windows)
        logging.info(f"\nUsing global threshold: {optimal_threshold:.3f} for all windows")
        with open(threshold_filename, 'wb') as f:
            pickle.dump(optimal_threshold, f)
        logging.info(f"Optimized threshold saved to {threshold_filename}")

    logging.info("\nStarting mode-specific processing with the optimized threshold...")

    import sys

    for mode_idx, mode in enumerate(modes, 1):
        mode_message = f"\nProcessing mode: {mode} [{mode_idx}/{total_modes}] for comparison: {comparison}"
        logging.info(mode_message)
        print(mode_message)

        group1_mode_data = group1_data[mode]
        group2_mode_data = group2_data[mode]

        common_windows = set(group1_mode_data.keys()) & set(group2_mode_data.keys())

        if group3_data is not None:
            group3_mode_data = group3_data[mode]
            common_windows = common_windows & set(group3_mode_data.keys())

        state_windows = sorted(common_windows, key=lambda x: (int(x.split('_')[1]), int(x.split('_')[3])))

        total_mode_windows = len(state_windows)

        if state_window_fraction < 1.0:
            num_windows = max(1, int(total_mode_windows * state_window_fraction))
            state_windows = random.sample(state_windows, num_windows)
            logging.info(f"Analyzing {num_windows}/{total_mode_windows} windows for {mode}")
        else:
            logging.info(f"Analyzing {total_mode_windows} windows for {mode}")
        
        for window_idx, state_window in enumerate(state_windows, 1):
            key = (mode, state_window)
            if key in processed_windows:
                continue
            progress_message = f"Processing window {window_idx}/{len(state_windows)}: {state_window} | "
            sys.stdout.write('\r' + progress_message)
            sys.stdout.flush()

            logging.info(f"\nStarting window {window_idx}/{len(state_windows)}: {state_window}")
            
            group1_arrays = [participant_data for participant_data in group1_mode_data[state_window].values()]
            group2_arrays = [participant_data for participant_data in group2_mode_data[state_window].values()]

            if group3_data is not None:
                group3_mode_data = group3_data[mode]
                group3_arrays = [participant_data for participant_data in group3_mode_data[state_window].values()]
            else:
                group3_arrays = None

            if not all(arr.shape == group1_arrays[0].shape for arr in group1_arrays + group2_arrays):
                warning_msg = f"Skipping window {state_window} - Inconsistent array shapes"
                logging.warning(warning_msg)
                continue

            group1_values = np.array(group1_arrays)
            group2_values = np.array(group2_arrays)
            if group3_arrays is not None:
                group3_values = np.array(group3_arrays)

            matrix_shape_str = f"Matrix shape: {group1_values.shape[0]}x{group1_values.shape[1]} | "
            progress_message += matrix_shape_str
            sys.stdout.write('\r' + progress_message)
            sys.stdout.flush()

            threshold_str = f"Threshold: {optimal_threshold:.3f} | "
            progress_message += threshold_str
            sys.stdout.write('\r' + progress_message)
            sys.stdout.flush()
            
            nbs = NetworkBasedStatistic(
                max_permutations=max_permutations,
                primary_threshold=optimal_threshold,
                convergence_window=50,
                convergence_alpha=0.05,
                min_iterations=100,
                window_idx=window_idx,
                total_windows=len(state_windows),
                state_window=state_window
            )

            result = nbs.fit(group1_values, group2_values)

            n_components = len(result['significant_components']) if result['significant_components'] else 0
            current_perm = result.get('current_permutation', 0)

            perm_comp_str = f"Permutations: {current_perm}/{max_permutations} | Components: {n_components}"
            progress_message += perm_comp_str
            sys.stdout.write('\r' + progress_message)
            sys.stdout.flush()

            print()
            logging.info(f"Completed window {window_idx}/{len(state_windows)}: {state_window} | "
                         f"Permutations: {current_perm}/{max_permutations} | Components: {n_components}")

            if result['significant_components']:
                key = (mode, state_window)
                nbs_results[key] = result
                convergence_stats[key] = result['convergence_info']

                group1_means = np.mean(group1_values, axis=0)
                group2_means = np.mean(group2_values, axis=0)
                diff_values = group1_means - group2_means

                edge_indices = range(len(diff_values))
                differences[key] = dict(zip(edge_indices, diff_values))
                group1_means_all[key] = dict(zip(edge_indices, group1_means))
                group2_means_all[key] = dict(zip(edge_indices, group2_means))

                if group3_arrays is not None:
                    group3_means = np.mean(group3_values, axis=0)
                    group3_means_all[key] = dict(zip(edge_indices, group3_means))

                total_state_windows += 1
                logging.info(f"Found {n_components} significant components for {state_window}")

            processed_windows.add(key)
            if results_filename is not None:
                try:
                    with open(results_filename, 'wb') as f:
                        if group3_data is not None:
                            pickle.dump((differences, nbs_results, group1_means_all, group2_means_all,
                                         convergence_stats, total_state_windows, group3_means_all, processed_windows), f)
                        else:
                            pickle.dump((differences, nbs_results, group1_means_all, group2_means_all,
                                         convergence_stats, total_state_windows, processed_windows), f)
                    logging.info(f"Intermediate results saved to {results_filename}")
                except Exception as e:
                    logging.error(f"Error saving intermediate results: {str(e)}")

        final_msg = f"\nAnalysis complete for mode {mode}. Total windows with significant components: {total_state_windows}"
        logging.info(final_msg)
        print(final_msg)

    logging.info(f"\nAnalysis complete. Total windows with significant components: {total_state_windows}")
    logging.info(f"Complete log available in: {log_filename}")

    print(f"\nAnalysis complete. Total windows with significant components: {total_state_windows}")
    print(f"Complete log available in: {log_filename}")

    if group3_data is not None:
        return (differences, nbs_results, group1_means_all, group2_means_all, 
                convergence_stats, total_state_windows, group3_means_all)
    else:
        return (differences, nbs_results, group1_means_all, group2_means_all, 
                convergence_stats, total_state_windows)

# =============================================================================
# 9. Logging and Threshold Optimization Functions
# =============================================================================
def setup_logging():
    """Set up logging configuration."""
    timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
    log_filename = f'nbs_analysis_{timestamp}.log'
    
    f_handler = logging.FileHandler(log_filename)
    f_handler.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(message)s')
    f_handler.setFormatter(formatter)
    
    logger = logging.getLogger()
    logger.handlers = []
    logger.addHandler(f_handler)
    logger.setLevel(logging.INFO)

    return log_filename


def optimize_global_threshold(data1_all_windows: List[np.ndarray], 
                              data2_all_windows: List[np.ndarray],
                              total_windows: int,
                              thresholds: list = [1.5, 2.0, 2.5, 3.0]) -> float:
    """
    Optimize the primary t-statistic threshold for NBS analysis using stability analysis
    across all windows.
    """
    with warnings.catch_warnings():
        warnings.filterwarnings('ignore', category=RuntimeWarning)
        stability_scores = []
        
        total_thresholds = len(thresholds)
        total_windows = len(data1_all_windows)
        total_iterations = total_thresholds * total_windows
        current_iteration = 0

        for threshold_idx, threshold in enumerate(thresholds):
            all_windows_components = []
            
            for window_idx in range(len(data1_all_windows)):
                data1 = data1_all_windows[window_idx]
                data2 = data2_all_windows[window_idx]
                
                components_list = []
                n1, n_edges = data1.shape
                n2 = data2.shape[0]
                
                num_bootstrap_iterations = 20
                bootstrap_indices = [(
                    np.random.choice(n1, size=n1, replace=True),
                    np.random.choice(n2, size=n2, replace=True)
                ) for _ in range(num_bootstrap_iterations)]
                
                for idx1, idx2 in bootstrap_indices:
                    data1_bootstrap = data1[idx1]
                    data2_bootstrap = data2[idx2]
                    
                    mean_diff = np.mean(data1_bootstrap, axis=0) - np.mean(data2_bootstrap, axis=0)
                    var1 = np.var(data1_bootstrap, axis=0, ddof=1)
                    var2 = np.var(data2_bootstrap, axis=0, ddof=1)
                    pooled_var = ((len(idx1) - 1) * var1 + (len(idx2) - 1) * var2) / (len(idx1) + len(idx2) - 2)
                    pooled_std = np.sqrt(pooled_var * (1/len(idx1) + 1/len(idx2)))
                    
                    valid_mask = pooled_std > 0
                    t_stats = np.zeros(n_edges)
                    t_stats[valid_mask] = mean_diff[valid_mask] / pooled_std[valid_mask]
                    
                    n_nodes = int((1 + np.sqrt(1 + 8 * n_edges)) / 2)
                    adj_matrix = np.zeros((n_nodes, n_nodes))
                    edge_indices = np.triu_indices(n_nodes, k=1)
                    significant_edges = np.abs(t_stats) > threshold
                    adj_matrix[edge_indices[0][significant_edges], edge_indices[1][significant_edges]] = 1
                    
                    G = nx.from_numpy_array(adj_matrix)
                    components = list(nx.connected_components(G))
                    components_list.append(len(components))
                
                all_windows_components.extend(components_list)
                current_iteration += 1
                progress = (current_iteration / total_iterations) * 100
                sys.stdout.write(f"\rOptimizing threshold: {progress:.2f}% complete")
                sys.stdout.flush()
            
            stability_scores.append(np.var(all_windows_components))
        
        sys.stdout.write('\n')

    optimal_idx = np.argmin(stability_scores)
    optimal_threshold = thresholds[optimal_idx]
    
    print(f"Selected global threshold: {optimal_threshold:.3f}")
    return optimal_threshold

# =============================================================================
# 10. Summary and Excel Output Functions
# =============================================================================
def quantification_summary_to_excel_nbs(
    differences,
    nbs_results,
    group1_means_all,
    group2_means_all,
    edge_labels,
    networks,
    labels,
    change_type,
    comparison,
    writer=None,
    group3_means_all=None,
    alpha=0.05,
    existing_all_dfs=None,
    existing_summary_data=None,
    processed_keys=None,
    summary_filename=None,
    edge_indices_reverse=None
):
    """Generate summary with proper network pair extraction, including weighted values."""
    all_dfs = existing_all_dfs if existing_all_dfs is not None else []
    processed_keys = processed_keys if processed_keys is not None else set()
    summary_list = []

    network_names = sorted(list(networks.keys()))
    categories = [
        'Overall',
        'Intra-Network',
        'Inter-Network',
    ]
    intra_network_categories = [f'Intra-Network ({net})' for net in network_names]
    categories.extend(intra_network_categories)

    inter_network_pairs = []
    for i in range(len(network_names)):
        for j in range(i + 1, len(network_names)):
            net1, net2 = sorted([network_names[i], network_names[j]])
            inter_network_pairs.append(f'{net1}-{net2}')

    inter_network_categories = [f'Inter-Network ({pair})' for pair in inter_network_pairs]
    categories.extend(inter_network_categories)
    
    unique_state_windows = len(set([key[1] for key in nbs_results.keys()]))

    summary_data = existing_summary_data if existing_summary_data is not None else {}
    for category in categories:
        if category not in summary_data:
            summary_data[category] = {
                'Connection Count': 0,
                'Total Connection Count': 0,
                'Connection Percentage': 0,
                'Mean_Component_Size': [],
                'Mean_Component_Density': [],
                'Mean_Effect_Size': [],
                'Mean_T_Statistic': [],
                'Mean_P_Value': [],
                f'{group_names[comparison][0]} favored': 0,
                f'{group_names[comparison][1]} favored': 0,
                f'{group_names[comparison][0]} weighted value (average)': 0,
                f'{group_names[comparison][1]} weighted value (average)': 0,
                f'{group_names[comparison][0]} weighted value (total)': 0,
                f'{group_names[comparison][1]} weighted value (total)': 0
            }
            if comparison == 'oac_vs_tcoa' and group3_means_all is not None:
                summary_data[category].update({
                    'Compensatory_Connections': 0,
                    'Deterioration_Connections': 0,
                    'Restoration_Count': 0,
                    'Alternative_Count': 0,
                    'Normalization_Count': 0,
                    'Enhancement_Count': 0,
                    'Deterioration_Count': 0,
                    'Maladaptive_Count': 0,
                    'Exacerbation_Count': 0,
                    'Decompensation_Count': 0
                })

    n_regions = len(labels)
    total_possible_overall = (n_regions * (n_regions - 1)) // 2
    summary_data['Overall']['Total Connection Count'] = total_possible_overall * unique_state_windows

    total_possible_intra = 0
    for net in network_names:
        n = len(networks[net])
        total_intra = n * (n - 1) // 2
        total_possible_intra += total_intra
        category = f'Intra-Network ({net})'
        summary_data[category]['Total Connection Count'] = total_intra * unique_state_windows

    summary_data['Intra-Network']['Total Connection Count'] = total_possible_intra * unique_state_windows
    total_possible_inter = total_possible_overall - total_possible_intra
    summary_data['Inter-Network']['Total Connection Count'] = total_possible_inter * unique_state_windows

    for pair in inter_network_pairs:
        net1, net2 = pair.split('-')
        n1 = len(networks[net1])
        n2 = len(networks[net2])
        total_inter = n1 * n2
        category = f'Inter-Network ({pair})'
        summary_data[category]['Total Connection Count'] = total_inter * unique_state_windows

    all_dfs = existing_all_dfs if existing_all_dfs is not None else []
    processed_keys = processed_keys if processed_keys is not None else set()

    keys = differences.keys()

    for key in keys:
        if key in processed_keys:
            continue
        mode_key, state_window = key
        if key not in nbs_results or not nbs_results[key]['significant_components']:
            continue

        result = nbs_results[key]
        group1_means = group1_means_all[key]
        group2_means = group2_means_all[key]

        for comp_idx, component in enumerate(result['significant_components']):
            edges_data = []
            all_connections = result['component_all_connections'][comp_idx]
            
            for connection in all_connections:
                i, j = connection
                edge_idx = edge_indices_reverse.get((i, j))
                if edge_idx is None:
                    edge_idx = edge_indices_reverse.get((j, i))
                if edge_idx is None:
                    continue

                diff = differences[key][edge_idx]
                g1_mean = group1_means[edge_idx]
                g2_mean = group2_means[edge_idx]

                if should_include_edge(diff, g1_mean, g2_mean, change_type):
                    region1, region2 = labels[i], labels[j]
                    net1 = get_network_name(region1)
                    net2 = get_network_name(region2)

                    if region1 > region2:
                        region1, region2 = region2, region1
                        net1, net2 = net2, net1

                    network_pair = f"{net1}-{net2}"

                    is_direct = any(edge == (i,j) or edge == (j,i) 
                                  for edge in result['component_edges'][comp_idx])

                    edge_data = {
                        'Edge': edge_idx,
                        'Region1': region1,
                        'Region2': region2,
                        'Network_Pair': network_pair,
                        'Is_Direct_Connection': is_direct,
                        'Difference': diff,
                        'Group1_Mean': g1_mean,
                        'Group2_Mean': g2_mean,
                        'Component_Size': result['effect_sizes'][comp_idx]['size'],
                        'Component_PValue': result['component_pvals'][comp_idx],
                        'Cohens_D': result['effect_sizes'][comp_idx]['cohens_d'],
                        'T_Statistic': result['effect_sizes'][comp_idx]['t_stat'],
                        'Component_Density': result['effect_sizes'][comp_idx]['density'],
                        'State_Window': state_window,
                        'Mode': mode_key,
                        'Component_ID': f"{mode_key}_{state_window}_comp_{comp_idx}"
                    }
                    edges_data.append(edge_data)

            if edges_data:
                df = pd.DataFrame(edges_data)
                df['Favored_Group'] = df.apply(
                    lambda row: determine_favored_group(
                        row['Difference'],
                        row['Group1_Mean'],
                        row['Group2_Mean'],
                        change_type,
                        comparison
                    ),
                    axis=1
                )
                df['Mechanism'] = None
                if comparison == 'oac_vs_tcoa' and group3_means_all is not None:
                    for idx, row in df.iterrows():
                        edge_idx = row['Edge']
                        yac_mean = group3_means_all[key][edge_idx]
                        oac_mean = group1_means[edge_idx]
                        tcoa_mean = group2_means[edge_idx]

                        oac_distance = abs(oac_mean - yac_mean)
                        tcoa_distance = abs(tcoa_mean - yac_mean)

                        mechanism = None
                        if tcoa_distance < oac_distance:
                            if (tcoa_mean - yac_mean) * (oac_mean - yac_mean) > 0:
                                mechanism = 'Restoration'
                            else:
                                mechanism = 'Alternative'
                        else:
                            if (tcoa_mean - yac_mean) * (oac_mean - yac_mean) > 0:
                                mechanism = 'Deterioration'
                            else:
                                mechanism = 'Maladaptive'

                        if tcoa_mean > yac_mean > oac_mean:
                            mechanism = 'Enhancement' if tcoa_distance < oac_distance else 'Exacerbation'
                        elif tcoa_mean < yac_mean < oac_mean:
                            mechanism = 'Normalization' if tcoa_distance < oac_distance else 'Decompensation'

                        df.at[idx, 'Mechanism'] = mechanism

                    mechanism_types = {
                        'compensatory': ['Restoration', 'Alternative', 'Enhancement', 'Normalization'],
                        'deterioration': ['Deterioration', 'Maladaptive', 'Exacerbation', 'Decompensation']
                    }
                    for mech in mechanism_types['compensatory'] + mechanism_types['deterioration']:
                        df[mech] = df['Mechanism'].apply(lambda x: 1 if x == mech else 0)
                        
                all_dfs.append(df)

                for idx, row in df.iterrows():
                    region1 = row['Region1']
                    region2 = row['Region2']
                    net1 = get_network_name(region1)
                    net2 = get_network_name(region2)
                    net1, net2 = sorted([net1, net2])

                    categories_to_update = ['Overall']
                    if net1 == net2:
                        categories_to_update.extend([
                            'Intra-Network',
                            f'Intra-Network ({net1})'
                        ])
                    else:
                        categories_to_update.extend([
                            'Inter-Network',
                            f'Inter-Network ({net1}-{net2})'
                        ])

                    for category in categories_to_update:
                        data = summary_data[category]
                        data['Connection Count'] += 1
                        data[row['Favored_Group']] += 1
                        data['Mean_Component_Size'].append(row['Component_Size'])
                        data['Mean_Component_Density'].append(row['Component_Density'])
                        data['Mean_Effect_Size'].append(row['Cohens_D'])
                        data['Mean_T_Statistic'].append(row['T_Statistic'])
                        data['Mean_P_Value'].append(row['Component_PValue'])
                        group_name = row['Favored_Group'].replace(' favored', '')
                        data[f'{group_name} weighted value (total)'] += abs(row['Difference'])

                        if comparison == 'oac_vs_tcoa' and row['Mechanism']:
                            if row['Mechanism'] in mechanism_types['compensatory']:
                                data['Compensatory_Connections'] += 1
                            else:
                                data['Deterioration_Connections'] += 1
                            data[f'{row["Mechanism"]}_Count'] += 1

                processed_keys.add(key)
                if summary_filename is not None:
                    try:
                        with open(summary_filename, 'wb') as f:
                            pickle.dump((all_dfs, summary_data, processed_keys), f)
                        print(f"Intermediate summary data saved to {summary_filename}")
                    except Exception as e:
                        print(f"Error saving intermediate summary data: {str(e)}")

                if isinstance(all_dfs, list) and len(all_dfs) > 0:
                    final_df = pd.concat(all_dfs, ignore_index=True)
                        
    for category in categories:
        data = summary_data[category]
        total_conn = data['Total Connection Count']
        conn_count = data['Connection Count']
        data['Connection Percentage'] = (conn_count / total_conn) * 100 if total_conn > 0 else 0

        if isinstance(final_df, pd.DataFrame) and not final_df.empty:
            data['Direct_Connections'] = len(final_df[final_df['Is_Direct_Connection']])
            data['Indirect_Connections'] = len(final_df[~final_df['Is_Direct_Connection']])
        else:
            data['Direct_Connections'] = 0
            data['Indirect_Connections'] = 0

        for key in ['Mean_Component_Size', 'Mean_Component_Density', 'Mean_Effect_Size', 'Mean_T_Statistic', 'Mean_P_Value']:
            values = data[key]
            data[key] = np.mean(values) if values else 0

        for group in [group_names[comparison][0], group_names[comparison][1]]:
            favored_conn = data[f'{group} favored']
            total_weighted_value = data[f'{group} weighted value (total)']
            data[f'{group} weighted value (average)'] = (total_weighted_value / favored_conn) if favored_conn > 0 else 0

    summary_list = []
    for category in categories:
        data = summary_data[category]
        summary_entry = {
            'Category': category,
            'Connection Count': data['Connection Count'],
            'Total Connection Count': data['Total Connection Count'],
            'Connection Percentage': data['Connection Percentage'],
            'Mean_Component_Size': data['Mean_Component_Size'],
            'Mean_Component_Density': data['Mean_Component_Density'],
            'Mean_Effect_Size': data['Mean_Effect_Size'],
            'Mean_T_Statistic': data['Mean_T_Statistic'],
            'Mean_P_Value': data['Mean_P_Value'],
            f'{group_names[comparison][0]} favored': data[f'{group_names[comparison][0]} favored'],
            f'{group_names[comparison][1]} favored': data[f'{group_names[comparison][1]} favored'],
            f'{group_names[comparison][0]} weighted value (average)': data[f'{group_names[comparison][0]} weighted value (average)'],
            f'{group_names[comparison][1]} weighted value (average)': data[f'{group_names[comparison][1]} weighted value (average)'],
            f'{group_names[comparison][0]} weighted value (total)': data[f'{group_names[comparison][0]} weighted value (total)'],
            f'{group_names[comparison][1]} weighted value (total)': data[f'{group_names[comparison][1]} weighted value (total)']
        }

        if comparison == 'oac_vs_tcoa':
            summary_entry.update({
                'Compensatory_Connections': data.get('Compensatory_Connections', 0),
                'Deterioration_Connections': data.get('Deterioration_Connections', 0),
                'Restoration_Count': data.get('Restoration_Count', 0),
                'Alternative_Count': data.get('Alternative_Count', 0),
                'Normalization_Count': data.get('Normalization_Count', 0),
                'Enhancement_Count': data.get('Enhancement_Count', 0),
                'Deterioration_Count': data.get('Deterioration_Count', 0),
                'Maladaptive_Count': data.get('Maladaptive_Count', 0),
                'Exacerbation_Count': data.get('Exacerbation_Count', 0),
                'Decompensation_Count': data.get('Decompensation_Count', 0)
            })
        summary_list.append(summary_entry)

    try:
        if isinstance(all_dfs, list) and len(all_dfs) > 0:
            final_df = pd.concat(all_dfs, ignore_index=True)
        else:
            final_df = pd.DataFrame()

        if len(summary_list) > 0:
            df_summary_sheet = pd.DataFrame(summary_list)
        else:
            df_summary_sheet = pd.DataFrame()

        if summary_filename is not None:
            try:
                with open(summary_filename, 'wb') as f:
                    pickle.dump((all_dfs, summary_data, processed_keys), f)
                print(f"Intermediate summary data saved to {summary_filename}")
            except Exception as e:
                print(f"Error saving intermediate summary data: {str(e)}")

        if final_df.empty:
            print(f"No data to write for {comparison}, {change_type}.")

        return final_df, df_summary_sheet

    except Exception as e:
        print(f"Error in final DataFrame creation: {str(e)}")
        return pd.DataFrame(), pd.DataFrame()

# =============================================================================
# 11. Global Variables
# =============================================================================
group_names = {
    'oac_vs_yac': ('OAC', 'YAC'),
    'oac_vs_tcoa': ('OAC', 'TCOA')
}

# =============================================================================
# 12. Main Function and Execution Block
# =============================================================================
def main():
    start_time = time.time()
    
    """
    Main execution function with NBS implementation.
    """
    print("Starting the NBS-based dynamic functional connectivity analysis...")

    # ---------------------------
    # 1. Define Pilot Test or Full Analysis
    # ---------------------------
    pilot_test = False  # Set to False for full analysis

    YACs_full = ['101', '102', '103', '104', '105', '111', '112', '113', '114', '115', '116', '117', '118', '119', '120']
    OACs_full = ['202', '205', '206', '207', '208', '209', '210', '211', '214', '215', '216', '217', '218', '219', '221']
    TCOAs_full = ['401', '402', '403', '404', '406', '407', '408', '409', '410', '411', '412', '413', '414', '415', '416']

    if pilot_test:
        state_window_fraction = 0.2
        max_permutations = 2500
        YACs = YACs_full[:5]
        OACs = OACs_full[:5]
        TCOAs = TCOAs_full[:5]
        
        primary_threshold = 0.05        
        
        print("\nPilot Testing Mode Activated:")
        print(f"  YACs: {YACs}")
        print(f"  OACs: {OACs}")
        print(f"  TCOAs: {TCOAs}")
        print(f"  State Window Fraction: {state_window_fraction}")
        print(f"  Max Permutations: {max_permutations}")
    else:
        state_window_fraction = 1.0
        max_permutations = 5000
        primary_threshold = 0.05
        
        YACs = YACs_full
        OACs = OACs_full
        TCOAs = TCOAs_full
        print("\nFull Analysis Mode Activated:")
        print(f"  YACs: {YACs}")
        print(f"  OACs: {OACs}")
        print(f"  TCOAs: {TCOAs}")
        print(f"  State Window Fraction: {state_window_fraction}")
        print(f"  Max Permutations: {max_permutations}")

    participants = YACs + OACs + TCOAs

    participant_groups = {p: 'YAC' for p in YACs}
    participant_groups.update({p: 'OAC' for p in OACs})
    participant_groups.update({p: 'TCOA' for p in TCOAs})

    # ---------------------------
    # 2. Define Labels
    # ---------------------------
    labels = [
        '7Networks_LH_Cont_Cing_1-lh',
        '7Networks_LH_Cont_Par_1-lh',
        '7Networks_LH_Cont_PFCl_1-lh',
        '7Networks_LH_Cont_pCun_1-lh',
        '7Networks_LH_Default_Par_1-lh',
        '7Networks_LH_Default_Par_2-lh',
        '7Networks_LH_Default_pCunPCC_1-lh',
        '7Networks_LH_Default_pCunPCC_2-lh',
        '7Networks_LH_Default_PFC_1-lh',
        '7Networks_LH_Default_PFC_2-lh',
        '7Networks_LH_Default_PFC_3-lh',
        '7Networks_LH_Default_PFC_4-lh',
        '7Networks_LH_Default_PFC_5-lh',
        '7Networks_LH_Default_PFC_6-lh',
        '7Networks_LH_Default_PFC_7-lh',
        '7Networks_LH_Default_Temp_1-lh',
        '7Networks_LH_Default_Temp_2-lh',
        '7Networks_LH_DorsAttn_FEF_1-lh',
        '7Networks_LH_DorsAttn_Post_1-lh',
        '7Networks_LH_DorsAttn_Post_2-lh',
        '7Networks_LH_DorsAttn_Post_3-lh',
        '7Networks_LH_DorsAttn_Post_4-lh',
        '7Networks_LH_DorsAttn_Post_5-lh',
        '7Networks_LH_DorsAttn_Post_6-lh',
        '7Networks_LH_DorsAttn_PrCv_1-lh',
        '7Networks_LH_Limbic_OFC_1-lh',
        '7Networks_LH_Limbic_TempPole_1-lh',
        '7Networks_LH_Limbic_TempPole_2-lh',
        '7Networks_LH_SalVentAttn_FrOperIns_1-lh',
        '7Networks_LH_SalVentAttn_FrOperIns_2-lh',
        '7Networks_LH_SalVentAttn_Med_1-lh',
        '7Networks_LH_SalVentAttn_Med_2-lh',
        '7Networks_LH_SalVentAttn_Med_3-lh',
        '7Networks_LH_SalVentAttn_ParOper_1-lh',
        '7Networks_LH_SalVentAttn_PFCl_1-lh',
        '7Networks_LH_SomMot_1-lh',
        '7Networks_LH_SomMot_2-lh',
        '7Networks_LH_SomMot_3-lh',
        '7Networks_LH_SomMot_4-lh',
        '7Networks_LH_SomMot_5-lh',
        '7Networks_LH_SomMot_6-lh',
        '7Networks_LH_Vis_1-lh',
        '7Networks_LH_Vis_2-lh',
        '7Networks_LH_Vis_3-lh',
        '7Networks_LH_Vis_4-lh',
        '7Networks_LH_Vis_5-lh',
        '7Networks_LH_Vis_6-lh',
        '7Networks_LH_Vis_7-lh',
        '7Networks_LH_Vis_8-lh',
        '7Networks_LH_Vis_9-lh',
        '7Networks_RH_Cont_Cing_1-rh',
        '7Networks_RH_Cont_Par_1-rh',
        '7Networks_RH_Cont_Par_2-rh',
        '7Networks_RH_Cont_PFCl_1-rh',
        '7Networks_RH_Cont_PFCl_2-rh',
        '7Networks_RH_Cont_PFCl_3-rh',
        '7Networks_RH_Cont_PFCl_4-rh',
        '7Networks_RH_Cont_PFCmp_1-rh',
        '7Networks_RH_Cont_pCun_1-rh',
        '7Networks_RH_Default_Par_1-rh',
        '7Networks_RH_Default_pCunPCC_1-rh',
        '7Networks_RH_Default_pCunPCC_2-rh',
        '7Networks_RH_Default_PFCdPFCm_1-rh',
        '7Networks_RH_Default_PFCdPFCm_2-rh',
        '7Networks_RH_Default_PFCdPFCm_3-rh',
        '7Networks_RH_Default_PFCv_1-rh',
        '7Networks_RH_Default_PFCv_2-rh',
        '7Networks_RH_Default_Temp_1-rh',
        '7Networks_RH_Default_Temp_2-rh',
        '7Networks_RH_Default_Temp_3-rh',
        '7Networks_RH_DorsAttn_FEF_1-rh',
        '7Networks_RH_DorsAttn_Post_1-rh',
        '7Networks_RH_DorsAttn_Post_2-rh',
        '7Networks_RH_DorsAttn_Post_3-rh',
        '7Networks_RH_DorsAttn_Post_4-rh',
        '7Networks_RH_DorsAttn_Post_5-rh',
        '7Networks_RH_DorsAttn_PrCv_1-rh',
        '7Networks_RH_Limbic_OFC_1-rh',
        '7Networks_RH_Limbic_TempPole_1-rh',
        '7Networks_RH_SalVentAttn_FrOperIns_1-rh',
        '7Networks_RH_SalVentAttn_Med_1-rh',
        '7Networks_RH_SalVentAttn_Med_2-rh',
        '7Networks_RH_SalVentAttn_TempOccPar_1-rh',
        '7Networks_RH_SalVentAttn_TempOccPar_2-rh',
        '7Networks_RH_SomMot_1-rh',
        '7Networks_RH_SomMot_2-rh',
        '7Networks_RH_SomMot_3-rh',
        '7Networks_RH_SomMot_4-rh',
        '7Networks_RH_SomMot_5-rh',
        '7Networks_RH_SomMot_6-rh',
        '7Networks_RH_SomMot_7-rh',
        '7Networks_RH_SomMot_8-rh',
        '7Networks_RH_Vis_1-rh',
        '7Networks_RH_Vis_2-rh',
        '7Networks_RH_Vis_3-rh',
        '7Networks_RH_Vis_4-rh',
        '7Networks_RH_Vis_5-rh',
        '7Networks_RH_Vis_6-rh',
        '7Networks_RH_Vis_7-rh',
        '7Networks_RH_Vis_8-rh'
    ]
    labels = np.array(labels)

    # ---------------------------
    # 3. Initialize Networks Dictionary
    # ---------------------------
    networks = {
        'Visual': [],
        'Somatomotor': [],
        'DorsalAttention': [],
        'VentralAttention': [],
        'Limbic': [],
        'Frontoparietal': [],
        'Default': []
    }

    for i, label in enumerate(labels):
        network_name = get_network_name(label)
        if network_name:
            region_name = get_region_name(label)
            networks[network_name].append((i, region_name))

    # ---------------------------
    # 4. Load and Process Data
    # ---------------------------
    base_dir = '/home/cerna3/neuroconn/data/out/subjects/'
    modes = ['EC', 'EO']

    random.seed(42)
    np.random.seed(42)

    connectivity_data = load_connectivity_data(base_dir, participants, modes)
    
    print("\nOrganizing connectivity data for NBS analysis...")
    organized_data = {group: {mode: {} for mode in modes} for group in set(participant_groups.values())}

    total_participants = len(participants)
    participants_processed = 0

    for participant in participants:
        group = participant_groups.get(participant)
        if participant not in connectivity_data or not group:
            continue
        participant_data = connectivity_data[participant]
        for mode in modes:
            if mode not in participant_data:
                continue
            mode_data = participant_data[mode]
            for state_window, matrix in mode_data.items():
                state_window_str = f"state_{state_window[0]}_window_{state_window[1]}"
                if state_window_str not in organized_data[group][mode]:
                    organized_data[group][mode][state_window_str] = {}
                vector_form = matrix[np.triu_indices_from(matrix, k=1)]
                organized_data[group][mode][state_window_str][participant] = vector_form
        participants_processed += 1
        sys.stdout.write(f"\rOrganized data for participant {participant} ({participants_processed}/{total_participants})")
        sys.stdout.flush()

    print("\nFinished organizing connectivity data.\n")
    
    n_regions = labels.shape[0]
    edge_indices = {}
    idx = 0
    edge_labels = []
    for i in range(n_regions):
        for j in range(i + 1, n_regions):
            edge_indices[idx] = (i, j)
            region_pair = (labels[i], labels[j])
            edge_labels.append(region_pair)
            idx += 1
    
    edge_indices_reverse = {}
    idx = 0
    for i in range(n_regions):
        for j in range(i + 1, n_regions):
            edge_indices_reverse[(i, j)] = idx
            edge_indices_reverse[(j, i)] = idx
            idx += 1

    change_types = ['IMPC', 'DMNC', 'SNPC', 'DMPC', 'IMNC', 'SPNC']

    total_steps = len(group_names) * len(change_types)
    completed_steps = 0

    aging_results = {}
    taichi_results = {}
    group1_means_all_global = {}
    group2_means_all_global = {}
    group3_means_all_global = {}
    group3_means_all = {}
    group1_means_all = {}
    group2_means_all = {}

    # ---------------------------
    # 5. Run Analysis
    # ---------------------------
    for comparison_index, (comparison, (group1, group2)) in enumerate(group_names.items(), 1):
        print(f"\nPerforming comparison {comparison_index}/{len(group_names)}: {group1} vs {group2}")
        logging.info(f"\nStarting comparison {comparison_index}/{len(group_names)}: {group1} vs {group2}")

        group1_data = organized_data[group1]
        group2_data = organized_data[group2]

        results_filename = f'results_{comparison}.pkl'
        try:
            with open(results_filename, 'rb') as f:
                if comparison == 'oac_vs_tcoa':
                    differences, nbs_results, group1_means_all, group2_means_all, convergence_stats, total_state_windows, group3_means_all, processed_windows = pickle.load(f)
                else:
                    differences, nbs_results, group1_means_all, group2_means_all, convergence_stats, total_state_windows, processed_windows = pickle.load(f)
                print(f"Loaded existing results for {comparison} from {results_filename}")
        except Exception as e:
            print(f"Could not load results from {results_filename}: {str(e)}")
            differences = {}
            nbs_results = {}
            group1_means_all = {}
            group2_means_all = {}
            convergence_stats = {}
            total_state_windows = 0
            group3_means_all = {}
            processed_windows = set()

        if comparison == 'oac_vs_tcoa':
            group3_data = organized_data['YAC']
            results = calculate_dynamic_group_differences_nbs(
                group1_data,
                group2_data,
                comparison,
                edge_labels,
                alpha=0.05,
                state_window_fraction=state_window_fraction,
                max_permutations=max_permutations,
                group3_data=group3_data,
                results_filename=results_filename,
                existing_differences=differences,
                existing_nbs_results=nbs_results,
                existing_group1_means_all=group1_means_all,
                existing_group2_means_all=group2_means_all,
                existing_convergence_stats=convergence_stats,
                existing_total_state_windows=total_state_windows,
                existing_group3_means_all=group3_means_all,
                processed_windows=processed_windows
            )
        else:
            results = calculate_dynamic_group_differences_nbs(
                group1_data,
                group2_data,
                comparison,
                edge_labels,
                alpha=0.05,
                state_window_fraction=state_window_fraction,
                max_permutations=max_permutations,
                results_filename=results_filename,
                existing_differences=differences,
                existing_nbs_results=nbs_results,
                existing_group1_means_all=group1_means_all,
                existing_group2_means_all=group2_means_all,
                existing_convergence_stats=convergence_stats,
                existing_total_state_windows=total_state_windows,
                processed_windows=processed_windows
            )

        try:
            with open(results_filename, 'wb') as f:
                if comparison == 'oac_vs_tcoa':
                    pickle.dump((results[0], results[1], results[2], results[3], results[4], results[5], results[6], processed_windows), f)
                else:
                    pickle.dump((results[0], results[1], results[2], results[3], results[4], results[5], processed_windows), f)
            print(f"Saved results for {comparison} to {results_filename}")
        except Exception as e:
            print(f"Error saving results to {results_filename}: {str(e)}")
            backup_filename = f'results_{comparison}_backup.pkl'
            try:
                with open(backup_filename, 'wb') as f:
                    if comparison == 'oac_vs_tcoa':
                        pickle.dump((results[0], results[1], results[2], results[3], results[4], results[5], results[6], processed_windows), f)
                    else:
                        pickle.dump((results[0], results[1], results[2], results[3], results[4], results[5], processed_windows), f)
                print(f"Results saved to backup file: {backup_filename}")
            except Exception as e:
                print(f"Error saving backup: {str(e)}")

        if comparison == 'oac_vs_tcoa':
            differences, nbs_results, group1_means_all, group2_means_all, convergence_stats, total_state_windows, group3_means_all = results
            taichi_results = {k: v for k, v in nbs_results.items()}
            group1_means_all_global.update(group1_means_all)
            group2_means_all_global.update(group2_means_all)
            group3_means_all_global.update(group3_means_all)
        else:
            differences, nbs_results, group1_means_all, group2_means_all, convergence_stats, total_state_windows = results
            group3_means_all = None
            if comparison == 'oac_vs_yac':
                aging_results = {k: v for k, v in nbs_results.items()}
                group1_means_all_global.update(group1_means_all)
                group2_means_all_global.update(group2_means_all)

        for change_type_index, change_type in enumerate(change_types, 1):
            df_filename = f'df_summary_{comparison}_{change_type}.pkl'
            try:
                with open(df_filename, 'rb') as f:
                    existing_all_dfs, existing_summary_data, processed_keys = pickle.load(f)
                    print(f"Loaded existing summary data for {comparison}, {change_type} from {df_filename}")
            except Exception as e:
                print(f"Could not load summary from {df_filename}: {str(e)}")
                existing_all_dfs = []
                existing_summary_data = {}
                processed_keys = set()

            print("Computing summary...")

            df_summary, df_summary_sheet = quantification_summary_to_excel_nbs(
                differences,
                nbs_results,
                group1_means_all,
                group2_means_all,
                edge_labels,
                networks,
                labels,
                change_type,
                comparison,
                writer=None,
                group3_means_all=group3_means_all,
                existing_all_dfs=existing_all_dfs,
                existing_summary_data=existing_summary_data,
                processed_keys=processed_keys,
                summary_filename=df_filename,
                edge_indices_reverse=edge_indices_reverse
            )

            try:
                with open(df_filename, 'wb') as f:
                    pickle.dump((df_summary, df_summary_sheet, processed_keys), f)
                print(f"Saved summary DataFrames for {comparison}, {change_type} to {df_filename}")
            except Exception as e:
                print(f"Error saving summary to {df_filename}: {str(e)}")
                backup_filename = f'df_summary_{comparison}_{change_type}_backup.pkl'
                try:
                    with open(backup_filename, 'wb') as f:
                        pickle.dump((df_summary, df_summary_sheet, processed_keys), f)
                    print(f"Summary saved to backup file: {backup_filename}")
                except Exception as e:
                    print(f"Error saving backup: {str(e)}")

            completed_steps += 1
            progress = (completed_steps / total_steps) * 100
            sys.stdout.write(f"\rOverall Progress: {completed_steps}/{total_steps} ({progress:.2f}%)")
            sys.stdout.flush()

    print("\nAll comparisons and change types processed.")

    with pd.ExcelWriter('quantification_summary_dynamic_nbs.xlsx', engine='xlsxwriter') as writer:
        for comparison_index, (comparison, (group1, group2)) in enumerate(group_names.items(), 1):
            for change_type_index, change_type in enumerate(change_types, 1):
                df_filename = f'df_summary_{comparison}_{change_type}.pkl'
                if os.path.exists(df_filename):
                    with open(df_filename, 'rb') as f:
                        df_summary, df_summary_sheet, _ = pickle.load(f)
                    sheet_name = f'{comparison}_{change_type.lower()}'
                    df_summary.to_excel(writer, sheet_name=sheet_name, index=False)
                    sheet_name_summary = f'{comparison}_{change_type.lower()}_summary'
                    df_summary_sheet.to_excel(writer, sheet_name=sheet_name_summary, index=False)
                    print(f"Sheets '{sheet_name}' and '{sheet_name_summary}' written to Excel.")
    print("\nAnalysis complete. Results saved to 'quantification_summary_dynamic_nbs.xlsx'")
        
    total_time_taken = (time.time() - start_time) / 60
    print(f"Total processing time: {total_time_taken:.2f} minutes.")

    if 'aging_results' in locals() and 'taichi_results' in locals() and group3_means_all_global:
        yac_means = group3_means_all_global
        compensatory_results = identify_compensatory_connections(
            aging_results=aging_results,
            taichi_results=taichi_results,
            yac_means=yac_means,
            comparison='oac_vs_tcoa',
            labels=labels,
            edge_indices_reverse=edge_indices_reverse,
            group1_means_all=group1_means_all_global,
            group2_means_all=group2_means_all_global
        )
        
        compensatory_filename = 'compensatory_results.pkl'
        with open(compensatory_filename, 'wb') as f:
            pickle.dump(compensatory_results, f)
        print(f"Compensatory results saved to {compensatory_filename}")
    else:
        print("Not enough data to perform compensatory analysis.")

if __name__ == '__main__':
    main()