In [2]:
# Import necessary libraries

import os
import json
import sys
import time
import cProfile
import itertools
from itertools import combinations
import numpy as np
import numba as nb
from numba import njit
import mne
import nibabel
from multiprocessing import Pool
from hmmlearn import hmm
import networkx as nx
from sklearn.preprocessing import StandardScaler
import scipy.stats as stats

In [3]:
# Base directory for output files
base_dir = '/home/cerna3/neuroconn/data/out/subjects/'
participants = ['101', '102', '103', '104', '105', '111', '112', '113', '114', '115', '116', '117', '118', '119', '120',
                '202', '205', '206', '207', '208', '209', '210', '211', '214', '215', '216', '217', '218', '219', '221',
                '401', '402', '403', '404', '406', '407', '408', '409', '410', '411', '412', '413', '414', '415', '416']

modes = ['EC', 'EO']

optimal_states = {mode: [] for mode in modes}  # Store optimal states for each mode

# Iterate through participants and modes
for participant in participants:
    for mode in modes:
        file_path = os.path.join(base_dir, participant, mode, f"aic_bic_{participant}_{mode}.txt")

        # Check if file exists
        if os.path.exists(file_path):
            # Read file and extract optimal state
            with open(file_path, 'r') as f:
                lines = f.readlines()
                for line in lines:
                    if line.startswith("Optimal state (Average):"):
                        optimal_state = int(line.split()[-1])
                        optimal_states[mode].append(optimal_state)
                        break
        else:
            print(f"File missing for participant {participant}, mode {mode}")

# Calculate and print average and median optimal states for each mode
for mode, states in optimal_states.items():
    mean_optimal_state = np.mean(states)
    std_optimal_state = np.std(states)
    outliers = [participant for participant, state in zip(participants, states) if abs(state - mean_optimal_state) > 3 * std_optimal_state]
    median_optimal_state = sorted(states)[len(states) // 2] if len(states) % 2 != 0 else (sorted(states)[len(states) // 2 - 1] + sorted(states)[len(states) // 2]) / 2
    print(f"Average Optimal State ({mode}): {mean_optimal_state:.2f}")
    print(f"Median Optimal State ({mode}): {median_optimal_state:.2f}")
    print(f"Number of outliers for {mode}: {len(outliers)}")
    if outliers:
        print(f"Outlier IDs for {mode}: {', '.join(outliers)}")

    # Save the median optimal state to a file
    output_file_path = os.path.join(base_dir, f"median_optimal_state_{mode}.txt")
    with open(output_file_path, 'w') as out_file:
        out_file.write(f"Median Optimal State ({mode}): {median_optimal_state:.2f}\n")

Average Optimal State (EC): 7.24
Median Optimal State (EC): 7.00
Number of outliers for EC: 0
Average Optimal State (EO): 6.73
Median Optimal State (EO): 5.00
Number of outliers for EO: 0


In [3]:
# HMM Fitting 

# Define the base directory where participant folders are located
files_in = '../data/in/subjects/'
files_out = '../data/out/subjects/'
output_dir_base = '/home/cerna3/neuroconn/data/out/subjects'

total_participants = len(participants) * len(modes)
current_count = 0
start_time = time.time()

def check_correlation_range(corr_matrix):
    """Check if the correlation matrix values are within the range [-1, 1]."""
    if np.any(corr_matrix < -1) or np.any(corr_matrix > 1):
        return False
    return True

def validate_data(data, context):
    """Validate the input data and ensure it does not contain NaNs or empty slices."""
    if data.size == 0:
        raise ValueError(f"Empty data encountered in {context}.")
    if np.isnan(data).any():
        raise ValueError(f"NaN values encountered in {context}.")
    if np.isinf(data).any():
        raise ValueError(f"Infinite values encountered in {context}.")
    return data

# Uncomment if troubleshooting is needed
# Initialize counters for empty data cases
# total_initial_empty_cases = 0
# total_replaced_empty_cases = 0
# total_remaining_empty_cases = 0

for participant in participants:
    participant_start_time = time.time()  # Record start time for the participant
    for mode in modes:
        dir_in = files_in + participant + '/' + mode + '/'
        dir_out = files_out + participant + "/" + mode + '/'

        try:
            # Load orthogonalized data
            orthogonalized_data = np.load(dir_out + "orth.npy")
            orthogonalized_data = validate_data(orthogonalized_data, f"orthogonalized data for participant {participant}, mode {mode}")
            
            features = np.mean(orthogonalized_data, axis=2)
            features = validate_data(features, f"mean features for participant {participant}, mode {mode}")
            features = np.ma.masked_invalid(features).filled(np.mean(features, axis=0))
            scaler = StandardScaler()
            features = scaler.fit_transform(features)

            median_optimal_state = np.median(optimal_states[mode])
            model = hmm.GaussianHMM(n_components=int(median_optimal_state), n_iter=50, 
                                    covariance_type='full', tol=1e-7, verbose=False, 
                                    params='st', init_params='stmc')
            model.fit(features)
            state_sequence = model.predict(features)
            state_probs = model.predict_proba(features)

            output_dir = os.path.join(output_dir_base, participant, mode)

            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            np.save(os.path.join(output_dir, f"{participant}_state_sequence.npy"), state_sequence)
            np.save(os.path.join(output_dir, f"{participant}_state_probs.npy"), state_probs)

            # CALCULATE TEMPORAL FEATURES
            
            # Compute fractional occupancy: fraction of time spent in each state
            fractional_occupancy = np.array([np.sum(state_sequence == i) / len(state_sequence) for i in range(int(median_optimal_state))])

            # Compute transition probabilities
            transition_counts = np.zeros((int(median_optimal_state), int(median_optimal_state)))
            for (i, j) in zip(state_sequence[:-1], state_sequence[1:]):
                transition_counts[i, j] += 1
            transition_probabilities = transition_counts / np.sum(transition_counts, axis=1, keepdims=True)

            # Compute mean lifetime (dwell time) in each state: average time spent in each state before transitioning
            mean_lifetime = np.zeros(int(median_optimal_state))
            for i in range(int(median_optimal_state)):
                # Identify the indices where state changes
                change_indices = np.where(np.diff(state_sequence == i, prepend=False, append=False))[0]
                # Calculate segment lengths by differencing indices of changes; add 1 because diff loses 1
                segment_lengths = np.diff(change_indices) + 1
                # Compute mean segment length for state i
                mean_lifetime[i] = np.mean(segment_lengths) if len(segment_lengths) > 0 else 0

            # Mean Interval Length: average time between consecutive occurrences of each state
            mean_interval_length = np.zeros(int(median_optimal_state))
            for k in range(int(median_optimal_state)):
                # Boolean array where True is the state 'k'
                is_state_k = state_sequence == k
                # Time points where state is 'k'
                time_points_k = np.where(is_state_k)[0]
                # Compute time differences between consecutive occurrences of state 'k'
                intervals_k = np.diff(time_points_k) - 1
                # Compute the mean of these intervals, accounting for the case where state 'k' does not repeat
                mean_interval_length[k] = np.mean(intervals_k) if len(intervals_k) > 0 else 0
                
            np.savez(os.path.join(output_dir, f"{participant}_temporal_features.npz"), 
                     fractional_occupancy=fractional_occupancy, 
                     transition_probabilities=transition_probabilities,
                     mean_lifetime=mean_lifetime, 
                     mean_interval_length=mean_interval_length)
                    
            # CALCULATE SPATIAL FEATURES (FUNCTIONAL CONNECTIVITY)
            
            def calculate_functional_connectivity(orthogonalized_data, state_sequence, median_optimal_state):
                correlation_matrices = {}
                positive_correlations = {}
                negative_correlations = {}

                initial_empty_case_count = 0  # Counter for initial empty data cases
                replaced_empty_case_count = 0  # Counter for replaced empty data cases

                for state in range(int(median_optimal_state)):
                    state_indices = np.where(state_sequence == state)[0]

                    block_starts = np.where(np.diff(state_indices) > 1)[0] + 1
                    block_starts = np.insert(block_starts, 0, 0)
                    block_ends = np.append(block_starts[1:] - 1, len(state_indices) - 1)
                    state_blocks = zip(state_indices[block_starts], state_indices[block_ends] + 1)

                    for i, (start_index, end_index) in enumerate(state_blocks):
                        state_data = orthogonalized_data[:, :, start_index:end_index]

                        if state_data.size == 0:
                            initial_empty_case_count += 1  # Increment initial empty case count
                            # Replace empty data with the mean of the input data
                            state_data = np.mean(orthogonalized_data, axis=2, keepdims=True) 
                            replaced_empty_case_count += 1

                        state_data = validate_data(state_data, f"state data for state {state}, block {i}, participant {participant}, mode {mode}")

                        # Select the time and sample dimensions for correlation
                        reshaped_data = state_data.swapaxes(1, 2).reshape(102, -1)

                        # Impute NaNs if necessary
                        if np.isnan(state_data).any():
                            reshaped_data = np.nan_to_num(reshaped_data)

                        corr_matrix = np.corrcoef(reshaped_data)
                        if not check_correlation_range(corr_matrix):
                            raise ValueError(f"Correlation values out of range in state {state}, block {i} for participant {participant}, mode {mode}. Check data preprocessing.")

                        correlation_matrices[(state, i)] = corr_matrix

                        upper_tri_indices = np.triu_indices_from(corr_matrix, k=1)
                        positive_correlations[(state, i)] = corr_matrix[upper_tri_indices][corr_matrix[upper_tri_indices] > 0]
                        negative_correlations[(state, i)] = corr_matrix[upper_tri_indices][corr_matrix[upper_tri_indices] < 0]

                remaining_empty_case_count = sum(1 for key, value in correlation_matrices.items() if value.size == 0)

                return correlation_matrices, positive_correlations, negative_correlations, initial_empty_case_count, replaced_empty_case_count, remaining_empty_case_count

            # CALCULATE SPATIAL FEATURES (FUNCTIONAL CONNECTIVITY)
            correlation_matrices, positive_correlations, negative_correlations, initial_empty_case_count, replaced_empty_case_count, remaining_empty_case_count = calculate_functional_connectivity(
                orthogonalized_data, state_sequence, median_optimal_state)
            
            # Uncomment if troubleshooting is needed
            # total_initial_empty_cases += initial_empty_case_count
            # total_replaced_empty_cases += replaced_empty_case_count
            # total_remaining_empty_cases += remaining_empty_case_count

            correlation_matrices_file = os.path.join(output_dir, f"{participant}_correlation_matrices.npy")
            positive_correlations_file = os.path.join(output_dir, f"{participant}_positive_correlations.npy")
            negative_correlations_file = os.path.join(output_dir, f"{participant}_negative_correlations.npy")
            # Convert tuple keys to strings
            arrays_to_save = {}
            for key, value in correlation_matrices.items():
                key_str = "_".join(map(str, key))
                arrays_to_save[key_str] = value
            np.savez(correlation_matrices_file, **arrays_to_save)
            np.savez(positive_correlations_file, positive_correlations)
            np.savez(negative_correlations_file, negative_correlations)

        except Exception as e:
            print(f"Error processing participant {participant}, mode {mode}: {e}")
        
        current_count += 1
        participant_elapsed_time = (time.time() - participant_start_time) / 60
        total_elapsed_time = (time.time() - start_time) / 60
        avg_time_per_participant = total_elapsed_time / current_count
        #progress_percent = (current_count / total_participants) * 100  # Uncomment if the need to keep track of participant's processing progress arises
        sys.stdout.write(f"\rProcessing {participant} | Mode: {mode} | "
                         f"Participant Progress: {participant_elapsed_time:.2f} min | "
                         #f"Overall Progress: {progress_percent:.2f}% | " # Uncomment if the need to keep track of participant's processing progress arises
                         f"Avg Time/Participant: {avg_time_per_participant:.2f} min")
        sys.stdout.flush()

print("\nAll HMM fittings completed.")
total_time_taken = (time.time() - start_time) / 60
print(f"Total processing time: {total_time_taken:.2f} minutes. Average time per participant: {avg_time_per_participant:.2f} minutes.")

# Uncomment if troubleshooting is needed
# print(f"Total initial empty data cases: {total_initial_empty_cases}")
# print(f"Total replaced empty data cases: {total_replaced_empty_cases}")
# print(f"Total remaining empty data cases: {total_remaining_empty_cases}")

Processing 416 | Mode: EO | Participant Progress: 0.05 min | Avg Time/Participant: 0.03 min
All HMM fittings completed.
Total processing time: 2.36 minutes. Average time per participant: 0.03 minutes.


In [None]:
# Function to load all correlation matrices with checks for out-of-bounds values
def load_all_correlation_matrices(base_dir, participants):
    all_matrices = []
    total_participants = len(participants)
    out_of_bounds_count = 0  # Counter for out-of-bounds matrices
    printed_out_of_bounds = False  # Flag to check if an out-of-bounds matrix has been printed

    for idx, participant in enumerate(participants):
        participant_dir = os.path.join(base_dir, participant)
        for mode in ["EC", "EO"]:
            mode_dir = os.path.join(participant_dir, mode)
            if os.path.isdir(mode_dir):
                filename_pattern = f"{participant}_correlation_matrices.npy.npz"
                filepath = os.path.join(mode_dir, filename_pattern)
                if os.path.exists(filepath):
                    data = np.load(filepath, allow_pickle=True)
                    for key in data.files:
                        matrix = data[key]
                        if matrix.ndim == 2:
                            # Check for out-of-bounds values and print detailed statistics if found
                            if np.any(matrix < -1) or np.any(matrix > 1):
                                out_of_bounds_count += 1
                                if not printed_out_of_bounds:
                                    print(f"Out-of-bounds values found in file {filepath}, key {key}.")
                                    out_of_bounds_indices = np.where((matrix < -1) | (matrix > 1))
                                    print(f"Out-of-bounds values in matrix: {matrix[out_of_bounds_indices]}")
                                    print(f"Matrix {participant}, {mode}, {key}: min {np.min(matrix)}, max {np.max(matrix)}, mean {np.mean(matrix)}, std {np.std(matrix)}")
                                    printed_out_of_bounds = True
                            all_matrices.append(matrix)
                        else:
                            print(f"Non-2D matrix found in file {filepath}, key {key}.")
                else:
                    print(f"File not found: {filepath}")
        print(f"Processed {idx + 1}/{total_participants} participants ({(idx + 1) / total_participants * 100:.2f}%)", end='\r')
    
    print("\nAll participants processed.")
    print(f"Total matrices loaded: {len(all_matrices)}")
    print(f"Number of out-of-bounds matrices during loading: {out_of_bounds_count / len(all_matrices) * 100:.2f}%")
    return all_matrices

def determine_optimal_alpha(all_matrices):
    print("Converting correlation matrices to NetworkX graphs...")
    windowed_graphs = []
    total_matrices = len(all_matrices)
    for i, matrix in enumerate(all_matrices):
        try:
            windowed_graphs.append(nx.from_numpy_array(matrix))
            # Update progress every 10%
            print(f"Converting matrix {i + 1}/{total_matrices} to graph ({(i + 1) / total_matrices * 100:.2f}%)", end='\r')
        except Exception as e:
            print(f"Skipping matrix {i + 1}/{total_matrices} due to error: {e}")
    print(f"\nGraph conversion completed. Total graphs converted: {len(windowed_graphs)}")
    return aggregated_bootstrapping_and_alpha_threshold(windowed_graphs)

@njit
def test_alpha_numba(threshold_array, alpha):
    # Find indices where the condition is true
    row_indices, col_indices = np.where(threshold_array >= alpha) 

    # Extract values using the indices
    valid_connections = np.extract(threshold_array >= alpha, threshold_array)  

    count = valid_connections.size
    return np.sum(valid_connections) / count if count > 0 else 0

def test_alpha_numba_wrapper(threshold_array, alpha): 
    result = test_alpha_numba(threshold_array, alpha)
    return result

def aggregated_bootstrapping_and_alpha_threshold(windowed_graphs, num_iterations=10000, num_alphas=100):
    if not windowed_graphs:
        raise ValueError("No valid graphs found for processing.")
    
    print("Starting aggregation of edge weights from all windowed graphs...")
    all_edge_weights = []
    total_graphs = len(windowed_graphs)
    edge_count = 0

    for i, G in enumerate(windowed_graphs):
        graph_edge_weights = [data['weight'] for _, _, data in G.edges(data=True)]
        all_edge_weights.extend(graph_edge_weights)
        edge_count += len(graph_edge_weights)
        print(f"Aggregating edge weights progress: {i + 1}/{total_graphs} ({(i + 1) / total_graphs * 100:.2f}%)", end='\r')
        sys.stdout.flush()

    all_edge_weights = np.array(all_edge_weights)
    print(f"\nTotal number of edge weights aggregated: {edge_count}")
    print("Aggregation of edge weights completed.")

    print("Starting bootstrapping on aggregated edge weights...")
    bootstrap_weights = np.zeros_like(all_edge_weights)
    for i in range(num_iterations):
        random_indices = np.random.randint(0, len(all_edge_weights), size=len(all_edge_weights))
        bootstrap_sample = all_edge_weights[random_indices]
        bootstrap_weights += bootstrap_sample
        if (i + 1) % 100 == 0:
            print(f"Bootstrapping progress: {i + 1}/{num_iterations} ({(i + 1) / num_iterations * 100:.2f}%)", end='\r')
    bootstrap_median = np.median(bootstrap_weights / num_iterations)
    print("\nBootstrapping completed.")

    if len(all_edge_weights) == 0:
        raise ValueError("No edge weights found for bootstrapping.")
    
    alpha_start = np.percentile(all_edge_weights, 5) / bootstrap_median
    alpha_end = np.percentile(all_edge_weights, 95) / bootstrap_median
    
    # Pre-calculate threshold arrays here (after bootstrap_median is calculated)
    print("Pre-calculating threshold arrays...")
    threshold_arrays = []
    for G in windowed_graphs:
        connectivity_array = np.asarray(nx.to_numpy_array(G))
        threshold_arrays.append((connectivity_array / bootstrap_median) ** 2)
    
    print("Starting golden-section search for optimal alpha...")
    gr = (np.sqrt(5) + 1) / 2
    c = alpha_end - (alpha_end - alpha_start) / gr
    d = alpha_start + (alpha_end - alpha_start) / gr

    iteration_count = 0
    alpha_values = [c, d]
    fc_values = []
    fd_values = []
    while True:
        iteration_count += 1
        # Calculation using pre-calculated arrays
        fc = np.mean([test_alpha_numba_wrapper(arr, c) for arr in threshold_arrays])
        fd = np.mean([test_alpha_numba_wrapper(arr, d) for arr in threshold_arrays])
        alpha_values.append((c + d) / 2)
        fc_values.append(fc)
        fd_values.append(fd)

        if iteration_count > 1:
            relative_threshold = 0.01 * np.std(alpha_values)
            if abs(c - d) < relative_threshold:
                if len(fc_values) > 5:
                    t_stat, p_value = stats.ttest_rel(fc_values[-5:], fd_values[-5:])
                    if p_value > 0.05:
                        break

        print(f"Testing alphas: {c:.4f}, {d:.4f} - Iteration {iteration_count}", end='\r')
        if fc < fd:
            alpha_end = d
            d = c
            c = alpha_end - (alpha_end - alpha_start) / gr
        else:
            alpha_start = c
            c = d
            d = alpha_start + (alpha_end - alpha_start) / gr

    optimal_alpha = np.median(alpha_values)
    print(f"\nOptimal alpha determined: {optimal_alpha:.4f}")
    print("\nTesting completed")
    return optimal_alpha, bootstrap_median

def save_alpha_and_median(base_dir, optimal_alpha, bootstrap_median):
    save_path = os.path.join(base_dir, "alpha_and_median.json")
    os.makedirs(base_dir, exist_ok=True)
    with open(save_path, 'w') as file:
        json.dump({'optimal_alpha': optimal_alpha, 'bootstrap_median': bootstrap_median}, file)
    print(f"Optimal alpha and bootstrap median saved to: {save_path}")

# Start timing
start_time = time.time()

# Usage
base_dir = "/home/cerna3/neuroconn/data/out/subjects"  # Path to your data
max_matrices_to_load = 50  # Adjust this value as needed
all_matrices = load_all_correlation_matrices(base_dir, participants)
optimal_alpha, bootstrap_median = determine_optimal_alpha(all_matrices)

# End timing
end_time = time.time()

# Calculate elapsed time in minutes
elapsed_time_minutes = (end_time - start_time) / 60
    
print("Optimal Alpha:", optimal_alpha)
print("Bootstrap Median:", bootstrap_median)
print(f"Total time taken: {elapsed_time_minutes:.2f} minutes")

# Save the optimal alpha and bootstrap median
save_alpha_and_median(base_dir, optimal_alpha, bootstrap_median)

Processed 45/45 participants (100.00%)
All participants processed.
Total matrices loaded: 11513
Number of out-of-bounds matrices during loading: 0.00%
Converting correlation matrices to NetworkX graphs...
Converting matrix 11513/11513 to graph (100.00%)
Graph conversion completed. Total graphs converted: 11513
Starting aggregation of edge weights from all windowed graphs...
Aggregating edge weights progress: 11513/11513 (100.00%)
Total number of edge weights aggregated: 60477789

IOStream.flush timed out



Aggregation of edge weights completed.
Starting bootstrapping on aggregated edge weights...


In [7]:
# Thresholding phase: Apply the standardized optimal alpha value calculated from the aggregated and bootstrapped edges (of all participants) to threshold the corr_matrices for each participant

# Load optimal alpha and bootstrap median from JSON file
def load_alpha_and_median(json_path):
    with open(json_path, "r") as f:
        data = json.load(f)
    return data["optimal_alpha"], data["bootstrap_median"]

# Updated file path for the .json
json_path = "/home/cerna3/neuroconn/data/out/subjects/alpha_and_median.json"
optimal_alpha, bootstrap_median = load_alpha_and_median(json_path)

def process_participant(participant_dir, optimal_alpha, bootstrap_median):
    """Processes each participant's data, applies thresholding, and saves results."""
    participant_id = os.path.split(participant_dir)[-1]
    results = {}
    success = True
    thresholded_correlation_matrices = {}  # Initialize the dictionary to store thresholded matrices

    for mode in ["EC", "EO"]:
        mode_dir = os.path.join(participant_dir, mode)
        correlation_matrices = {}
        results[mode] = {}

        for filename in os.listdir(mode_dir):
            if filename.startswith(f"{participant_id}_correlation_matrices"):
                filepath = os.path.join(mode_dir, filename)
                data = np.load(filepath)
                for key in data.files:
                    try:
                        tuple_key = tuple(map(int, key.split('_')))
                        matrix = data[key]  # Directly use the matrix without Fisher's r-to-Z transformation
                        correlation_matrices[tuple_key] = matrix
                    except ValueError:
                        continue  # Skip keys that don't convert to integers

                try:
                    thresholded_matrices, pos_corrs, neg_corrs = threshold_functional_connectivity(
                        correlation_matrices, optimal_alpha, bootstrap_median
                    )
                    results[mode]['correlation_matrices'] = correlation_matrices
                    results[mode]['thresholded_matrices'] = thresholded_matrices
                    thresholded_correlation_matrices.update(thresholded_matrices)  # Store thresholded matrices
                    save_thresholded_data(mode_dir, participant_id, thresholded_matrices, pos_corrs, neg_corrs)
                except Exception as e:
                    success = False
                    break

    return results, success, thresholded_correlation_matrices  # Return the thresholded matrices along with other results

def save_thresholded_data(mode_dir, participant_id, thresholded_matrices, pos_corrs, neg_corrs):
    """Saves thresholded data in .npz format with '_raw' suffix."""
    # Initialize dictionaries to store data with string keys
    thresholded_arrays_to_save = {}
    pos_corrs_to_save = {}
    neg_corrs_to_save = {}

    # Convert keys from tuples to strings and save corresponding data
    for key, matrix in thresholded_matrices.items():
        key_str = "_".join(map(str, key))  # Converts key tuple to a string
        thresholded_arrays_to_save[key_str] = matrix
        pos_corrs_to_save[key_str] = pos_corrs[key]
        neg_corrs_to_save[key_str] = neg_corrs[key]

    # Save the dictionaries using np.savez which can handle multiple arrays in a single file
    np.savez(os.path.join(mode_dir, f"{participant_id}_thresholded_matrices_raw.npz"), **thresholded_arrays_to_save)
    np.savez(os.path.join(mode_dir, f"{participant_id}_thresholded_pos_corrs_raw.npz"), **pos_corrs_to_save)
    np.savez(os.path.join(mode_dir, f"{participant_id}_thresholded_neg_corrs_raw.npz"), **neg_corrs_to_save)

@njit
def apply_threshold(corr_matrix, optimal_alpha_squared, bootstrap_median):
    """Applies the threshold to a single correlation matrix using numba for acceleration."""
    size = corr_matrix.shape[0]
    thresholded_matrix = np.zeros_like(corr_matrix)
    
    for i in range(size):
        for j in range(size):
            normalized_weight = corr_matrix[i, j] / bootstrap_median
            if normalized_weight ** 2 >= optimal_alpha_squared:
                thresholded_matrix[i, j] = corr_matrix[i, j]
    
    return thresholded_matrix

def threshold_functional_connectivity(correlation_matrices, optimal_alpha, bootstrap_median):
    """Applies threshold based on alpha and bootstrap_median to filter correlation matrices."""
    thresholded_correlation_matrices = {}
    thresholded_positive_correlations = {}
    thresholded_negative_correlations = {}
    optimal_alpha_squared = optimal_alpha ** 2  # Calculate squared alpha for the threshold condition

    for key, corr_matrix in correlation_matrices.items():
        thresholded_matrix = apply_threshold(corr_matrix, optimal_alpha_squared, bootstrap_median)
        thresholded_correlation_matrices[key] = thresholded_matrix

        pos_corrs = thresholded_matrix[thresholded_matrix > 0]
        neg_corrs = thresholded_matrix[thresholded_matrix < 0]
        thresholded_positive_correlations[key] = pos_corrs
        thresholded_negative_correlations[key] = neg_corrs

    return thresholded_correlation_matrices, thresholded_positive_correlations, thresholded_negative_correlations

# Start timing
start_time = time.time()

# Collecting thresholded matrices for all participants
all_thresholded_matrices = {}

# Process participants and store results
participants_data = {}
total_participants = len(participants)
processed_participants = [] 
thresholded_participants = []

for idx, participant in enumerate(sorted(participants, key=int)):
    participant_dir = os.path.join(base_dir, participant)
    results, success, thresholded_matrices = process_participant(participant_dir, optimal_alpha, bootstrap_median)
    participants_data[participant] = results

    if success:
        # Store thresholded matrices in a dictionary, if needed
        all_thresholded_matrices[participant] = thresholded_matrices

    sys.stdout.write(f"\rParticipants processed/thresholded: {idx + 1}/{len(participants)}")
    sys.stdout.flush()

print("\nAll participants processed.")

# EDGE COUNT SUMMARY

@njit
def count_edges(matrix):
    """Counts the non-zero entries in the matrix that represent edges."""
    count = 0
    size = matrix.shape[0]
    for i in range(size):
        for j in range(i + 1, size):  # Only count each edge once in an undirected graph
            if matrix[i, j] != 0:
                count += 1
    return count

# Initialize counters
total_unthresholded_edges = 0
total_thresholded_edges = 0
total_edge_counts = 0
thresholded_participants = []
no_modes_processed = []
one_mode_processed = []

for idx, participant in enumerate(sorted(participants_data.keys(), key=int)):
    data = participants_data[participant]
    processed_modes = [mode for mode, results in data.items() if 'correlation_matrices' in results and 'thresholded_matrices' in results]

    if not processed_modes:
        no_modes_processed.append(participant)
    elif len(processed_modes) == 1:
        one_mode_processed.append((participant, processed_modes[0]))
    else:
        thresholded_participants.append(participant)
        # Loop through each processed mode
        for mode in processed_modes:
            results = data[mode]
            unthresholded_graph_edges = 0
            thresholded_graph_edges = 0
            for key, corr_matrix in results['correlation_matrices'].items():
                unthresholded_graph_edges += count_edges(corr_matrix)
                thresholded_graph_edges += count_edges(results['thresholded_matrices'][key])

            # Update total counts for the participant in each mode
            total_unthresholded_edges += unthresholded_graph_edges
            total_thresholded_edges += thresholded_graph_edges
            total_edge_counts += 1  # This counts the number of mode entries processed, not participants

            sys.stdout.write(f"\rProcessing participant {participant}, Mode {mode}: {idx + 1}/{len(participants_data)}")
            sys.stdout.flush()

print("\nAll participants processed.")
if no_modes_processed:
    print("Participants with no modes processed:", ", ".join(no_modes_processed))
if one_mode_processed:
    print("Participants with only one mode processed:")
    for participant, mode in one_mode_processed:
        print(f"  - {participant} (Mode: {mode})")
if thresholded_participants:
    print("Participants with all modes processed and thresholded:", ", ".join(thresholded_participants))

# Calculate and print averages if applicable
if total_edge_counts > 0:
    average_unthresholded = total_unthresholded_edges / total_edge_counts
    average_thresholded = total_thresholded_edges / total_edge_counts
    average_pruned = average_unthresholded - average_thresholded

    print(f"Average unthresholded edge count: {average_unthresholded:.2f}")
    print(f"Average thresholded edge count: {average_thresholded:.2f}")
    print(f"Average number of pruned edges: {average_pruned:.2f}")
else:
    print("No edge count data available.")
    
# End timing
end_time = time.time()

# Calculate elapsed time in minutes
elapsed_time_minutes = (end_time - start_time) / 60  # Convert seconds to minutes
print(f"Total time taken: {elapsed_time_minutes:.2f} minutes")

Participants processed/thresholded: 45/45
All participants processed.
Processing participant 416, Mode EO: 45/45
All participants processed.
Participants with all modes processed and thresholded: 101, 102, 103, 104, 105, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 202, 205, 206, 207, 208, 209, 210, 211, 214, 215, 216, 217, 218, 219, 221, 401, 402, 403, 404, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416
Average unthresholded edge count: 658927.37
Average thresholded edge count: 646065.43
Average number of pruned edges: 12861.93
Total time taken: 2.14 minutes


In [8]:
# Functional Connectivity Calculation (within/betweeen/aggregated-within and between)

def load_thresholded_matrices(participants):
    """Loads thresholded matrices for all participants."""
    thresholded_correlation_matrices = {}
    for participant in sorted(participants):
        participant_dir = os.path.join(base_dir, participant)
        for mode in ["EC", "EO"]:
            mode_dir = os.path.join(participant_dir, mode)
            filepath = os.path.join(mode_dir, f"{participant}_thresholded_matrices_raw.npz")
            if os.path.exists(filepath):
                data = np.load(filepath)
                for key in data.files:
                    try:
                        tuple_key = tuple(map(int, key.split('_')))
                        thresholded_correlation_matrices[(participant, mode, tuple_key)] = data[key]
                    except ValueError:
                        continue
    return thresholded_correlation_matrices

# Load optimal alpha and bootstrap median from JSON file
def load_alpha_and_median(json_path):
    with open(json_path, "r") as f:
        data = json.load(f)
    return data["optimal_alpha"], data["bootstrap_median"]

# Updated file path for the .json
json_path = "/home/cerna3/neuroconn/data/out/subjects/alpha_and_median.json"
optimal_alpha, bootstrap_median = load_alpha_and_median(json_path)

# Start timer
start_time_1 = time.time()

# After processing all participants:
thresholded_correlation_matrices = load_thresholded_matrices(participants_data.keys())

# Load labels from the atlas 
labels = [
    '7Networks_LH_Cont_Cing_1-lh',
    '7Networks_LH_Cont_Par_1-lh',
    '7Networks_LH_Cont_PFCl_1-lh',
    '7Networks_LH_Cont_pCun_1-lh',
    '7Networks_LH_Default_Par_1-lh',
    '7Networks_LH_Default_Par_2-lh',
    '7Networks_LH_Default_pCunPCC_1-lh',
    '7Networks_LH_Default_pCunPCC_2-lh',
    '7Networks_LH_Default_PFC_1-lh',
    '7Networks_LH_Default_PFC_2-lh',
    '7Networks_LH_Default_PFC_3-lh',
    '7Networks_LH_Default_PFC_4-lh',
    '7Networks_LH_Default_PFC_5-lh',
    '7Networks_LH_Default_PFC_6-lh',
    '7Networks_LH_Default_PFC_7-lh',
    '7Networks_LH_Default_Temp_1-lh',
    '7Networks_LH_Default_Temp_2-lh',
    '7Networks_LH_DorsAttn_FEF_1-lh',
    '7Networks_LH_DorsAttn_Post_1-lh',
    '7Networks_LH_DorsAttn_Post_2-lh',
    '7Networks_LH_DorsAttn_Post_3-lh',
    '7Networks_LH_DorsAttn_Post_4-lh',
    '7Networks_LH_DorsAttn_Post_5-lh',
    '7Networks_LH_DorsAttn_Post_6-lh',
    '7Networks_LH_DorsAttn_PrCv_1-lh',
    '7Networks_LH_Limbic_OFC_1-lh',
    '7Networks_LH_Limbic_TempPole_1-lh',
    '7Networks_LH_Limbic_TempPole_2-lh',
    '7Networks_LH_SalVentAttn_FrOperIns_1-lh',
    '7Networks_LH_SalVentAttn_FrOperIns_2-lh',
    '7Networks_LH_SalVentAttn_Med_1-lh',
    '7Networks_LH_SalVentAttn_Med_2-lh',
    '7Networks_LH_SalVentAttn_Med_3-lh',
    '7Networks_LH_SalVentAttn_ParOper_1-lh',
    '7Networks_LH_SalVentAttn_PFCl_1-lh',
    '7Networks_LH_SomMot_1-lh',
    '7Networks_LH_SomMot_2-lh',
    '7Networks_LH_SomMot_3-lh',
    '7Networks_LH_SomMot_4-lh',
    '7Networks_LH_SomMot_5-lh',
    '7Networks_LH_SomMot_6-lh',
    '7Networks_LH_Vis_1-lh',
    '7Networks_LH_Vis_2-lh',
    '7Networks_LH_Vis_3-lh',
    '7Networks_LH_Vis_4-lh',
    '7Networks_LH_Vis_5-lh',
    '7Networks_LH_Vis_6-lh',
    '7Networks_LH_Vis_7-lh',
    '7Networks_LH_Vis_8-lh',
    '7Networks_LH_Vis_9-lh',
    '7Networks_RH_Cont_Cing_1-rh',
    '7Networks_RH_Cont_Par_1-rh',
    '7Networks_RH_Cont_Par_2-rh',
    '7Networks_RH_Cont_PFCl_1-rh',
    '7Networks_RH_Cont_PFCl_2-rh',
    '7Networks_RH_Cont_PFCl_3-rh',
    '7Networks_RH_Cont_PFCl_4-rh',
    '7Networks_RH_Cont_PFCmp_1-rh',
    '7Networks_RH_Cont_pCun_1-rh',
    '7Networks_RH_Default_Par_1-rh',
    '7Networks_RH_Default_pCunPCC_1-rh',
    '7Networks_RH_Default_pCunPCC_2-rh',
    '7Networks_RH_Default_PFCdPFCm_1-rh',
    '7Networks_RH_Default_PFCdPFCm_2-rh',
    '7Networks_RH_Default_PFCdPFCm_3-rh',
    '7Networks_RH_Default_PFCv_1-rh',
    '7Networks_RH_Default_PFCv_2-rh',
    '7Networks_RH_Default_Temp_1-rh',
    '7Networks_RH_Default_Temp_2-rh',
    '7Networks_RH_Default_Temp_3-rh',
    '7Networks_RH_DorsAttn_FEF_1-rh',
    '7Networks_RH_DorsAttn_Post_1-rh',
    '7Networks_RH_DorsAttn_Post_2-rh',
    '7Networks_RH_DorsAttn_Post_3-rh',
    '7Networks_RH_DorsAttn_Post_4-rh',
    '7Networks_RH_DorsAttn_Post_5-rh',
    '7Networks_RH_DorsAttn_PrCv_1-rh',
    '7Networks_RH_Limbic_OFC_1-rh',
    '7Networks_RH_Limbic_TempPole_1-rh',
    '7Networks_RH_SalVentAttn_FrOperIns_1-rh',
    '7Networks_RH_SalVentAttn_Med_1-rh',
    '7Networks_RH_SalVentAttn_Med_2-rh',
    '7Networks_RH_SalVentAttn_TempOccPar_1-rh',
    '7Networks_RH_SalVentAttn_TempOccPar_2-rh',
    '7Networks_RH_SomMot_1-rh',
    '7Networks_RH_SomMot_2-rh',
    '7Networks_RH_SomMot_3-rh',
    '7Networks_RH_SomMot_4-rh',
    '7Networks_RH_SomMot_5-rh',
    '7Networks_RH_SomMot_6-rh',
    '7Networks_RH_SomMot_7-rh',
    '7Networks_RH_SomMot_8-rh',
    '7Networks_RH_Vis_1-rh',
    '7Networks_RH_Vis_2-rh',
    '7Networks_RH_Vis_3-rh',
    '7Networks_RH_Vis_4-rh',
    '7Networks_RH_Vis_5-rh',
    '7Networks_RH_Vis_6-rh',
    '7Networks_RH_Vis_7-rh',
    '7Networks_RH_Vis_8-rh'
]
labels = np.array(labels)

# Initialize a dictionary to hold the network assignments
networks = {
    'Visual': [],
    'Somatomotor': [],
    'DorsalAttention': [],
    'VentralAttention': [],
    'Limbic': [],
    'Frontoparietal': [],
    'Default': []
}

def get_network_name(label_name):
    """Extract the network name from the label name."""
    if 'Vis' in label_name:
        return 'Visual'
    elif 'SomMot' in label_name:
        return 'Somatomotor'
    elif 'DorsAttn' in label_name:
        return 'DorsalAttention'
    elif 'SalVentAttn' in label_name or 'VentAttn' in label_name:
        return 'VentralAttention'
    elif 'Limbic' in label_name:
        return 'Limbic'
    elif 'Cont' in label_name or 'Frontoparietal' in label_name:
        return 'Frontoparietal'
    elif 'Default' in label_name:
        return 'Default'
    else:
        return None

def get_region_name(label_name):
    """Extract the region name from the label name."""
    parts = label_name.split('-')
    network_name = get_network_name(label_name)
    if network_name == 'Visual' or network_name == 'Somatomotor':
        region_info = parts[0].split('_')[2:]
    else:
        region_info = parts[0].split('_')[3:]
    hemisphere = parts[1]
    return '_'.join(region_info) + '-' + hemisphere

for i, label in enumerate(labels):
    network_name = get_network_name(label)
    if network_name:
        region_name = get_region_name(label)
        networks[network_name].append((i, region_name))

# Define calculate_network_connectivity function 
def calculate_network_connectivity(thresholded_correlation_matrices, networks):
    """
    Calculate connectivity within and between networks using thresholded correlation matrices.

    Args:
        thresholded_correlation_matrices (dict): Dictionary containing thresholded correlation matrices 
            for each participant, mode, and state-window combination.
        networks (dict): Dictionary mapping network names to lists of region indices and names.

    Returns:
        tuple: A tuple containing two dictionaries: within_network_connectivity and 
            between_network_connectivity.
            - within_network_connectivity: Connectivity values within each network for each state-window.
            - between_network_connectivity: Connectivity values between different networks for each state-window.
    """
    print("Starting within/between network connectivity calculations...")
    within_network_connectivity = {}
    between_network_connectivity = {}
    networks_indices = {k: [r[0] for r in v] for k, v in networks.items()} 

    participants_processed = set()
    participant_count = len(participants_data)
    current_count = 0

    for (participant, mode, window), corr_matrix in sorted(thresholded_correlation_matrices.items()):
        # Initialize list for within-network connectivity
        within_network_connectivity[(participant, mode, window)] = []  
        # Initialize list for between-network connectivity
        between_network_connectivity[(participant, mode, window)] = []  

        for network_name, regions in networks.items():
            # Calculate within-network connectivity
            network_corr_matrix = corr_matrix[np.ix_(networks_indices[network_name], networks_indices[network_name])] 
            upper_tri_indices = np.triu_indices_from(network_corr_matrix, k=1)
            for i, j in zip(*upper_tri_indices):  
                region1 = regions[i][1]  
                region2 = regions[j][1] 
                corr_value = network_corr_matrix[i, j]  
                within_network_connectivity[(participant, mode, window)].append(f"[{network_name}]: {region1} - {corr_value:.2f} - {region2}")  

        # Calculate between-network connectivity
        for net1, net2 in itertools.combinations(networks.keys(), 2): 
            regions1, regions2 = networks[net1], networks[net2]
            between_corr_matrix = corr_matrix[np.ix_(networks_indices[net1], networks_indices[net2])] 
            for i, j in itertools.product(range(len(regions1)), range(len(regions2))):
                region1 = regions1[i][1] 
                region2 = regions2[j][1] 
                corr_value = between_corr_matrix[i, j] 
                between_network_connectivity[(participant, mode, window)].append(f"[{net1}, {net2}]: {region1} - {corr_value:.2f} - {region2}")  

        if participant not in participants_processed:
            ec_processed = (participant, "EC", window) in thresholded_correlation_matrices
            eo_processed = (participant, "EO", window) in thresholded_correlation_matrices

            if ec_processed and eo_processed:
                current_count += 1
                participants_processed.add(participant)
                sys.stdout.write(f"\rWithin/Between network connectivity calculated for participant: {participant} for modes EC and EO ({current_count}/{participant_count})")
                sys.stdout.flush()
            else:
                print(f"\rError processing participant: {participant}. Mode EC or EO not successfully processed.")
                

    print("\nCompleted calculation of network connectivity for all participants.")
    return within_network_connectivity, between_network_connectivity

# Calculate within and between network connectivity
start_time = time.time()
within_network_conn_values, between_network_conn_values = calculate_network_connectivity(thresholded_correlation_matrices, networks)
print(f"Total time taken for calculation: {(time.time() - start_time) / 60:.2f} minutes")


def aggregate_network_connectivity(thresholded_correlation_matrices, networks, median_optimal_states):
    """
    Aggregate network connectivity measures for each state by averaging across windows.
    NOTE: Add back error handling if need to troubleshoot

    Args:
        thresholded_correlation_matrices (dict): Dictionary containing thresholded correlation matrices 
            for each participant, mode, and state-window combination. Keys are tuples
            in the form (participant, mode, (state, window_number)).
        networks (dict): Dictionary mapping network names to lists of region indices and names.
        median_optimal_states (int): The optimal number of states identified.

    Returns:
        tuple: A tuple containing two dictionaries: aggregated_within_network_connectivity and 
            aggregated_between_network_connectivity.
            - aggregated_within_network_connectivity: Average connectivity within each network for each state.
            - aggregated_between_network_connectivity: Average connectivity between different networks for each state.
    """
    print("Starting aggregation of network connectivity...")
    aggregated_within_network_connectivity = {}
    aggregated_between_network_connectivity = {}
    networks_indices = {k: [r[0] for r in v] for k, v in networks.items()}

    participants_processed = set()
    participant_count = len(participants_data)
    current_count = 0
    error_participants = set()  # Use a set to avoid duplicates
    max_message_length = 0

    # Convert median_optimal_states to an integer
    median_optimal_states = int(median_optimal_states)

    # Initialize the connectivity dictionaries for each state
    for state in range(median_optimal_states):
        aggregated_within_network_connectivity[state] = {network: [] for network in networks.keys()}
        aggregated_between_network_connectivity[state] = {(net1, net2): [] for net1, net2 in itertools.combinations(networks.keys(), 2)}

    # Aggregate connectivity measures for each state
    for (participant, mode, (state, window_number)), corr_matrix in sorted(thresholded_correlation_matrices.items()):
        # Ensure the state key exists in the dictionaries
        if state not in aggregated_within_network_connectivity:
            aggregated_within_network_connectivity[state] = {network: [] for network in networks.keys()}
        if state not in aggregated_between_network_connectivity:
            aggregated_between_network_connectivity[state] = {(net1, net2): [] for net1, net2 in itertools.combinations(networks.keys(), 2)}

        # Within-network connectivity
        for network_name, regions in networks.items():
            network_corr_matrix = corr_matrix[np.ix_(networks_indices[network_name], networks_indices[network_name])]
            upper_tri_indices = np.triu_indices_from(network_corr_matrix, k=1)
            mean_corr = np.mean(network_corr_matrix[upper_tri_indices])
            if not np.isnan(mean_corr):  # Check if mean_corr is not NaN
                aggregated_within_network_connectivity[state][network_name].append(mean_corr)

        # Between-network connectivity
        for net1, net2 in itertools.combinations(networks.keys(), 2):
            regions1 = networks_indices[net1]
            regions2 = networks_indices[net2]
            between_corr_matrix = corr_matrix[np.ix_(regions1, regions2)]
            mean_corr = np.mean(between_corr_matrix)
            if not np.isnan(mean_corr):  # Check if mean_corr is not NaN
                aggregated_between_network_connectivity[state][(net1, net2)].append(mean_corr)
                
        # except KeyError as e:
        #     error_message = f"Error processing participant: {participant}, mode: {mode}, state: {state}. Reason: KeyError: {e.args[0]} not found"
        #     sys.stdout.write('\r' + ' ' * max_message_length + '\r')  # Clear the line
        #     sys.stdout.write(f"\r{error_message}")
        #     sys.stdout.flush()
        #     error_participants.add(participant)
        #     max_message_length = max(max_message_length, len(error_message))
        #     continue  # Skip the rest of the loop for this participant/mode/state

        # except Exception as e:
        #     error_message = f"Error processing participant: {participant}, mode: {mode}, state: {state}. Reason: {type(e).__name__}: {e}"
        #     sys.stdout.write('\r' + ' ' * max_message_length + '\r')  # Clear the line
        #     sys.stdout.write(f"\r{error_message}")
        #     sys.stdout.flush()
        #     error_participants.add(participant)
        #     max_message_length = max(max_message_length, len(error_message))
        #     continue  # Skip the rest of the loop for this participant/mode/state

        if participant not in participants_processed:
            ec_processed = any((participant, "EC", (state, wn)) in thresholded_correlation_matrices for wn in range(window_number + 1))
            eo_processed = any((participant, "EO", (state, wn)) in thresholded_correlation_matrices for wn in range(window_number + 1))

            if ec_processed and eo_processed:
                current_count += 1
                participants_processed.add(participant)
                elapsed_time_participant = time.time() - start_time
                progress_message = f"Aggregated network connectivity calculated for participant: {participant} for modes EC and EO ({current_count}/{participant_count})."
                sys.stdout.write('\r' + ' ' * max_message_length + '\r')  # Clear the line
                sys.stdout.write(f"\r{progress_message}")
                sys.stdout.flush()
                max_message_length = max(max_message_length, len(progress_message))
                
            # else:
            #     error_message = f"Error processing participant: {participant}. Mode EC or EO not successfully processed for state {state}."
            #     sys.stdout.write('\r' + ' ' * max_message_length + '\r')  # Clear the line
            #     sys.stdout.write(f"\r{error_message}")
            #     sys.stdout.flush()
            #     error_participants.add(participant)
            #     max_message_length = max(max_message_length, len(error_message))

    # Calculate average connectivity for each state
    for state in range(median_optimal_states):
        for network in aggregated_within_network_connectivity[state]:
            if aggregated_within_network_connectivity[state][network]:  # Check if list is not empty
                aggregated_within_network_connectivity[state][network] = np.mean(aggregated_within_network_connectivity[state][network])
            else:
                aggregated_within_network_connectivity[state][network] = np.nan  # Assign NaN if list is empty

        for net_pair in aggregated_between_network_connectivity[state]:
            if aggregated_between_network_connectivity[state][net_pair]:  # Check if list is not empty
                aggregated_between_network_connectivity[state][net_pair] = np.mean(aggregated_between_network_connectivity[state][net_pair])
            else:
                aggregated_between_network_connectivity[state][net_pair] = np.nan  # Assign NaN if list is empty

    print("\nCompleted aggregation of network connectivity for all participants.")
   # if error_participants:
   #     print(f"\nParticipants with errors: {', '.join(sorted(error_participants))}")

    return aggregated_within_network_connectivity, aggregated_between_network_connectivity

# Call the function and measure execution time
start_time = time.time()
aggregated_within_conn, aggregated_between_conn = aggregate_network_connectivity(thresholded_correlation_matrices, networks, median_optimal_state)
print(f"Total time taken for aggregation: {(time.time() - start_time) / 60:.2f} minutes")

def save_connectivity_results(participant, mode, within_network_conn_values, 
                              between_network_conn_values, aggregated_within_conn, 
                              aggregated_between_conn):
    """Saves connectivity results to separate files within the participant's folder."""
    participant_dir = os.path.join(base_dir, participant)
    mode_dir = os.path.join(participant_dir, mode)
    os.makedirs(mode_dir, exist_ok=True)  # Ensure the directory exists

    try:
        # Save network assignment to a file
        network_assignment_file = os.path.join(mode_dir, f"{participant}_{mode}_network_assignment.npy")
        np.save(network_assignment_file, networks)

        # Save within-network connectivity values
        within_conn_file = os.path.join(mode_dir, f"{participant}_{mode}_within_network_conn_raw.npz")
        within_arrays_to_save = {}
        for key, value in within_network_conn_values.items():
            key_str = "_".join(map(str, key))
            within_arrays_to_save[key_str] = value
        np.savez(within_conn_file, **within_arrays_to_save)

        # Save between-network connectivity values
        between_conn_file = os.path.join(mode_dir, f"{participant}_{mode}_between_network_conn_raw.npz")
        between_arrays_to_save = {}
        for key, value in between_network_conn_values.items():
            key_str = "_".join(map(str, key))
            between_arrays_to_save[key_str] = value
        np.savez(between_conn_file, **between_arrays_to_save)

        # Save aggregated connectivity measures
        aggregated_conn_file = os.path.join(mode_dir, f"{participant}_{mode}_aggregated_conn_raw.npz")
        np.savez(aggregated_conn_file, within_conn=aggregated_within_conn, 
                 between_conn=aggregated_between_conn)
        
        print(f"Successfully saved connectivity results for participant: {participant}, mode: {mode}") 
    except Exception as e:
        print(f"\nError saving results for participant: {participant}, mode: {mode}. Reason: {str(e)}")
    

# Call the save function for each participant and mode
for participant, results in participants_data.items():
    for mode, data in results.items():
        if 'correlation_matrices' in data and 'thresholded_matrices' in data:
            # Extract the connectivity values for the current participant and mode
            within_conn_values = {
                key: value for key, value in within_network_conn_values.items() if key[0] == participant and key[1] == mode
            }
            between_conn_values = {
                key: value for key, value in between_network_conn_values.items() if key[0] == participant and key[1] == mode
            }
            # Save the results
            save_connectivity_results(participant, mode, within_conn_values,
                                      between_conn_values, aggregated_within_conn,
                                      aggregated_between_conn)
            
end_time = time.time()  
elapsed_time_minutes = (end_time - start_time_1) / 60
print(f"Total time taken: {elapsed_time_minutes:.2f} minutes")

Starting within/between network connectivity calculations...
Within/Between network connectivity calculated for participant: 416 for modes EC and EO (45/45)
Completed calculation of network connectivity for all participants.
Total time taken for calculation: 3.54 minutes
Starting aggregation of network connectivity...
Aggregated network connectivity calculated for participant: 416 for modes EC and EO (45/45).
Completed aggregation of network connectivity for all participants.
Total time taken for aggregation: 0.46 minutes
Successfully saved connectivity results for participant: 101, mode: EC
Successfully saved connectivity results for participant: 101, mode: EO
Successfully saved connectivity results for participant: 102, mode: EC
Successfully saved connectivity results for participant: 102, mode: EO
Successfully saved connectivity results for participant: 103, mode: EC
Successfully saved connectivity results for participant: 103, mode: EO
Successfully saved connectivity results for pa

In [10]:
# SANITY CHECK!

def load_all_matrices(participants, base_dir):
    """Loads all relevant matrices for analysis including thresholded, within-network, between-network, and aggregated connectivity matrices."""
    matrices = {
        'thresholded': {},
        'within_network': {},
        'between_network': {},
        'aggregated_within': {},
        'aggregated_between': {}
    }

    for participant in participants:
        participant_dir = os.path.join(base_dir, participant)
        for mode in ["EC", "EO"]:
            mode_dir = os.path.join(participant_dir, mode)
            
            # Load thresholded matrices
            thresholded_path = os.path.join(mode_dir, f"{participant}_thresholded_matrices_raw.npz")
            if os.path.exists(thresholded_path):
                data = np.load(thresholded_path, allow_pickle=True)
                for key in data.files:
                    try:
                        tuple_key = (participant, mode, tuple(map(int, key.split('_'))))
                        matrices['thresholded'][tuple_key] = data[key]
                    except ValueError:
                        continue  # Skip keys that don't convert to integers
            
            # Load within-network connectivity
            within_path = os.path.join(mode_dir, f"{participant}_{mode}_within_network_conn_raw.npz")
            if os.path.exists(within_path):
                data = np.load(within_path, allow_pickle=True)
                for key in data.files:
                    matrices['within_network'][(participant, mode, key)] = data[key]
            
            # Load between-network connectivity
            between_path = os.path.join(mode_dir, f"{participant}_{mode}_between_network_conn_raw.npz")
            if os.path.exists(between_path):
                data = np.load(between_path, allow_pickle=True)
                for key in data.files:
                    matrices['between_network'][(participant, mode, key)] = data[key]
            
            # Load aggregated connectivity
            aggregated_path = os.path.join(mode_dir, f"{participant}_{mode}_aggregated_conn_raw.npz")
            if os.path.exists(aggregated_path):
                data = np.load(aggregated_path, allow_pickle=True)
                matrices['aggregated_within'][(participant, mode, 'within')] = data['within_conn']
                matrices['aggregated_between'][(participant, mode, 'between')] = data['between_conn']
            else:
                print(f"File missing for participant {participant}, mode {mode}")

    return matrices

def check_bounds(matrix):
    """Check if any values in the matrix are out of the expected correlation bounds [-1, 1]."""
    return np.any((matrix < -1) | (matrix > 1))

def analyze_and_print_results(matrices):
    """Analyzes newly created matrices and prints whether all correlations are within bounds or not."""
    categories_to_check = ['thresholded', 'within_network', 'between_network', 'aggregated_within', 'aggregated_between']
    
    for category in categories_to_check:
        out_of_bounds_found = False
        for matrix in matrices[category].values():
            if isinstance(matrix, np.ndarray) and matrix.ndim == 2:
                if check_bounds(matrix):
                    out_of_bounds_found = True
                    break
        if out_of_bounds_found:
            print(f"Some correlations in {category.replace('_', ' ')} are out of bounds")
        else:
            print(f"All correlations in {category.replace('_', ' ')} are within -1 to +1 bounds")
            
# Main code to load matrices and perform analysis
base_dir = '/home/cerna3/neuroconn/data/out/subjects/'
participants = ['101', '102', '103', '104', '105', '111', '112', '113', '114', '115', '116', '117', '118', '119', '120',
                '202', '205', '206', '207', '208', '209', '210', '211', '214', '215', '216', '217', '218', '219', '221',
                '401', '402', '403', '404', '406', '407', '408', '409', '410', '411', '412', '413', '414', '415', '416']
matrices = load_all_matrices(participants, base_dir)

# Analyze and print results
analyze_and_print_results(matrices)

All correlations in thresholded are within -1 to +1 bounds
All correlations in within network are within -1 to +1 bounds
All correlations in between network are within -1 to +1 bounds
All correlations in aggregated within are within -1 to +1 bounds
All correlations in aggregated between are within -1 to +1 bounds
