In [None]:
# This is for analyzing layer data 
import os
import numpy as np
import scipy
import scipy
from scipy.linalg import orthogonal_procrustes
from scipy.stats import spearmanr
import matplotlib.pyplot as plt
import nibabel as nib
import pyvista as pv
import yaspy
from brainspace.gradient import GradientMaps
from brainspace.gradient.kernels import compute_affinity
import brainspace.gradient.alignment as ga
from brainspace.utils.parcellation import map_to_labels
from scipy.optimize import linear_sum_assignment
from sklearn.linear_model import LinearRegression
import pingouin as pg
import pandas as pd

do_glasser = False
add_bigbrain_data = False
do_parcellation = True#should be TRUE most of the time
combine_hemispheres = False
hcp_save_subpath = 'HCP'
exvivo_save_subpath = 'EXVIVO'



#for bigbrain data if it is included ...? in ex vivo data.
baseDir_bigbrain = '/Users/dennis.jungchildmind.org/Downloads/BigBrain/thickness/resample/'


if do_glasser:
    atlas_data_lh = nib.load('/Users/dennis.jungchildmind.org/Library/CloudStorage/OneDrive-ChildMindInstitute/parcellation/glasser/Glasser_2016.32k.L.label.gii').darrays[0].data
    atlas_data_rh = nib.load('/Users/dennis.jungchildmind.org/Library/CloudStorage/OneDrive-ChildMindInstitute/parcellation/glasser/Glasser_2016.32k.R.label.gii').darrays[0].data  
    num_parcels = 180
else:
    #https://github.com/ThomasYeoLab/CBIG/blob/master/stable_projects/brain_parcellation/Schaefer2018_LocalGlobal/Parcellations/HCP/fslr32k/cifti/Schaefer2018_200Parcels_7Networks_order.dscalar.nii
    atlas_path = '/Users/dennis.jungchildmind.org/OneDrive - Child Mind Institute/parcellation/schaefer2018/Schaefer2018_400Parcels_7Networks_order.dlabel.nii'
    atlas = nib.load(atlas_path).get_fdata()[0].astype(int)
    #split atlas into left and right hemispheres for later use       
    atlas_data_lh = atlas[:len(atlas)//2]
    atlas_data_rh = atlas[len(atlas)//2:]
    atlas_data_rh = atlas_data_rh - np.min(atlas_data_rh[atlas_data_rh != 0])+1
    atlas_data_rh[atlas_data_rh <= 0] = 0
    num_parcels = int(atlas_path.split('_')[1].split('Parcels')[0])  # Extract number of parcels from atlas path

    yeo_networks = nib.load(atlas_path).header.get_axis(0)
    yeo_network_data = yeo_networks.get_element(0)[1]

    #each yeo_network_data[key] has ('17Networks_LH_VisCent_ExStr_1',...), save the third string after string split by '_'
    yeo_network_names = []
    yeo_network_colors = []
    for i,key in enumerate(yeo_network_data):
        if i == 0:
            yeo_network_names.append(np.nan)
            yeo_network_colors.append((0,0,0,0))
            continue#skip the first key
    # print(yeo_network_data[key][0])#network name
        network_parts = yeo_network_data[key][0].split('_')
        #print(network_parts)
        yeo_network_names.append(network_parts[2])  # Append the network name (VisCent)
        yeo_network_colors.append(yeo_network_data[key][1])


In [None]:

def decompose_to_permutation_and_signs(R):
    """
    Decompose orthogonal matrix R into permutation + sign flips
    
    Parameters:
    -----------
    R : array (10, 10)
        Orthogonal matrix from procrustes
    
    Returns:
    --------
    perm_idx : array (10,)
        Permutation indices
    signs : array (10,)
        Sign flips (+1 or -1)
    """
    n = R.shape[0]
    
    # Find the permutation by matching largest absolute values
    # This finds which target component each source component maps to
    abs_R = np.abs(R)
    
    # Use Hungarian algorithm to find optimal assignment
    row_ind, col_ind = linear_sum_assignment(-abs_R)
    
    # Extract signs from the matched elements
    signs = np.sign(R[row_ind, col_ind])
    
    # col_ind gives us the permutation
    perm_idx = col_ind
    
    return perm_idx, signs



# Function to create eigenvalue plot
def plot_eigenvalues(gm_aligned, color='#2E86C1', filename=None, save=False, SAVEFOLDER='./figures/eigenvalues'):
    # Create figure and axis
    fig, ax = plt.subplots(figsize=(4, 4))   
    
    # Calculate and plot normalized eigenvalues 
    eigenvalues = gm_aligned.lambdas_ / np.sum(gm_aligned.lambdas_)
    ax.plot(range(1, len(eigenvalues) + 1), eigenvalues, 'o-', 
            color=color, linewidth=2, markersize=6)
    
    # Customize plot
    ax.set_xlabel('Component', fontsize=18)
    ax.set_ylabel('Variance explained', fontsize=18)
    ax.grid(True, linestyle='--', alpha=0.7)
    ax.set_xlim(0, len(eigenvalues))
    ax.yaxis.set_major_formatter(plt.FormatStrFormatter('%.2f'))
    ax.tick_params(axis='both', which='major', labelsize=20)
    
    plt.tight_layout()
    
    if save:
        # Make folder if it doesn't exist
        if not os.path.exists(SAVEFOLDER):
            os.makedirs(SAVEFOLDER)
        plt.savefig(os.path.join(SAVEFOLDER, f'{filename}_eigenvalues.png'), dpi=300)
        plt.close()
        return
    
    return fig, ax



#parcellation of the data based on the given atlas (atlas_data).
def parcellate_data(data,atlas_data):
    """
    data: numpy array
    atlas_data: numpy array
    hemisphere: string
    """

    #parcellate data
    len_unique = len(np.unique(atlas_data[atlas_data != 0]))
    data_parc = np.zeros((len_unique))
  
    for i in range(len_unique):
        parcel_data = data[atlas_data == i+1]
        # First remove NaN values
        parcel_data = parcel_data[~np.isnan(parcel_data)]
        if len(parcel_data) == 0:
            data_parc[i] = 0
            continue
        # Calculate mean and std of non-NaN values
        parcel_data_mean = np.nanmean(parcel_data)
        parcel_data_std = np.nanstd(parcel_data)
        # Keep only values within 2 std of mean
        # This keeps values that are strictly less than 2*std from the mean.
        # If the intent is to include values *within* 2 std (inclusive), then <= should be used instead of <.
        # Additionally, if the standard deviation is zero, this will remove all but the exact mean.
        # This is a correct method to identify and filter outliers by the classic definition, but may not work as expected if parcel_data_std == 0 or parcel_data has few elements.

        mask = np.abs(parcel_data - parcel_data_mean) < (2 * parcel_data_std)
        filtered_data = parcel_data[mask]
        # Calculate final mean of filtered data
        if len(filtered_data) > 0:
            data_parc[i] = np.nanmean(filtered_data)
        else:
            print(data[atlas_data == i+1])
            print('no data',i)
            data_parc[i] = 0
    return data_parc



def plot_correlation_matrix(corr, transpose, title, vmin,vmax,cmap):
    """Plot correlation matrix with consistent formatting."""
    # Remove diagonal values
    corr_plot = corr.copy()
   # corr_plot[np.eye(corr_plot.shape[0], dtype=bool)] = np.nan
    
    # Plot matrix
    im = plt.imshow(corr_plot, cmap=cmap, vmin=vmin, vmax=vmax)
        
    plt.colorbar(im, fraction=0.046, pad=0.04)
    plt.title(f'{title}',fontsize=16,fontweight='bold')
    if transpose:
        plt.xlabel('Subject',fontsize=16)
        plt.ylabel('Subject',fontsize=16)
    else:
        plt.xlabel('Parcel',fontsize=16)
        plt.ylabel('Parcel',fontsize=16)
    
    return corr

def plot_subject_similarity(hemisphere_data, transpose=False, vmin = None, vmax = None, cmap='viridis'):
    """
    Plot similarity matrices between subjects for different thickness measures and radii.
    
    Args:
        hemisphere_data (dict): Dictionary containing thickness data for one hemisphere
        radii (list): List of radius values used for smoothing
        transpose (bool): Whether to transpose the data matrices
        partial_corr (bool): Whether to calculate partial correlations
        shrink_it (bool): Whether to apply shrinkage to correlation calculation
    """
    # Initialize correlation dictionaries
    correlation_dicts = {
        'total': {},
        'infra': {},
        'supra': {},
        'relative': {},
        'ratio_supra': {},
        'ratio_infra': {}
    }

    first_index = 0
    plt.figure(figsize=(len(correlation_dicts)*4,4))

    # Define measures and their plot positions
    measures = [
        ('total', 'Total Thickness', 1),
        ('infra', 'Infra Thickness', 2), 
        ('supra', 'Supra Thickness', 3),
        ('relative', 'Relative (Supra/Infra)', 4),
        ('ratio_supra', 'Ratio (Supra/Total)', 5),
        ('ratio_infra', 'Ratio (Infra/Total)', 6)
    ]
    
    # Plot each measure
    for measure, title, plot_pos in measures:
        plt.subplot(1, len(measures), plot_pos)
        data = hemisphere_data[measure]

        if transpose:
            data = data[first_index:].T
        else:
            data = data[first_index:]


        #do zscore
       #data = zscore(data)

        # Yes, np.corrcoef calculates Pearson correlation coefficients
        
        corr = np.corrcoef(data)

        #do cosine similarity
        #from sklearn.metrics.pairwise import cosine_similarity
        #corr = cosine_similarity(data)

        '''
        from scipy.stats import kendalltau

        # Calculate Kendall's tau for all pairs, building the matrix
        n = data.shape[0]
        corr = np.zeros((n, n))
        for i in range(n):
            for j in range(n):
                # kendalltau returns (correlation, p-value)
                tau, _ = kendalltau(data[i], data[j])
                corr[i, j] = tau
        '''
        
        '''
        import dcor
        # Compute a distance correlation matrix (n_samples x n_samples)
        # Assume data.shape = (regions, subjects) if not transposed, else (subjects, regions)
        if transpose:
            n = data.shape[0]
            corr = np.zeros((n, n))
            for i in range(n):
                for j in range(n):
                    corr[i, j] = dcor.distance_correlation(data[i, :], data[j, :])
        else:
            n = data.shape[0]
            corr = np.zeros((n, n))
            for i in range(n):
                for j in range(n):
                    corr[i, j] = dcor.distance_correlation(data[i, :], data[j, :])
        
        '''
        '''
        from scipy.stats import spearmanr

        # Calculate Spearman's rank correlation matrix
        # spearmanr returns both correlation matrix and p-values matrix
        corr, _ = spearmanr(data, axis=0 if transpose else 1)
        # If corr is a scalar, expand to 2D for consistency
        if np.isscalar(corr):
            corr = np.array([[corr]])
        '''
        # Store and plot correlation matrix
        correlation_dicts[measure][0] = plot_correlation_matrix(
            corr,
            transpose,
            title,
            vmin,
            vmax,
            cmap
        )

    plt.tight_layout()
    
    return tuple(correlation_dicts[k] for k in ['total', 'infra', 'supra', 'relative', 'ratio_supra', 'ratio_infra'])




def plot_subject_similarity_hcp(hemisphere_data, transpose=False, vmin = None, vmax = None, cmap='viridis'):
    """
    Plot similarity matrices between subjects for different thickness measures and radii.

    Args:
        hemisphere_data (dict): Dictionary containing thickness data for one hemisphere
        radii (list): List of radius values used for smoothing
        transpose (bool): Whether to transpose the data matrices
        partial_corr (bool): Whether to calculate partial correlations
        shrink_it (bool): Whether to apply shrinkage to correlation calculation
    """
    # Initialize correlation dictionaries
    correlation_dicts = {
        'total': {},

    }

    first_index = 0
    plt.figure(figsize=(len(correlation_dicts)*4,4))

    # Define measures and their plot positions
    measures = [
        ('total', 'Total Thickness', 1),
    ]

    # Plot each measure
    for measure, title, plot_pos in measures:
        plt.subplot(1, len(measures), plot_pos)
        data = hemisphere_data[measure]

        if transpose:
            data = data[first_index:].T
        else:
            data = data[first_index:]


        #do zscore
        #data = zscore(data)

        # Yes, np.corrcoef calculates Pearson correlation coefficients
        
        corr = np.corrcoef(data)
        '''
        import dcor
        # Compute a distance correlation matrix (n_samples x n_samples)
        # Assume data.shape = (regions, subjects) if not transposed, else (subjects, regions)
        if transpose:
            n = data.shape[0]
            corr = np.zeros((n, n))
            for i in range(n):
                for j in range(n):
                    corr[i, j] = dcor.distance_correlation(data[i, :], data[j, :])
        else:
            n = data.shape[0]
            corr = np.zeros((n, n))
            for i in range(n):
                for j in range(n):
                    corr[i, j] = dcor.distance_correlation(data[i, :], data[j, :])
        '''
        '''
        from scipy.stats import spearmanr

        # Calculate Spearman's rank correlation matrix
        # spearmanr returns both correlation matrix and p-values matrix
        corr, _ = spearmanr(data, axis=0 if transpose else 1)
        # If corr is a scalar, expand to 2D for consistency
        if np.isscalar(corr):
            corr = np.array([[corr]])
        '''
        # Store and plot correlation matrix
        correlation_dicts[measure][0] = plot_correlation_matrix(
            corr,
            transpose,
            title,
            vmin,
            vmax,
            cmap
        )

    plt.tight_layout()

    return correlation_dicts['total']


In [None]:
#load '/Users/dennis.jungchildmind.org/Desktop/exvivo_postslurm/at_pial_surface/output_960um_method0/output_120um_max_960um_dist_method0/I41_new_confidence/lh/pial_120um_method0_manual_raw_intensity.npz'
tmp = np.load('/Users/dennis.jungchildmind.org/Desktop/exvivo_postslurm/at_pial_surface/output_960um_method0/output_120um_max_960um_dist_method0/I41_new_confidence/lh/pial_120um_method0_manual_raw_intensity.npz')
print(tmp['all_values'].shape)
# Get the value of all_values at the middle_index of dist_array
dist_array = tmp['dist_array']
middle_index = len(dist_array) // 2
all_values_middle = tmp['all_values'][middle_index]
print("Value of all_values at middle_index:", all_values_middle.shape)

In [None]:

# === Main Analysis Block ===

def clean_data(data_array):
    """Clean NaN and Inf values from array using masked arrays."""
    masked_array = np.ma.masked_invalid(data_array)
    return np.array(masked_array.filled(0))

def load_thickness_hcp_data(base_dir, subject_dir, hemi, suffix):
    """Load thickness measurements for HCP data."""
    file_stub = f"{base_dir}{subject_dir}/{subject_dir}."
    if hemi == 'lh':
        file_path = file_stub + "L.thickness.32k_6mm_fwhm_fs_LR.shape.gii"
    else:
        file_path = file_stub + "R.thickness.32k_6mm_fwhm_fs_LR.shape.gii"
    return {'total': nib.load(file_path).darrays[0].data}

def load_thickness_exvivo_data(base_dir, subject_dir, hemi, suffix):
    """Load thickness measurements for ex vivo data."""
    return {
        'infra': nib.load(f'{base_dir}{subject_dir}/{hemi}.thickness.wm.inf.{suffix}.shape.gii').darrays[0].data,
        'supra': nib.load(f'{base_dir}{subject_dir}/{hemi}.thickness.inf.pial.{suffix}.shape.gii').darrays[0].data,
        'total': nib.load(f'{base_dir}{subject_dir}/{hemi}.thickness.{suffix}.shape.gii').darrays[0].data
    }

def load_thickness_bigbrain_data(base_dir, hemi):
    """Load thickness measurements from BigBrain."""
    return {
        'infra': nib.load(f'{base_dir}/{hemi}.3-6.32k.shape.gii').darrays[0].data,  # layers 4,5,6 (ex vivo style)
        'supra': nib.load(f'{base_dir}/{hemi}.0-3.32k.shape.gii').darrays[0].data,  # layers 1,2,3
        'total': nib.load(f'{base_dir}/{hemi}.0-6.32k.shape.gii').darrays[0].data   # all layers
    }

def calculate_derived_measurements(thickness_data):
    """Calculate derived thickness measurements."""
    return {
        'relative': clean_data(np.divide(thickness_data['supra'], thickness_data['infra'],
                                         out=np.zeros_like(thickness_data['supra']), where=thickness_data['infra'] != 0)),
        'ratio_supra': clean_data(np.divide(thickness_data['supra'], thickness_data['total'],
                                            out=np.zeros_like(thickness_data['supra']), where=thickness_data['total'] != 0)),
        'ratio_infra': clean_data(np.divide(thickness_data['infra'], thickness_data['total'],
                                            out=np.zeros_like(thickness_data['infra']), where=thickness_data['total'] != 0)),
        'diff': clean_data((thickness_data['infra'] - thickness_data['supra']) /
                           (thickness_data['infra'] + thickness_data['supra']))
    }

# --- Analysis Loop for HCP and Ex Vivo data ---
for run_hcp_data in [True, False]:

    # Set up directories
    if run_hcp_data:
        # For HCP
        baseDir = '/Users/dennis.jungchildmind.org/OneDrive - Child Mind Institute/hcp1200-thickness/'
        data_types = ['total']
    else:
        # For ex vivo
        baseDir = '/Users/dennis.jungchildmind.org/Downloads/exvivo/'
        # This setting is specifically for the ex vivo data
        data_types = ['total', 'infra', 'supra', 'relative', 'ratio_supra', 'ratio_infra', 'diff']

    hemispheres = {'lh': 'left', 'rh': 'right'}

    # Get all subject directories in the baseDir
    subject_dirs = [d for d in os.listdir(baseDir) if os.path.isdir(os.path.join(baseDir, d))]
        
    # Ex vivo-specific ordering of subjects
    if not run_hcp_data:
        # Move the subject containing 'I54' to the beginning of the list (should have both hemispheres)
        l54_idx = next(i for i, s in enumerate(subject_dirs) if 'I54' in s)
        print(f"Index of L54: {l54_idx}")
        subject_dirs.insert(0, subject_dirs.pop(l54_idx))

    # Initialize data dictionary for every hemisphere and type
    data = {hemi: {dtype: [] for dtype in data_types} for hemi in hemispheres}

    # Helper lists to keep track of included subject names
    lh_subjects, rh_subjects = [], []
    lh_subjects_name, rh_subjects_name = [], []

    # Preallocate (optional usage)
    all_values_middle_array = {hemi: {dtype: [] for dtype in data_types} for hemi in hemispheres}

    # -------- Data Loading --------
    for subjectDir in subject_dirs:
        for hemi in hemispheres:
            # Construct appropriate file path
            if run_hcp_data:
                surf_file = f"{baseDir}{subjectDir}/{subjectDir}.{'L' if hemi == 'lh' else 'R'}.thickness.32k_6mm_fwhm_fs_LR.shape.gii"
            else:
                surf_file = f'{baseDir}{subjectDir}/{hemi}.pial.32k_fs_LR.surf.gii'
            
            if os.path.exists(surf_file):
                # Update subject lists
                if hemi == 'lh':
                    lh_subjects.append(subjectDir)
                    lh_subjects_name.append(subjectDir.replace('_new_confidence', ''))
                else:
                    rh_subjects.append(subjectDir)
                    rh_subjects_name.append(subjectDir.replace('_new_confidence', ''))

                # Suffix is consistent for both
                suffix = '32k_6mm_fwhm_fs_LR'

                # Load and process thickness data for subject and hemisphere
                if run_hcp_data:
                    thickness_data = load_thickness_hcp_data(baseDir, subjectDir, hemi, suffix)
                else:
                    thickness_data = load_thickness_exvivo_data(baseDir, subjectDir, hemi, suffix)
                    thickness_data.update(calculate_derived_measurements(thickness_data))

                # Store data for each type efficiently using numpy
                for dtype in data_types:
                    reshaped = thickness_data[dtype].reshape(-1, 1)
                    if len(data[hemi][dtype]) > 0:
                        data[hemi][dtype] = np.concatenate((data[hemi][dtype], reshaped), axis=1)
                    else:
                        data[hemi][dtype] = reshaped

    # -------- BigBrain Data (ex vivo only) --------
    if add_bigbrain_data and not run_hcp_data:
        for hemi in hemispheres:
            thickness_data = load_thickness_bigbrain_data(baseDir_bigbrain, hemi)
            thickness_data.update(calculate_derived_measurements(thickness_data))
            for dtype in data_types:
                reshaped = thickness_data[dtype].reshape(-1, 1)
                if len(data[hemi][dtype]) > 0:
                    data[hemi][dtype] = np.concatenate((data[hemi][dtype], reshaped), axis=1)
                else:
                    data[hemi][dtype] = reshaped
        lh_subjects_name.append('bigbrain')
        rh_subjects_name.append('bigbrain')

    # -------- Combine Hemispheres if needed --------
    # Combine before parcellation because left/right hemispheres may have different parcellations
    if combine_hemispheres and not run_hcp_data:
        for key in data['lh'].keys():
            data['lh'][key] = np.concatenate((data['lh'][key], data['rh'][key]), axis=1)
            print(data['lh'][key].shape)
        data['rh'] = data['lh'].copy()
        lh_subjects_name = lh_subjects_name + rh_subjects_name
        rh_subjects_name = lh_subjects_name.copy()

    # -------- Parcellation if specified --------
    if do_parcellation:
        for hemi in hemispheres:
            # Choose the appropriate atlas_data for each hemisphere
            atlas_data = atlas_data_lh if hemi == 'lh' else atlas_data_rh
            for dtype in data_types:
                print('hemi', hemi, 'dtype', dtype)
                tmp_data = []
                for i in range(data[hemi][dtype].shape[1]):
                    print('i', i)
                    parcellated = parcellate_data(data[hemi][dtype][:, i], atlas_data)
                    tmp_data.append(parcellated)
                data[hemi][dtype] = np.array(tmp_data).T

    # -------- Assemble Final Data --------
    lh_data_parc = data['lh']
    rh_data_parc = data['rh']

    # -------- Construct Structural Covariance Matrices --------
    if run_hcp_data:
        # For HCP data
        total_hcp_corr_lh = plot_subject_similarity_hcp(lh_data_parc, transpose=False, vmin=-0.5, vmax=0.5, cmap='viridis')
        total_hcp_corr_rh = plot_subject_similarity_hcp(rh_data_parc, transpose=False, vmin=-0.5, vmax=0.5, cmap='viridis')
        # Save the correlation matrices
        if not os.path.exists(hcp_save_subpath):
            os.makedirs(hcp_save_subpath, exist_ok=True)
        np.save(f'{hcp_save_subpath}/total_hcp_corr_lh.npy', total_hcp_corr_lh)
        np.save(f'{hcp_save_subpath}/total_hcp_corr_rh.npy', total_hcp_corr_rh)
    else:
        # For ex vivo data
        (total_corr_lh, infra_corr_lh, supra_corr_lh, relative_corr_lh,
         ratio_supra_corr_lh, ratio_infra_corr_lh) = plot_subject_similarity(
            lh_data_parc, transpose=False, vmin=-0.5, vmax=1, cmap='viridis'
        )
        (total_corr_rh, infra_corr_rh, supra_corr_rh, relative_corr_rh,
         ratio_supra_corr_rh, ratio_infra_corr_rh) = plot_subject_similarity(
            rh_data_parc, transpose=False, vmin=-0.5, vmax=1, cmap='viridis'
        )

        if not os.path.exists(exvivo_save_subpath):
            os.makedirs(exvivo_save_subpath, exist_ok=True)

        # Save the correlation matrices
        np.save(f'{exvivo_save_subpath}/total_corr_lh.npy', total_corr_lh)
        np.save(f'{exvivo_save_subpath}/total_corr_rh.npy', total_corr_rh)
        np.save(f'{exvivo_save_subpath}/infra_corr_lh.npy', infra_corr_lh)
        np.save(f'{exvivo_save_subpath}/infra_corr_rh.npy', infra_corr_rh)
        np.save(f'{exvivo_save_subpath}/supra_corr_lh.npy', supra_corr_lh)
        np.save(f'{exvivo_save_subpath}/supra_corr_rh.npy', supra_corr_rh)
        np.save(f'{exvivo_save_subpath}/relative_corr_lh.npy', relative_corr_lh)
        np.save(f'{exvivo_save_subpath}/relative_corr_rh.npy', relative_corr_rh)
        np.save(f'{exvivo_save_subpath}/ratio_supra_corr_lh.npy', ratio_supra_corr_lh)
        np.save(f'{exvivo_save_subpath}/ratio_supra_corr_rh.npy', ratio_supra_corr_rh)
        np.save(f'{exvivo_save_subpath}/ratio_infra_corr_lh.npy', ratio_infra_corr_lh)
        np.save(f'{exvivo_save_subpath}/ratio_infra_corr_rh.npy', ratio_infra_corr_rh)




In [None]:

def calculate_gradients_from_brainspace(data2plot, mask_indices, atlas_data, hemisphere_mask, n_components, g_sparsity = 0.9):
    """Process gradient maps for one hemisphere"""
    grad_all = []
    
    # Fit gradient maps
    gm = GradientMaps(n_components, approach=G_dimension_reduction, kernel=G_kernel)
    
    if np.isnan(mask_indices).all():
        print('no mask')
        gm.fit(np.nan_to_num(data2plot, 0),sparsity = g_sparsity)#sparsity density is 0.9 by default
    else:
        mask = np.ones(data2plot.shape[0], dtype=bool)
        mask[mask_indices] = False
        gm.fit(data2plot, sparsity = g_sparsity)

    # Process gradients
    grad = []
    
    #for each gradient component..
    for j in range(n_components):

        data_len = len(data2plot)
        if np.isnan(mask_indices).all():
            tmp_gm = gm.gradients_[:,j]
        else:
            tmp_gm = np.full((data_len, 1), np.nan)
            tmp_gm[mask] = gm.gradients_[:,j].reshape(-1,1)
            tmp_gm = tmp_gm.ravel()
            
        atlas_slice = atlas_data
        min_val = np.min(atlas_slice[atlas_slice != 0])
        max_val = np.max(atlas_slice[atlas_slice != 0])
        #print(f"Atlas range: {min_val}-{max_val}")

        grad.append(map_to_labels(tmp_gm, atlas_slice, mask=hemisphere_mask, 
                                fill=np.nan))#, source_lab=np.arange(min_val,max_val+1)))

    return gm, grad


#this is the main funciton used for gradient alignment -092525 DJ
def align_gradients(X,Y,reflection=False,rotation=False):
    """Align source gradients to target gradients
        X: source gradients
        Y: target gradients
    """
    
    if len(X) != len(Y):
        raise ValueError("Lists must be same length")
    
    #set nan to 0
    X[np.isnan(X)] = 0
    Y[np.isnan(Y)] = 0

    print('X',X.shape)
    print('Y',Y.shape)

    #center the matrix
    X_centered = X - np.mean(X, axis=0)
    Y_centered = Y - np.mean(Y, axis=0)
    R, _ = orthogonal_procrustes(X_centered, Y_centered)

    if reflection == True and rotation == False:
        sign_matrix = np.sign(np.diag(R))  
        sign_transform = np.diag(sign_matrix)
        X_sign_corrected = X_centered @ sign_transform + np.mean(Y, axis=0)
        return X_sign_corrected, sign_transform
    elif reflection == True and rotation == True:
        X = X_centered @ R + np.mean(Y, axis=0)
        return X, R
    else:
        return X, Y
   

def unmap_gradient(grad_all_aligned, mask_indices, atlas_data, hemisphere_mask, start_idx, end_idx):
    """Unmap gradients to atlas data"""
    grad_all_unmapped = []
    for i, gm in enumerate(gm_all):
        grad = []
        for j in range(N_components):
            tmp_gm = gm.gradients_[:,j] if np.isnan(mask_indices).all() else \
                    np.full(len(data2plot[i]), np.nan)
            
            # Apply mask if needed
            if not np.isnan(mask_indices).all():
                mask = np.ones(data2plot[i].shape[0], dtype=bool)
                mask[mask_indices] = False
                tmp_gm[mask] = gm.gradients_[:,j]
            
            # Map to atlas labels
            atlas_slice = atlas_data[start_idx:end_idx]
            nonzero = atlas_slice[atlas_slice != 0]
            grad.append(label_to_map(tmp_gm, atlas_slice, mask=hemisphere_mask,
                                    fill=np.nan, source_lab=np.arange(nonzero.min(), nonzero.max()+1)))
    return grad

def map_gradients(gm_all, data2plot, mask_indices, atlas_data, hemisphere_mask, start_idx, end_idx):
    """Map gradients to atlas data"""
    grad_all_aligned = []
    for i, gm in enumerate(gm_all):
        grad = []
        for j in range(N_components):
            # Get gradient data
            tmp_gm = gm.gradients_[:,j] if np.isnan(mask_indices).all() else \
                    np.full(len(data2plot[i]), np.nan)
            
            # Apply mask if needed
            if not np.isnan(mask_indices).all():
                mask = np.ones(data2plot[i].shape[0], dtype=bool)
                mask[mask_indices] = False
                tmp_gm[mask] = gm.gradients_[:,j]
            
            # Map to atlas labels
            atlas_slice = atlas_data[start_idx:end_idx]
            nonzero = atlas_slice[atlas_slice != 0]
            grad.append(map_to_labels(tmp_gm, atlas_slice, mask=hemisphere_mask,
                                    fill=np.nan, source_lab=np.arange(nonzero.min(), nonzero.max()+1)))
        grad_all_aligned.append(grad)
    return grad_all_aligned


def create_hemisphere_plots(grad_all_aligned, surf_file, hemi, N_components_plot,color_ranges):
    """Create plots for one hemisphere"""
    plotters = []

    for pc in range(N_components_plot):
        plotter = yaspy.Plotter(surf_file, hemi=hemi)
        #[0] is some radii parameter I used before.. 
        #m = np.max(np.abs(([grad_all_aligned[0][pc]])))
        # Use percentiles but have the color scale centered at zero
        data = grad_all_aligned[0][:,pc]
        vmax = np.percentile(np.abs(data), 95)
        vmin = -vmax
        color_ranges[pc] = vmax
        overlay = plotter.overlay(data, cmap='RdBu_r', alpha=1, vmin=vmin, vmax=vmax)
        plotter.border(grad_all_aligned[0][:,pc], alpha=0)
        plotters.append([plotter.screenshot("lateral"), plotter.screenshot("medial"), overlay])
        
    return plotters


In [None]:
#gradient analysis parameters
G_sparsity = 0.9#sparsity paramter for brainspace (Default is 0.9)
global G_dimension_reduction, G_kernel
G_dimension_reduction = 'dm'
G_kernel = 'normalized_angle'

lh_surf = '/Users/dennis.jungchildmind.org/Downloads/HCP_S1200_Atlas_Z4_pkXDZ/S1200.L.white_MSMAll.32k_fs_LR.surf.gii'
rh_surf = '/Users/dennis.jungchildmind.org/Downloads/HCP_S1200_Atlas_Z4_pkXDZ/S1200.R.white_MSMAll.32k_fs_LR.surf.gii'

N_components = 10
N_components_plot = 10

mask_lh = atlas_data_lh != 0
mask_rh = atlas_data_rh != 0
#egular_values, partial_values
data2plot_lh = total_corr_lh[0]
data2plot_rh = total_corr_rh[0]


mask_index_lh = np.nan
mask_index_rh = np.nan
do_reflection = False
do_rotation = False
gm_all_lh, grad_lh = calculate_gradients_from_brainspace(data2plot_lh, mask_index_lh, atlas_data_lh, mask_lh, N_components,G_sparsity)
gm_all_rh, grad_rh = calculate_gradients_from_brainspace(data2plot_rh, mask_index_rh, atlas_data_rh, mask_rh, N_components,G_sparsity)


tmp_lh,R_lh = align_gradients(np.array(grad_lh).T, np.array(grad_lh).T, reflection=do_reflection, rotation=do_rotation)
tmp_rh,R_rh = align_gradients(np.array(grad_rh).T, np.array(grad_lh).T, reflection=do_reflection, rotation=do_rotation)

# Plot using Yaspy
# Assume grad_all_aligned_hemi_rh (right hemi gradients) and grad_all_lh (left hemi gradients) are ready
surf_file_lh = lh_surf  # Replace with correct surface file path for left hemi
surf_file_rh = rh_surf  # Replace with correct surface file path for right hemi

grad_this = {}

grad_this[0] = tmp_lh
# Plot left hemisphere gradients
plotters_lh = create_hemisphere_plots(grad_this, surf_file_lh, hemi='lh', N_components_plot=N_components_plot, color_ranges=[0.1]*N_components_plot)
# Plot right hemisphere gradients (use aligned gradients if desired)
#plotters_rh, m_rh = create_hemisphere_plots(grad_all_aligned_hemi_rh, surf_file_rh, hemi='rh', N_components_plot=N_components_plot, color_ranges=[0.1]*N_components_plot)

# Plot all PCAs (components) in two rows of subplots: first row lateral, second row medial
num_pcas = N_components_plot  # or len(plotters_lh), as appropriate
fig, axs = plt.subplots(2, num_pcas, figsize=(2*num_pcas, 4))  # two rows: lateral, medial

for i in range(num_pcas):
    # First row: lateral views
    ax_lateral = axs[0, i] if num_pcas > 1 else axs[0]
    ax_lateral.imshow(plotters_lh[i][0])  # 0 is lateral
    ax_lateral.set_title(f'Component {i+1}', fontsize=11)
    ax_lateral.axis('off')

    # Second row: medial views
    ax_medial = axs[1, i] if num_pcas > 1 else axs[1]
    ax_medial.imshow(plotters_lh[i][1])  # 1 is medial
  
    ax_medial.axis('off')

plt.tight_layout()
plt.show()

grad_this[0] = tmp_rh
# Plot left hemisphere gradients
plotters_rh = create_hemisphere_plots(grad_this, surf_file_rh, hemi='rh', N_components_plot=N_components_plot, color_ranges=[0.1]*N_components_plot)
# Plot right hemisphere gradients (use aligned gradients if desired)
#plotters_rh, m_rh = create_hemisphere_plots(grad_all_aligned_hemi_rh, surf_file_rh, hemi='rh', N_components_plot=N_components_plot, color_ranges=[0.1]*N_components_plot)

# Plot all PCAs (components) in a single row of subplots
num_pcas = N_components_plot  # or len(plotters_lh), as appropriate
fig, axs = plt.subplots(2, num_pcas, figsize=(2*num_pcas, 4))  # two rows: lateral, medial

for i in range(num_pcas):
    # First row: lateral views
    ax_lateral = axs[0, i] if num_pcas > 1 else axs[0]
    ax_lateral.imshow(plotters_rh[i][0])  # 0 is lateral
    ax_lateral.set_title(f'Component {i+1}', fontsize=11)
    ax_lateral.axis('off')
    
    # Second row: medial views
    ax_medial = axs[1, i] if num_pcas > 1 else axs[1]
    ax_medial.imshow(plotters_rh[i][1])  # 1 is medial
    ax_medial.axis('off')

plt.tight_layout()
plt.show()



#save tmp_lh and tmp_rh as the template gradient maps
if run_hcp_data:
    #save the template gradient maps
    os.makedirs(hcp_save_subpath, exist_ok=True)
    np.save(f'{hcp_save_subpath}/hcp_template_grad_lh_sparsity_{G_sparsity}.npy', tmp_lh)
    np.save(f'{hcp_save_subpath}/hcp_template_grad_rh_sparsity_{G_sparsity}.npy', tmp_rh)
    #save images of the template gradient maps
    grad_this[0] = tmp_lh
    plotters_lh = create_hemisphere_plots(grad_this, surf_file_lh, hemi='lh', N_components_plot=N_components_plot, color_ranges=[0.1]*N_components_plot)
    grad_this[0] = tmp_rh
    plotters_rh = create_hemisphere_plots(grad_this, surf_file_rh, hemi='rh', N_components_plot=N_components_plot, color_ranges=[0.1]*N_components_plot)
    # Plot all PCAs (components) in a single row of subplots
    num_pcas = N_components_plot  # or len(plotters_lh), as appropriate
    fig, axs = plt.subplots(1, num_pcas, figsize=(2*num_pcas, 2))
    for i in range(num_pcas):
        ax = axs[i] if num_pcas > 1 else axs  # handle when only one axis
        ax.imshow(plotters_lh[i][0])
        ax.set_title(f'Component {i+1}', fontsize=11)
        ax.axis('off')
    plt.tight_layout()
    plt.savefig(f'{hcp_save_subpath}/hcp_template_grad_lh_sparsity_{G_sparsity}.png')
    plt.show()
    #save right hemispere gradient maps
    fig, axs = plt.subplots(1, num_pcas, figsize=(2*num_pcas, 2))
    for i in range(num_pcas):
        ax = axs[i] if num_pcas > 1 else axs  # handle when only one axis
        ax.imshow(plotters_rh[i][0])
        ax.set_title(f'Component {i+1}', fontsize=11)
        ax.axis('off')
    plt.tight_layout()
    plt.savefig(f'{hcp_save_subpath}/hcp_template_grad_rh_sparsity_{G_sparsity}.png')
    plt.show()
    plot_eigenvalues(gm_all_lh, filename=f'hcp_template_grad_lh_eigenvalues_sparsity{G_sparsity}',save=True,SAVEFOLDER=hcp_save_subpath)
    plot_eigenvalues(gm_all_rh, filename=f'hcp_template_grad_rh_eigenvalues_sparsity{G_sparsity}',save=True,SAVEFOLDER=hcp_save_subpath)
    #close all plots
    plt.close('all')



In [None]:
# Fixed so code works for both LH and RH by ensuring correct subplot/metric assignment and variable names
base_dir = '/Users/dennis.jungchildmind.org/Downloads/HCP_S1200_Atlas_Z4_pkXDZ/schaefer400_gd'
surf_type = 'mid'  # 'mid','pial','white'
dist_type = 'gd'   # 'ed','gd'
distance_label = 'Euclidean Distance' if dist_type == 'ed' else 'Geodesic Distance'
cmap = 'gray_r'
font_size = 20
from scipy.stats import pearsonr, spearmanr
from sklearn.linear_model import LinearRegression
import scipy.stats as stats
from scipy.optimize import curve_fit
import warnings

# ---- Set hemisphere here ('lh' or 'rh')
hemisphere = 'lh'  # <-- change this to 'rh' for right hemisphere
dist_matrix = np.load(f'{base_dir}/{surf_type}_{hemisphere}/{dist_type}_matrix.npy')

# List of (label, matrix, pretty_label) to plot -- for both left and right hemispheres
metric_list_lh = [
    ('Total', total_hcp_corr_lh[0], 'Total_HCP'),
    ('Total', total_corr_lh[0], 'Total'),
    ('Supra', supra_corr_lh[0], 'Supra'),
    ('Infra', infra_corr_lh[0], 'Infra'),
    ('Ratio Supra', ratio_supra_corr_lh[0], 'Ratio Supra'),
    ('Ratio Infra', ratio_infra_corr_lh[0], 'Ratio Infra'),
]
metric_list_rh = [
    ('Total', total_hcp_corr_rh[0], 'Total_HCP'),
    ('Total', total_corr_rh[0], 'Total'),
    ('Supra', supra_corr_rh[0], 'Supra'),
    ('Infra', infra_corr_rh[0], 'Infra'),
    ('Ratio Supra', ratio_supra_corr_rh[0], 'Ratio Supra'),
    ('Ratio Infra', ratio_infra_corr_rh[0], 'Ratio Infra')
]

# Select metric list and axis holder based on hemisphere
if hemisphere == 'lh':
    n_metrics = len(metric_list_lh)
    metric_list = metric_list_lh
elif hemisphere == 'rh':
    n_metrics = len(metric_list_rh)
    metric_list = metric_list_rh
else:
    raise ValueError("hemisphere must be 'lh' or 'rh'")

fig, axs = plt.subplots(1, n_metrics, figsize=(5 * n_metrics, 6), squeeze=False)

# --- COLLECT ALL HISTOGRAMS TO FIND GLOBAL VMAX ---
all_hists = []
xedges_all = []
yedges_all = []
structural_cov_flat_norm_list = []
dist_matrix_flat_list = []

for i, (metric_name, structural_covariance, pretty_label) in enumerate(metric_list):
    mask = dist_matrix != 0
    dist_matrix_flat = dist_matrix[mask]
    structural_covariance_flat = structural_covariance[mask]
    
    # Normalize/clip to [-1, 1] as before
    if np.nanmax(np.abs(structural_covariance_flat)) > 1.0:
        structural_covariance_flat_norm = np.clip(
            structural_covariance_flat / np.nanmax(np.abs(structural_covariance_flat)), -1, 1
        )
    else:
        structural_covariance_flat_norm = np.clip(structural_covariance_flat, -1, 1)
    
    # Save for second loop
    dist_matrix_flat_list.append(dist_matrix_flat)
    structural_cov_flat_norm_list.append(structural_covariance_flat_norm)
    
    # Range for all plots (use same!)
    range2d = [[np.min(dist_matrix[mask]), np.max(dist_matrix[mask])], [-1, 1]]
    hist, xedges, yedges = np.histogram2d(
        dist_matrix_flat, structural_covariance_flat_norm, bins=[100, 50], range=range2d
    )
    all_hists.append(hist)
    xedges_all.append(xedges)
    yedges_all.append(yedges)

# Find global vmin/vmax for colorbar
all_hist_values = np.concatenate([h.flatten() for h in all_hists])
vmin = np.percentile(all_hist_values, 1)
vmax = np.percentile(all_hist_values, 99)

# --- PLOT WITH SHARED COLORBAR RANGE ---
for i, (metric_name, structural_covariance, pretty_label) in enumerate(metric_list):
    ax = axs[0, i]
    X, Y = np.meshgrid(xedges_all[i], yedges_all[i])
    pcm = ax.pcolormesh(X, Y, all_hists[i].T, cmap=cmap, vmin=vmin, vmax=vmax)
    cb = plt.colorbar(pcm, ax=ax)
    

    # --- Model fitting with linear vs exponential, AIC criterion, plus rho (Spearman) and significance ---
    dist_flat = dist_matrix_flat_list[i]
    sc_flat = structural_cov_flat_norm_list[i]
    
    # Remove any NaN or Inf values
    valid_mask = np.isfinite(dist_flat) & np.isfinite(sc_flat)
    dist_flat = dist_flat[valid_mask]
    sc_flat = sc_flat[valid_mask]
    
    # Bin the data by distance for clearer fitting/plotting
    bins = np.linspace(np.nanmin(dist_flat), np.nanmax(dist_flat), 40)
    bin_centers = (bins[:-1] + bins[1:]) / 2
    digitized = np.digitize(dist_flat, bins)
    means = np.array([np.nanmean(sc_flat[digitized == k]) for k in range(1, len(bins))])
    valid = np.isfinite(means)
    x_fit = bin_centers[valid]
    y_fit = means[valid]
    n = len(y_fit)
    
    if n < 4:  # Need at least 4 points to fit
        print(f"Not enough valid points for metric {metric_name}")
        continue
    
    # Exponential decay: y = a * exp(-b * x) + c
    # For negative decay, we expect a > 0, b > 0
    def exp_func(x, a, b, c):
        return a * np.exp(-b * x) + c
    
    # Linear: y = m*x + c
    def lin_func(x, m, c):
        return m * x + c
    
    # Try exponential fit with better initial guesses
    exp_success = True
    try:
        # Better initial parameter estimates for exponential decay
        # a: amplitude (difference between max and min)
        # b: decay rate (estimate from data span)
        # c: asymptote (approximate minimum value)
        y_range = np.max(y_fit) - np.min(y_fit)
        x_range = np.max(x_fit) - np.min(x_fit)
        
        # Initial guesses: a = range, b = 1/x_range (characteristic length), c = min value
        p0_exp = [y_range, 1.0/x_range if x_range > 0 else 0.01, np.min(y_fit)]
        
        # Set bounds: a can be positive or negative, b should be positive for decay, c is free
        bounds_exp = ([-np.inf, 0, -np.inf], [np.inf, np.inf, np.inf])
        
        popt_exp, pcov_exp = curve_fit(exp_func, x_fit, y_fit, p0=p0_exp, 
                                        bounds=bounds_exp, maxfev=10000)
        y_pred_exp = exp_func(x_fit, *popt_exp)
        residuals_exp = y_fit - y_pred_exp
        rss_exp = np.sum(residuals_exp ** 2)
        k_exp = 3
        
        if rss_exp == 0 or np.isnan(rss_exp) or rss_exp < 1e-10:
            aic_exp = np.inf
        else:
            aic_exp = 2 * k_exp + n * np.log(rss_exp / n)
            
    except Exception as e:
        exp_success = False
        aic_exp = np.inf
        y_pred_exp = np.full_like(y_fit, np.nan)
        print(f"Exponential fit failed for metric {metric_name}: {e}")
    
    # Try linear fit
    lin_success = True
    try:
        popt_lin, pcov_lin = curve_fit(lin_func, x_fit, y_fit)
        y_pred_lin = lin_func(x_fit, *popt_lin)
        residuals_lin = y_fit - y_pred_lin
        rss_lin = np.sum(residuals_lin ** 2)
        k_lin = 2
        
        if rss_lin == 0 or np.isnan(rss_lin) or rss_lin < 1e-10:
            aic_lin = np.inf
        else:
            aic_lin = 2 * k_lin + n * np.log(rss_lin / n)
            
    except Exception as e:
        lin_success = False
        aic_lin = np.inf
        y_pred_lin = np.full_like(y_fit, np.nan)
        print(f"Linear fit failed for metric {metric_name}: {e}")
    
    # Choose better model based on AIC (lower is better)
    # Also require AIC difference of at least 2 to be meaningful
    aic_diff = np.abs(aic_exp - aic_lin)
    
    if not exp_success and not lin_success:
        print(f"Both fits failed for metric {metric_name}")
        continue
    elif not exp_success:
        fit_label = 'Linear fit'
        fit_color = 'blue'
        y_pred = y_pred_lin
        chosen_aic = aic_lin
    elif not lin_success:
        fit_label = 'Exp fit'
        fit_color = 'red'
        y_pred = y_pred_exp
        chosen_aic = aic_exp
    elif aic_exp < aic_lin and aic_diff > 2:
        fit_label = f'Exp fit (AIC={aic_exp:.1f})'
        fit_color = 'red'
        y_pred = y_pred_exp
        chosen_aic = aic_exp
    elif aic_lin < aic_exp and aic_diff > 2:
        fit_label = f'Linear fit (AIC={aic_lin:.1f})'
        fit_color = 'blue'
        y_pred = y_pred_lin
        chosen_aic = aic_lin
    else:
        # AIC difference too small, choose simpler model (linear)
        fit_label = f'Linear fit (AIC~{aic_lin:.1f})'
        fit_color = 'blue'
        y_pred = y_pred_lin
        chosen_aic = aic_lin
    
    
    # Plot the binned means first
    #ax.scatter(x_fit, y_fit, color='yellow', s=30, alpha=0.7, zorder=5, 
    #           edgecolors='black', linewidths=0.5, label='Binned means')
    
    # Plot the chosen fit
    ax.plot(x_fit, y_pred, color=fit_color, lw=3, label=fit_label, zorder=6)
    
    # Calculate Spearman's rho and significance using ALL paired data (non-binned!)
    rho, pval = spearmanr(dist_flat, sc_flat, nan_policy='omit')
    
    # Calculate RÂ² for the chosen fit
    ss_res = np.sum((y_fit - y_pred) ** 2)
    ss_tot = np.sum((y_fit - np.mean(y_fit)) ** 2)
    r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0
    
    # Add legend with model selection and statistics
    legend_str = f"$R^2$={r_squared:.3f}\n$\\rho$={rho:.3f}\np={pval:.2e}"
    ax.text(0.05, 0.05, legend_str, transform=ax.transAxes, 
            fontsize=font_size-2, verticalalignment='bottom',
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.2))
    
    # Labels and formatting
    if i == 0:
        cb.set_label('Counts', fontsize=font_size, rotation=270, labelpad=18)
        ax.set_xlabel(distance_label, fontsize=font_size)
        ax.set_ylabel(f'{hemisphere.upper()}\n\nStructural Covariance', fontsize=font_size)
    else:
        ax.set_xlabel(distance_label, fontsize=font_size)
    
    ax.set_title(pretty_label, fontsize=font_size)
    ax.set_ylim(0, 1)
    ax.tick_params(axis='both', labelsize=font_size)
    cb.ax.tick_params(labelsize=font_size)

plt.tight_layout()
plt.show()

In [None]:
#gradient analysis parameters
G_sparsity = [0.9,0.8,0.7,0.6,0.5,0.4,0.3,0.2,0.1]#sparsity paramter for brainspace (Default is 0.9)
G_dimension_reduction = 'pca'
G_kernel = 'normalized_angle'

lh_surf = '/Users/dennis.jungchildmind.org/Downloads/HCP_S1200_Atlas_Z4_pkXDZ/S1200.L.white_MSMAll.32k_fs_LR.surf.gii'
rh_surf = '/Users/dennis.jungchildmind.org/Downloads/HCP_S1200_Atlas_Z4_pkXDZ/S1200.R.white_MSMAll.32k_fs_LR.surf.gii'


# Create dictionaries with simpler keys for data types
metrics = ['total_hcp','total', 'supra', 'infra', 'ratio_supra', 'ratio_infra']
N_components = 10
N_components_plot = 3

#paraemters for alignment
align_across_hemi = False
align_to_total = True
#rotation 
do_rotation = True #says don't do rotation but for calculation of the "reflection and reorder needt o be put True"
do_reflection = True

#mask?
mask_index_lh = np.nan
mask_index_rh = np.nan

# Initialize dictionaries for left and right hemispheres
lh_data_all = {m: {} for m in metrics}
rh_data_all = {m: {} for m in metrics}
lh_aligned_all = {m: {} for m in metrics}
rh_aligned_all = {m: {} for m in metrics}
#rotation matrices
R_lh_all = {m: {} for m in metrics}
R_rh_all = {m: {} for m in metrics}
R_lh_aligned_all = {m: {} for m in metrics}
R_rh_aligned_all = {m: {} for m in metrics}


for i,gs in enumerate(G_sparsity):
    print('G_sparsity:',gs)
    for metric in metrics:
        print(metric)
        # Get data for current metric
        data2plot_lh = eval(f"{metric}_corr_lh")[0]
        data2plot_rh = eval(f"{metric}_corr_rh")[0]
        # Create hemisphere masks
        atlasDat = atlas
        mask_lh = atlas_data_lh != 0
        mask_rh = atlas_data_rh != 0

        # calculate gradients using brainspace toolbox
        # make sure to adjust the G_sparsity parameter, which sets the sparsity for the affinity matrix
        gm_all_lh, grad_all_lh = calculate_gradients_from_brainspace(data2plot_lh, mask_index_lh, atlas_data_lh, mask_lh, N_components, gs)
        gm_all_rh, grad_all_rh = calculate_gradients_from_brainspace(data2plot_rh, mask_index_rh, atlas_data_rh, mask_rh, N_components, gs)
        #def plot_eigenvalues(gm_aligned, color='#2E86C1',metrics=None,save=False,SAVEFOLDER='./figures/eigenvalues'):
        #plot_eigenvalues(gm_all_lh, filename=f'{metric}_eigenvalues_lh_sparsity{G_sparsity}',save=True)
        #plot_eigenvalues(gm_all_rh, filename=f'{metric}_eigenvalues_rh_sparsity{G_sparsity}',save=True)
        
    # if align_across_hemi:
        if align_across_hemi:
            print('aligned across hemi')
            #align left hemisphere to itself (so that the data format is the same)
            grad_all_aligned_hemi_lh, R_lh = align_gradients(np.array(grad_all_lh).T, np.array(grad_all_lh).T, reflection=do_reflection, rotation=do_rotation)
            #align right hemisphere to left hemisphere
            grad_all_aligned_hemi_rh, R_rh = align_gradients(np.array(grad_all_rh).T, np.array(grad_all_lh).T, reflection=do_reflection, rotation=do_rotation)
            
        else:
            print('not aligned across hemi')
            grad_all_aligned_hemi_lh, R_lh = align_gradients(np.array(grad_all_lh).T, np.array(grad_all_lh).T, reflection=do_reflection, rotation=do_rotation)
            grad_all_aligned_hemi_rh, R_rh = align_gradients(np.array(grad_all_rh).T, np.array(grad_all_rh).T, reflection=do_reflection, rotation=do_rotation)

        #declare variables
        grad_all_aligned_lh = []
        grad_all_aligned_rh = []
        R_lh_hcp = []
        R_rh_hcp = []

        #align to total
        if align_to_total:
            #if align to total flag is on, make alignment for all other metrics, otherwise don't.
            if metric != 'total' and metric != 'total_hcp':
                print('aligning to total')
                _ , R_lh = align_gradients(grad_all_aligned_hemi_lh, lh_data_all['total'][0],reflection=do_reflection, rotation=do_rotation)
                _ , R_rh = align_gradients(grad_all_aligned_hemi_rh, rh_data_all['total'][0],reflection=do_reflection, rotation=do_rotation)
                
                #decompse the transformation matrix (R) to permultation index and signs 102725 DJ
                perm_idx_lh, signs_lh = decompose_to_permutation_and_signs(R_lh)
                perm_idx_rh, signs_rh = decompose_to_permutation_and_signs(R_rh)
                
                #reorder and flip the signs of the gradientmaps 102725 DJ
                grad_all_aligned_lh = grad_all_aligned_hemi_lh[:,perm_idx_lh]*signs_lh
                grad_all_aligned_rh = grad_all_aligned_hemi_rh[:,perm_idx_rh]*signs_rh
        
            else:
        
                grad_all_aligned_lh, R_lh = align_gradients(grad_all_aligned_hemi_lh, grad_all_aligned_hemi_lh,reflection=do_reflection, rotation=do_rotation)
                #align righ hemishere to lh
                grad_all_aligned_rh, R_rh = align_gradients(grad_all_aligned_hemi_rh, grad_all_aligned_hemi_rh,reflection=do_reflection, rotation=do_rotation)



        # Store aligned to "total" (sign_flip + reordering components)
        lh_aligned_all[metric][i] = grad_all_aligned_lh
        rh_aligned_all[metric][i] = grad_all_aligned_rh
        R_lh_aligned_all[metric][i] = R_lh_hcp
        R_rh_aligned_all[metric][i] = R_rh_hcp


        #Aligned across hemispehres
        lh_data_all[metric][i] = grad_all_aligned_hemi_lh
        rh_data_all[metric][i] = grad_all_aligned_hemi_rh
        R_lh_all[metric][i] = R_lh
        R_rh_all[metric][i] = R_rh



    

In [None]:
def plot_component_correlation(
    data,
    n_components=3,
    sparsity_index=0,#first one in G_sparsity
    vmin=-0.8,
    vmax=0.8,
    figsize_per_component=(3,3),
    colormap='RdBu_r'
):
    """
    For each component up to n_components, collect the component for each metric in the aligned gradients dictionary,
    compute their correlation matrix, and plot them in a single figure in columns (single row).

    Args:
        data (dict): Dictionary of aligned gradients; each key is a metric, each value is a list;
                     the first item (index 0) is assumed to be an array with at least two dimensions,
                     with columns representing components.
        n_components (int): Number of components to plot.
        vmin (float): Minimum value for color scale.
        vmax (float): Maximum value for color scale.
        figsize_per_component (tuple): Size of each subplot (width, height).
        colormap (str): matplotlib colormap.

    Returns:
        None
    """
    import numpy as np
    import matplotlib.pyplot as plt

    def label_format(key):
        """Format keys for LaTeX axis labels: main_subscript"""
        if '_' in key:
            main, sub = key.split('_', 1)
            return fr"${main.capitalize()}_{{\mathrm{{{sub}}}}}$"
        else:
            return fr"${key.capitalize()}$"

    keys = list(data.keys())
    labels = [label_format(key) for key in keys]
    n_metrics = len(keys)

    # Output the shape for debugging
    sample_shape = data[keys[0]][0].shape
    print(f"Example aligned gradients array shape (should be n_vertices x n_components): {sample_shape}")

    fig, axes = plt.subplots(
        1, n_components,
        figsize=(n_components * figsize_per_component[0], figsize_per_component[1]),
        squeeze=False
    )

    # Get global vmin/vmax for all correlation matrices, if you want consistent colorbars (optional)
    # Otherwise, you can just use the provided vmin/vmax for all
    for component in range(n_components):
        # Extract the specified component from each metric
        try:
            comp_matrix = np.column_stack([
                data[key][sparsity_index][:, component] for key in keys
            ])
        except IndexError as e:
            print(f"Component {component} is out of bounds for available gradient array shape {sample_shape}.")
            raise e

        print(f"Component matrix shape for component {component}: {comp_matrix.shape}")

        corrmat = np.corrcoef(comp_matrix.T)
        np.fill_diagonal(corrmat, np.nan)

        ax = axes[0, component]
        im = ax.imshow(corrmat, cmap=colormap, vmin=vmin, vmax=vmax)
        ax.set_xticks(np.arange(n_metrics))
        ax.set_yticks(np.arange(n_metrics))
        ax.set_xticklabels(labels, rotation=90, ha='center')
        ax.set_yticklabels(labels)
        ax.set_title(f'Component {component+1}')
        
        # Only add colorbar to each subplot (or just the last one, as preferred)
        fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    
    plt.tight_layout()
    plt.show()


# Example usage: plots the first `n_components` components in `lh_aligned_all` as a single row of subplots
gs_sparsity = 0.9
sparsity_index = [i for i, x in enumerate(G_sparsity) if x == gs_sparsity][0]
plot_component_correlation(lh_aligned_all, n_components=3, sparsity_index=sparsity_index, colormap='RdBu_r')

In [None]:
#this combines all pc components and pltos in the same figure (only one data type)
def plot_component_correlation(
    data,
    layer_type='total',
    component = 0,
    G_sparsity=G_sparsity,
    vmin=-0.9,
    vmax=0.9,
    figsize_per_component=(3,3),
    colormap='RdBu_r'
):
    """
    For each component up to n_components, collect the component for each metric in the aligned gradients dictionary,
    compute their correlation matrix, and plot them in a single figure in columns (single row).

    Args:
        data (dict): Dictionary of aligned gradients; each key is a metric, each value is a list;
                     the first item (index 0) is assumed to be an array with at least two dimensions,
                     with columns representing components.
        n_components (int): Number of components to plot.
        vmin (float): Minimum value for color scale.
        vmax (float): Maximum value for color scale.
        figsize_per_component (tuple): Size of each subplot (width, height).
        colormap (str): matplotlib colormap.

    Returns:
        None
    """
    import numpy as np
    import matplotlib.pyplot as plt

    def label_format(key):
        """Format keys for LaTeX axis labels: main_subscript"""
        if '_' in key:
            main, sub = key.split('_', 1)
            return fr"${main.capitalize()}_{{\mathrm{{{sub}}}}}$"
        else:
            return fr"${key.capitalize()}$"

    keys = [layer_type]
    labels = [layer_type]
    n_metrics = len(keys)
    
    # Output the shape for debugging
    #sample_shape = data[keys[0]][0].shape
    #print(f"Example aligned gradients array shape (should be n_vertices x n_components): {sample_shape}")
    fig, ax = plt.subplots(
        1, 1,
        figsize=( figsize_per_component[0], figsize_per_component[1])
    )
   
    comp_matrix =[]#set this to HCP data
    comp_matrix = data['total_hcp'][0][:,component]#this is the first sparsity (0.9)
    layer_data = data[layer_type]

    for sparsity_index in range(len(G_sparsity)):

        # Extract the specified component from each *metric* for this sparsity and stack
        try:
            #if len(comp_matrix) == 0:
                #this should be the first sparsity which is at 0.9 * 
            #    comp_matrix = layer_data[sparsity_index][:, component][:, None]

            #else:

            #for the sparsity other than 0.9 align the gradient maps to the first sparsity (0.9)
            _ , R= align_gradients(layer_data[sparsity_index], layer_data[0],reflection=do_reflection, rotation=do_rotation)
            
            #decompse the transformation matrix (R) to permultation index and signs 102725 DJ
            perm_idx, signs= decompose_to_permutation_and_signs(R)
            
            #reorder and flip the signs of the gradientmaps 102725 DJ
            grad_aligned_ = layer_data[sparsity_index][:,perm_idx]*signs


            #comp_matrix = np.column_stack((comp_matrix, layer_data[sparsity_index][:, component]))
            comp_matrix = np.column_stack((comp_matrix, grad_aligned_[:,component]))
                
        except IndexError as e:
            print(f"Component {component} is out of bounds for available gradient arrays. Sample shape: {sample_shape}")
            raise e


    #align gradients
    corrmat = np.corrcoef(comp_matrix.T)
    np.fill_diagonal(corrmat, np.nan)

    im = ax.imshow(corrmat, cmap=colormap, vmin=vmin, vmax=vmax)
    ax.set_xticks(np.arange(len(G_sparsity)+1))
    ax.set_yticks(np.arange(len(G_sparsity)+1))
    label_name_list = ['HCP 0.9'] + [f"EX {g}" for g in G_sparsity]
    ax.set_xticklabels(label_name_list, rotation=90, ha='center')
    ax.set_yticklabels(label_name_list)
    ax.set_title(f'Component {component+1}')
    
    # Only add colorbar to each subplot (or just the last one, as preferred)
    fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    plt.title(layer_type.capitalize())
    plt.tight_layout()
    plt.show()


# Example usage: plots the first `n_components` components in `lh_aligned_all` as a single row of subplots
LAYER_TYPE = 'total'
plot_component_correlation(lh_aligned_all, layer_type=LAYER_TYPE,component = 0,G_sparsity=G_sparsity, colormap='RdBu_r')
plot_component_correlation(lh_aligned_all, layer_type=LAYER_TYPE,component = 1,G_sparsity=G_sparsity, colormap='RdBu_r')
plot_component_correlation(lh_aligned_all, layer_type=LAYER_TYPE,component = 2,G_sparsity=G_sparsity, colormap='RdBu_r')

In [None]:
mask_lh = atlas_data_lh != 0
mask_rh = atlas_data_rh != 0

hemi = 'lh'# lh or rh
N_components = 10

do_rotation = True
do_reflection = True

if hemi == 'lh':
    total_corr = total_corr_lh[0]
    supra_corr = supra_corr_lh[0]
    infra_corr = infra_corr_lh[0]
    relative_corr = relative_corr_lh[0]
    ratio_supra_corr = ratio_supra_corr_lh[0]
    ratio_infra_corr = ratio_infra_corr_lh[0]
    mask_index = mask_index_lh
    atlas_data = atlas_data_lh
    mask = mask_lh
else:
    total_corr = total_corr_rh[0]
    supra_corr = supra_corr_rh[0]
    infra_corr = infra_corr_rh[0]
    relative_corr = relative_corr_rh[0]
    ratio_supra_corr = ratio_supra_corr_rh[0]
    ratio_infra_corr = ratio_infra_corr_rh[0]
    mask_index = mask_index_rh
    atlas_data = atlas_data_rh
    mask = mask_rh

gm_total, grad_total = calculate_gradients_from_brainspace(total_corr, mask_index, atlas_data, mask, N_components, G_sparsity)
gm_supra, grad_supra = calculate_gradients_from_brainspace(supra_corr, mask_index, atlas_data, mask, N_components, G_sparsity)
gm_infra, grad_infra = calculate_gradients_from_brainspace(infra_corr, mask_index, atlas_data, mask, N_components, G_sparsity)
gm_relative, grad_relative = calculate_gradients_from_brainspace(relative_corr, mask_index, atlas_data, mask, N_components, G_sparsity)
gm_ratio_supra, grad_ratio_supra  = calculate_gradients_from_brainspace(ratio_supra_corr, mask_index, atlas_data, mask, N_components, G_sparsity)
gm_ratio_infra, grad_ratio_infra = calculate_gradients_from_brainspace(ratio_infra_corr, mask_index, atlas_data, mask, N_components, G_sparsity)

#convert to numpy arrays and trasnpose
supra_grad = np.array(grad_supra).T
infra_grad = np.array(grad_infra).T
total_grad = np.array(grad_total).T
relative_grad = np.array(grad_relative).T
ratio_supra_grad = np.array(grad_ratio_supra).T
ratio_infra_grad = np.array(grad_ratio_infra).T

#R is (A @ R ) - B, 
_, R_total = align_gradients(total_grad, total_grad, reflection=do_reflection, rotation=do_rotation)
_, R_supra = align_gradients(supra_grad, total_grad, reflection=do_reflection, rotation=do_rotation)
_, R_infra = align_gradients(infra_grad, total_grad, reflection=do_reflection, rotation=do_rotation)
_, R_relative= align_gradients(relative_grad, total_grad, reflection=do_reflection, rotation=do_rotation)
_, R_ratio_supra = align_gradients(ratio_supra_grad, total_grad, reflection=do_reflection, rotation=do_rotation)
_, R_ratio_infra = align_gradients(ratio_infra_grad, total_grad, reflection=do_reflection, rotation=do_rotation)

R_matrices = [
    ("Supra", R_supra),
    ("Infra", R_infra),
    ("Relative", R_relative),
    ("Ratio Supra", R_ratio_supra),
    ("Ratio Infra", R_ratio_infra)
]

num_R = len(R_matrices)

# Match cross-correlation subplot style: 1 row, num_R columns, (4*num_R, 4) figsize
fig, axs = plt.subplots(1, num_R, figsize=(4*num_R, 4))
if num_R == 1:
    axs = [axs]
for i, (ax, (title, R_mat)) in enumerate(zip(axs, R_matrices)):
    im = ax.imshow(R_mat, vmin=-1, vmax=1, cmap='bwr')
    n_comp_x = R_mat.shape[1]
    n_comp_y = R_mat.shape[0]
    ax.set_xlabel('Target PC', fontsize=15)
    ax.set_ylabel('Source PC', fontsize=15)
    ax.set_title(f'{title} alignment\nrotation/reflection', fontsize=15)
    # Index ticks from 1 for columns/rows
    ax.set_xticks(np.arange(n_comp_x))
    ax.set_xticklabels(np.arange(1, n_comp_x+1), fontsize=11)
    ax.set_yticks(np.arange(n_comp_y))
    ax.set_yticklabels(np.arange(1, n_comp_y+1), fontsize=11)
    # Annotate each cell
    for y in range(n_comp_y):
        for x in range(n_comp_x):
            val = R_mat[y, x]
            text_color = 'w' if abs(val) > 0.5 else 'black'
            ax.text(x, y, f"{val:.2f}", ha="center", va="center", color=text_color, fontsize=8)
    # Only add colorbar to the last subplot for a cleaner look
    if i == num_R - 1:
        cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
plt.tight_layout()
plt.show()

perm_idx_supra, signs_supra = decompose_to_permutation_and_signs(R_supra)
perm_idx_infra, signs_infra = decompose_to_permutation_and_signs(R_infra)
perm_idx_relative, signs_relative = decompose_to_permutation_and_signs(R_relative)
perm_idx_ratio_supra, signs_ratio_supra = decompose_to_permutation_and_signs(R_ratio_supra)
perm_idx_ratio_infra, signs_ratio_infra = decompose_to_permutation_and_signs(R_ratio_infra)

print('supra',  perm_idx_supra+1, signs_supra)
print('infra', perm_idx_infra+1, signs_infra)
print('relative', perm_idx_relative+1, signs_relative)
print('ratio_supra', perm_idx_ratio_supra+1, signs_ratio_supra)
print('ratio_infra', perm_idx_ratio_infra+1, signs_ratio_infra)

source_aligned_supra = supra_grad[:,perm_idx_supra]*signs_supra
source_aligned_infra = infra_grad[:,perm_idx_infra]*signs_infra
source_aligned_relative = relative_grad[:,perm_idx_relative]*signs_relative
source_aligned_ratio_supra = ratio_supra_grad[:,perm_idx_ratio_supra]*signs_ratio_supra
source_aligned_ratio_infra = ratio_infra_grad[:,perm_idx_ratio_infra]*signs_ratio_infra


In [None]:
# Succinct plotting routine for left hemisphere gradients: all in a single large figure
if hemi == 'lh':
    surf_file = lh_surf
else:
    surf_file = rh_surf

num_pcas = 3

# Prepare gradient sets and labels for iteration
grads_and_labels = [
    (total_grad, [f'{i+1}' for i in range(num_pcas)], 'total'),
    (source_aligned_supra, [f'{perm_idx_supra[i]+1}' for i in range(num_pcas)], 'supra'),
    (source_aligned_infra, [f'{perm_idx_infra[i]+1}' for i in range(num_pcas)], 'infra'),
    (source_aligned_relative, [f'{perm_idx_relative[i]+1}' for i in range(num_pcas)], 'relative'),
    (source_aligned_ratio_supra, [f'{perm_idx_ratio_supra[i]+1}' for i in range(num_pcas)], 'supra/total'),
    (source_aligned_ratio_infra, [f'{perm_idx_ratio_infra[i]+1}' for i in range(num_pcas)], 'infra/total')
]

# Save the grads_and_labels arrays and metadata using numpy
import numpy as np
if not run_hcp_data:
    os.makedirs(exvivo_save_subpath, exist_ok=True)
    # grads_and_labels is a list of (array, label_list, metric)
    # We'll save arrays and metadata separately.
    arrays = [item[0] for item in grads_and_labels]
    labels = [item[1] for item in grads_and_labels]
    metrics = [item[2] for item in grads_and_labels]
    np.savez(
        f'{exvivo_save_subpath}/grad_algnd2_total_grad_{hemi}_sparsity_{G_sparsity}.npz',
        arrays=arrays, labels=labels, metrics=metrics
    )


num_sets = len(grads_and_labels)  # rows
n_cols = num_pcas                # columns

# Create a large figure: rows = metrics, cols = PCs
fig, axs = plt.subplots(num_sets, n_cols, figsize=(2*n_cols, 2*num_sets))
if num_sets == 1:
    axs = [axs]
if n_cols == 1:
    axs = np.expand_dims(axs, axis=1)

for row_idx, (grad_set, label_list, metric) in enumerate(grads_and_labels):
    grad_this = {0: grad_set}
    plotters = create_hemisphere_plots(
        grad_this, surf_file, hemi=hemi,
        N_components_plot=N_components_plot,
        color_ranges=[0.5] * N_components_plot
    )
    # Add the metric name to the center of the row at the top
    center_col = n_cols // 2
    axs[row_idx, center_col].set_title(metric, fontsize=18, pad=25, loc='center', color='black', fontweight='bold')
    for col_idx in range(n_cols):
        ax = axs[row_idx, col_idx]
        ax.imshow(plotters[col_idx][0])
        # Only add component # label if this column does NOT already have the row title
        ax.set_title(f'Component {label_list[col_idx]}', fontsize=12)
        ax.axis('off')

plt.tight_layout()
plt.show()


# Import grads_and_labels from .npz file (matching earlier code block)
if not run_hcp_data:
    exvivo_save_subpath = 'EXVIVO'
    npz_path = f'{exvivo_save_subpath}/grad_algnd2_total_grad_{hemi}_sparsity_{G_sparsity}.npz'
    with np.load(npz_path, allow_pickle=True) as data:
        arrays = data['arrays']
        labels = data['labels']
        metrics = data['metrics']
        grads_and_labels = list(zip(arrays, labels, metrics))

# Print the 'metric' of the second grad_and_label entry (index 1)
if 'grads_and_labels' in locals() and len(grads_and_labels) > 1:
    print(grads_and_labels[1][2])
else:
    print("grads_and_labels is not properly loaded or too short.")



# Spatial similarity between ex vivo and hcp data, if the hcp gradient maps are already saved in the folder
import matplotlib.cm as cm
import matplotlib.colors as mcolors
from enigmatoolbox.permutation_testing import spin_test, shuf_test

# --- Parameters block: set all user-controlled parameters here ---
hcp_grad_file = f'{hcp_save_subpath}/hcp_template_grad_{hemi}_sparsity_{G_sparsity}.npy'
metric_labels = ['Total', 'Infra', 'Supra', 'Relative', 'Infra/Total','Supra/Total']  # Can be any length
bar_width = 0.12
colormap_name = 'tab20'
ylim_min = -0.5
ylim_max = 0.5
font_size = 16
ncomp = 3 #number of components to plot ; if not specified, all will be plotted

# ---------------------------------------------------------------

# Load HCP gradient maps and calculate spatial similarity
if os.path.exists(hcp_grad_file):
    hcp_tmpl = np.load(hcp_grad_file)

    # Gather all available correlation vectors as a list so logic below can adapt to list length:
    calc_corrs = [
        ('Total', np.diag(np.corrcoef(total_grad.T, hcp_tmpl.T)[:total_grad.shape[1], total_grad.shape[1]:])),
        ('Supra', np.diag(np.corrcoef(source_aligned_supra.T, hcp_tmpl.T)[:source_aligned_supra.shape[1], source_aligned_supra.shape[1]:])),
        ('Infra', np.diag(np.corrcoef(source_aligned_infra.T, hcp_tmpl.T)[:source_aligned_infra.shape[1], source_aligned_infra.shape[1]:])),
        ('Relative', np.diag(np.corrcoef(source_aligned_relative.T, hcp_tmpl.T)[:source_aligned_relative.shape[1], source_aligned_relative.shape[1]:])),
        ('Infra/Total', np.diag(np.corrcoef(source_aligned_ratio_infra.T, hcp_tmpl.T)[:source_aligned_ratio_infra.shape[1], source_aligned_ratio_infra.shape[1]:])),
        ('Supra/Total', np.diag(np.corrcoef(source_aligned_ratio_supra.T, hcp_tmpl.T)[:source_aligned_ratio_supra.shape[1], source_aligned_ratio_supra.shape[1]:])),
       
    ]
    # Truncate to min(len(calc_corrs), len(metric_labels)) in case metric_labels or corrs is changed
    max_len = min(len(metric_labels), len(calc_corrs))
    corr_diags = [(metric_labels[i], calc_corrs[i][1]) for i in range(max_len)]

    if len(corr_diags) == 0:
        print("No correlation vectors specified, nothing to plot.")
    else:
        # Determine the number of components actually present; use the minimum length of any values vector
        if ncomp is not None:
            num_components = ncomp
        else:
            num_components = min(len(values) for _, values in corr_diags)
        x = np.arange(num_components)

        fig_height = max(4, max_len)
        fig, ax = plt.subplots(figsize=(num_components*3, fig_height))

        n_groups = len(corr_diags)
        offsets = (np.arange(n_groups) - (n_groups - 1) / 2) * bar_width

        cmap = cm.get_cmap(colormap_name)
        if n_groups > cmap.N:
            colors = [cmap(i / n_groups) for i in range(n_groups)]
        else:
            colors = cmap(np.linspace(0, 1, n_groups))

        for idx, ((label, values), color) in enumerate(zip(corr_diags, colors)):
            ax.bar(x + offsets[idx], abs(values[:num_components]), bar_width, label=label, color=color)

        ax.set_xlabel('Components', fontsize=font_size)
        ax.set_ylabel('Spatial similarity (|$\\rho$|)', fontsize=font_size)
        ax.set_xticks(x)
        ax.set_xticklabels([f'{i+1}' for i in x], fontsize=font_size)
        ax.tick_params(axis='y', labelsize=font_size, left=False)
        # Put the legend outside the plot to the right
        ax.legend(fontsize=13, loc='center left', bbox_to_anchor=(1.15, 0.5), ncol=1, frameon=False)
        ax.set_title('Spatial similarity between \nex vivo and HCP (total) gradient maps', fontsize=font_size)
        ax.set_ylim(0, ylim_max)

        # Remove all spines (box around plot)
        for spine in ax.spines.values():
            spine.set_visible(False)
        ax.tick_params(axis='y', which='both', left=True, right=False)

        plt.tight_layout()
        plt.show()





In [None]:
# Calculate cross correlations between total_grad and all aligned sources
aligned_sources = [
    ("Supra (aligned)", source_aligned_supra),
    ("Infra (aligned)", source_aligned_infra),
    ("Relative (aligned)", source_aligned_relative),
    ("Supra/Total (aligned)", source_aligned_ratio_supra),
    ("Infra/Total (aligned)", source_aligned_ratio_infra)
]

unaligned_sources = [
    ("Supra", supra_grad),
    ("Infra", infra_grad),
    ("Relative", relative_grad),
    ("Supra/Total", ratio_supra_grad),
    ("Infra/Total", ratio_infra_grad)
]

sources_to_plot = aligned_sources
num_sources = len(sources_to_plot)
corr_mats = []
labels = []

for label, aligned in sources_to_plot:
    cross_corr = np.corrcoef(total_grad.T, aligned.T)[:total_grad.shape[1], total_grad.shape[1]:]
    corr_mats.append(cross_corr)
    labels.append(label)

# Put all correlation matrices in a single row figure
fig, axs = plt.subplots(1, num_sources, figsize=(4*num_sources, 4))
if num_sources == 1:
    axs = [axs]
for i, (ax, mat, label) in enumerate(zip(axs, corr_mats, labels)):
    im = ax.imshow(mat, cmap='bwr', vmin=-1, vmax=1)
    n_comp_x = mat.shape[1]
    n_comp_y = mat.shape[0]
    ax.set_xlabel(f'{label} PC',fontsize=15)
    ax.set_ylabel('Total PC',fontsize=15)
    ax.set_title(f'Correlation:\nTotal vs {label}', fontsize=15)
    # Index ticks from 1, use slightly larger ticklabel font size (e.g., 17)
    ax.set_xticks(np.arange(n_comp_x))
    ax.set_xticklabels(np.arange(1, n_comp_x+1), fontsize=11)
    ax.set_yticks(np.arange(n_comp_y))
    ax.set_yticklabels(np.arange(1, n_comp_y+1), fontsize=11)
    # Annotate each cell with value (2 decimals), using a bit smaller font size (e.g., 8)
    for y in range(n_comp_y):
        for x in range(n_comp_x):
            val = mat[y, x]
            text_color = 'w' if abs(val) > 0.5 else 'black'
            ax.text(x, y, f"{val:.2f}", ha="center", va="center", color=text_color, fontsize=8)
    # Only add colorbar to the last subplot for a cleaner look
    if i == num_sources - 1:
        cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
plt.tight_layout()
plt.show()


In [None]:
# Calculate similarity matrices across data types for first 3 PCs
data_types = ['total','supra', 'infra',  'ratio_supra', 'ratio_infra']#,'diff']
title_labels = ['total','supra', 'infra', 'supra/total', 'infra/total']#,'diff']

plot_dat_lh = []
plot_dat_rh = []


print('Aligned across hemi reflection and/or rotation matrix being plotted')
plot_dat_lh = R_lh_all
plot_dat_rh = R_rh_all


# Plot both hemispheres
for hemi_name, plot_dat in [('left hemisphere', plot_dat_lh), ('right hemisphere', plot_dat_rh)]:
    print(f'Plotting {hemi_name}')
    fig, axes = plt.subplots(1, len(data_types), figsize=(5*len(data_types), 10))

    for i,dtype in enumerate(data_types):
        ax = axes[i]
        im = ax.imshow(plot_dat[dtype][0],cmap='RdBu_r',vmin=-1.5,vmax=1.5)
        ax.set_title(title_labels[i], fontsize=17, fontweight='bold')
        
        # Add labels
        ax.set_xticks(np.arange(0,10))
        ax.set_xticklabels(np.arange(1,11),fontsize=16)
        ax.set_xlabel('Principal Component',fontsize=16)
        ax.set_yticks(np.arange(0,10))
        ax.set_yticklabels(np.arange(1,11),fontsize=16)
        ax.set_ylabel('Principal Component',fontsize=16)
        # Add colorbar
        #cbar = plt.colorbar(im, ax=ax)
        #cbar.ax.tick_params(labelsize=14)

    plt.tight_layout()
    plt.show()

In [None]:
# Calculate similarity matrices across data types for first 3 PCs
data_types = ['total','supra', 'infra',  'ratio_supra', 'ratio_infra']#,'diff',]
tick_labels = ['total','supra', 'infra', 'supra/total', 'infra/total']

num_pca_plot = 3
#save subfolder 
SAVEFOLDER = './figures/gradient_corr_mat'
# Helper function to process and plot correlation matrices
def plot_correlation_matrix(data_dict, pc_index, ax, hemisphere):
    # Concatenate data for all types
    tmp_data = None
    for data_type in data_types:
        data = data_dict[data_type][0][:,pc_index].reshape(-1,1)
        tmp_data = data if tmp_data is None else np.concatenate((tmp_data, data), axis=1)
    
    # Calculate and plot correlations
    corr_mat = np.corrcoef(tmp_data.T)
    im = ax.imshow(corr_mat, cmap='RdBu_r', vmin=-1, vmax=1)
    ax.set_title(f'PC {pc_index+1}', fontsize=17, fontweight='bold')
    
    # Add labels
    ax.set_xticks(np.arange(len(tick_labels)))
    ax.set_yticks(np.arange(len(tick_labels)))
    ax.set_xticklabels(tick_labels, rotation=45, ha='right', fontsize=16)
    ax.set_yticklabels(tick_labels, fontsize=16)
    
    # Add colorbar
    cbar = plt.colorbar(im, ax=ax)
    cbar.ax.tick_params(labelsize=14)
    
    return im




# Plot correlation matrices for the unaligned
fig, axes = plt.subplots(1, num_pca_plot*2, figsize=(10*num_pca_plot, 4))
print('displaying unaligned data')
for pc_index in range(num_pca_plot):
    # Left hemisphere
    plot_correlation_matrix(lh_data_all, pc_index, axes[pc_index], 'LH')
    
    # Right hemisphere 
    plot_correlation_matrix(rh_data_all, pc_index, axes[pc_index + num_pca_plot], 'RH')

plt.tight_layout()
#if the save folder exists, save the figure
if not os.path.exists(SAVEFOLDER):
    os.makedirs(SAVEFOLDER)
plt.savefig(os.path.join(SAVEFOLDER, f'unaligned_corr_mat_sparsity{G_sparsity}.png'), dpi=300)
plt.close()





In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable
#Unaligned data 

print('displaying unaligned data (not reordered)')
N_components_plot = 3
labelsize = 10
SAVEFOLDER = f'./figures/gradients/sparsity{G_sparsity}'
#for data_type in lh_data_all.keys():
for data_type in data_types:
    print(data_type)
    grad2plot_lh = lh_data_all[data_type]
    grad2plot_rh = rh_data_all[data_type]
    #color_ranges_lh = [np.nanmax(np.abs([grad2plot_lh[0][:,pc]])) for pc in range(N_components_plot)]
    #color_ranges_rh = [np.nanmax(np.abs([grad2plot_rh[0][:,pc]])) for pc in range(N_components_plot)]
    color_ranges_lh = [np.percentile(([grad2plot_lh[0][:,pc]]),95) for pc in range(N_components_plot)]
    color_ranges_rh = [np.percentile(([grad2plot_rh[0][:,pc]]),95) for pc in range(N_components_plot)]

    plotters_lh = create_hemisphere_plots(
        grad2plot_lh,
        '/Users/dennis.jungchildmind.org/Downloads/HCP_S1200_Atlas_Z4_pkXDZ/S1200.L.white_MSMAll.32k_fs_LR.surf.gii',
        'lh',
        N_components_plot,
        color_ranges_lh
    )

    plotters_rh = create_hemisphere_plots(
        grad2plot_rh,
        '/Users/dennis.jungchildmind.org/Downloads/HCP_S1200_Atlas_Z4_pkXDZ/S1200.R.white_MSMAll.32k_fs_LR.surf.gii',
        'rh',
        N_components_plot,
        color_ranges_rh
    )

    # Plot both hemispheres
    for hemi, plotters, color_ranges in [('LH', plotters_lh, color_ranges_lh), ('RH', plotters_rh, color_ranges_rh)]:
        f, (ax1, ax2) = plt.subplots(2, N_components_plot, figsize=(2*N_components_plot, 5))

        for pc in range(N_components_plot):
            # Create mappable object for colorbar
            im1 = ax1[pc].imshow(plotters[pc][0], cmap='RdBu_r', 
                                vmin=-color_ranges[pc], vmax=color_ranges[pc])
            ax1[pc].axis('off')
            ax1[pc].set_title(f'PC{pc+1}', fontsize=labelsize*2, fontweight='bold')
            
            im2 = ax2[pc].imshow(plotters[pc][1], cmap='RdBu_r',
                                vmin=-color_ranges[pc], vmax=color_ranges[pc])
            ax2[pc].axis('off')
            
            # Add colorbar between the plots
            divider = make_axes_locatable(ax1[pc])
            cax = divider.append_axes("bottom", size="5%", pad=0.5)
            cbar = plt.colorbar(im1, cax=cax, orientation='horizontal')
            cbar.ax.tick_params(labelsize=labelsize)  # Increase colorbar tick label size

        plt.tight_layout()
        #save figure and the save folder, make the save folder if it doesn't exist
        if not os.path.exists(SAVEFOLDER):
            os.makedirs(SAVEFOLDER)
        plt.savefig(os.path.join(SAVEFOLDER, f'{data_type}_{hemi.lower()}_unaligned_unordered.png'), dpi=300)
        plt.close()
        


In [None]:
#aligned data 
print('displaying aligned data (Reordered)')
#for data_type in lh_data_all.keys():
for data_type in data_types:
    print(data_type)
    grad2plot_lh = lh_aligned_all[data_type]
    grad2plot_rh = rh_aligned_all[data_type]
    color_ranges_lh = [np.nanmax(np.abs([grad2plot_lh[0][:,pc]])) for pc in range(N_components_plot)]
    color_ranges_rh = [np.nanmax(np.abs([grad2plot_rh[0][:,pc]])) for pc in range(N_components_plot)]

    plotters_lh = create_hemisphere_plots(
        grad2plot_lh,
        '/Users/dennis.jungchildmind.org/Downloads/HCP_S1200_Atlas_Z4_pkXDZ/S1200.L.white_MSMAll.32k_fs_LR.surf.gii',
        'lh',
        N_components_plot,
        color_ranges_lh
    )

    plotters_rh = create_hemisphere_plots(
        grad2plot_rh,
        '/Users/dennis.jungchildmind.org/Downloads/HCP_S1200_Atlas_Z4_pkXDZ/S1200.R.white_MSMAll.32k_fs_LR.surf.gii',
        'rh',
        N_components_plot,
        color_ranges_rh
    )



      # Plot both hemispheres
    for hemi, plotters, color_ranges in [('LH', plotters_lh, color_ranges_lh), ('RH', plotters_rh, color_ranges_rh)]:
        f, (ax1, ax2) = plt.subplots(2, N_components_plot, figsize=(2*N_components_plot, 5))

        for pc in range(N_components_plot):
            # Create mappable object for colorbar
            im1 = ax1[pc].imshow(plotters[pc][0], cmap='RdBu_r', 
                                vmin=-color_ranges[pc], vmax=color_ranges[pc])
            ax1[pc].axis('off')
            ax1[pc].set_title(f'PC{pc+1}', fontsize=labelsize*2, fontweight='bold')
            
            im2 = ax2[pc].imshow(plotters[pc][1], cmap='RdBu_r',
                                vmin=-color_ranges[pc], vmax=color_ranges[pc])
            ax2[pc].axis('off')
            
            # Add colorbar between the plots
            divider = make_axes_locatable(ax1[pc])
            cax = divider.append_axes("bottom", size="5%", pad=0.5)
            cbar = plt.colorbar(im1, cax=cax, orientation='horizontal')
            cbar.ax.tick_params(labelsize=labelsize)  # Increase colorbar tick label size

        plt.tight_layout()
        #save figure and the save folder
        #make the save folder if it doesn't exist
        if not os.path.exists(SAVEFOLDER):
            os.makedirs(SAVEFOLDER)
        plt.savefig(os.path.join(SAVEFOLDER, f'{data_type}_{hemi.lower()}_aligned_reordered.png'), dpi=300)
        plt.close()
        

In [None]:
#in a large figure we are going to combine all figures to compare the notreordered and reordere figures
num_cols = len(data_types)
num_rows = 2
title_fontsize = 10

for HEMI in ['lh','rh']:
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols*3, 3*num_rows))

    for i, data_type in enumerate(data_types):
        # Load the unreordered figure
        unordered_path = os.path.join(SAVEFOLDER, f'{data_type}_{HEMI}_unaligned_unordered.png')
        unordered_img = plt.imread(unordered_path)
        
        # Load the reordered figure
        reordered_path = os.path.join(SAVEFOLDER, f'{data_type}_{HEMI}_aligned_reordered.png')
        reordered_img = plt.imread(reordered_path)

        # Display unreordered image in top row
        axes[0, i].imshow(unordered_img)
        axes[0, i].axis('off')
        axes[0, i].set_title(f'{data_type} (unordered)', fontsize=title_fontsize, fontweight='bold')
        
        # Display reordered image in bottom row
        axes[1, i].imshow(reordered_img)
        axes[1, i].axis('off')
        axes[1, i].set_title(f'{data_type} (reordered)', fontsize=title_fontsize, fontweight='bold')

    plt.tight_layout()
    plt.savefig(os.path.join(SAVEFOLDER, f'{HEMI}_combined_unordered_vs_reordered.png'), dpi=300, bbox_inches='tight')
    plt.close()
        





In [None]:
# Unmap gradient data to parcel level for both hemispheres
lh_data_all_unmapped = {}
rh_data_all_unmapped = {}


sparsity_idx = 0#0 is the 0.9 in the G_sparsity variable
# Process left hemisphere
for data_type in lh_data_all.keys():
    mapped_grad = lh_data_all[data_type][sparsity_idx]# this is unaligned (meaning that not reorderd and flipped)
    #mapped_grad = lh_aligned_all[data_type][sparsity_idx]# this is aligned (meaning that reorderd and not flipped)
    atlasDat_lh = atlasDat[:32492]
    unmapped_grad = np.zeros((np.max(atlasDat_lh)-np.min(atlasDat_lh[atlasDat_lh != 0])+1, mapped_grad.shape[1]))
    for i, parcel in enumerate(range(np.min(atlasDat_lh[atlasDat_lh != 0]), np.max(atlasDat_lh)+1)):
        unmapped_grad[i,:] = np.median(mapped_grad[atlasDat_lh == parcel,:], axis=0)
    lh_data_all_unmapped[data_type] = unmapped_grad

# Process right hemisphere 
for data_type in rh_data_all.keys():
    mapped_grad = rh_data_all[data_type][sparsity_idx] # this is unaligned (meaning that not reorderd and flipped)
   # mapped_grad = rh_aligned_all[data_type][sparsity_idx]# this is aligned (meaning that reorderd and not flipped)
    atlasDat_rh = atlasDat[32492:]
    unmapped_grad = np.zeros((np.max(atlasDat_rh)-np.min(atlasDat_rh[atlasDat_rh != 0])+1, mapped_grad.shape[1]))
    for i, parcel in enumerate(range(np.min(atlasDat_rh[atlasDat_rh != 0]), np.max(atlasDat_rh)+1)):
        unmapped_grad[i,:] = np.median(mapped_grad[atlasDat_rh == parcel,:], axis=0)
    rh_data_all_unmapped[data_type] = unmapped_grad


In [None]:
#Gradients versus Gradients (but jt seems like they want in parcel level not vertex level)
import seaborn as sns
from scipy import stats
import matplotlib.pyplot as plt

# Configuration
nPC = 0# 0 is the first PC
font_size = 24
tick_size = 20
data_type_ref_x = 'total'

for hemi in ['lh','rh']:

    if hemi == 'lh':
        gradient_data = lh_data_all_unmapped
    else:
        gradient_data = rh_data_all_unmapped

   
    data_types = [key for key in gradient_data.keys() if key != 'diff' and key != 'ratio_supra' and key != 'ratio_infra' and key != 'relative' and key != data_type_ref_x]
    n_plots = len(data_types)

    # Plot settings
    sns.set_style("white")
    sns.set_context("paper", font_scale=1.2)
    

    # Create figure and grid
    fig = plt.figure(figsize=(8*n_plots, 8))
    # Increase spacing between subplots
    gs = fig.add_gridspec(3, 3*n_plots, hspace=0.6, wspace=0.4)



    def get_title_name(data_type):
        titles = {
            'diff': f'G{nPC+1}(Symmetry)',
            'ratio_supra': f'G{nPC+1}(Supra/Total Ratio)',
            'ratio_infra': f'G{nPC+1}(Infra/Total Ratio)', 
            'relative': f'G{nPC+1}(Relative)',
            'he': f'G{nPC+1}(Hurst Exponent)',
            'supra': f'G{nPC+1}(Supra Thickness)',
            'infra': f'G{nPC+1}(Infra Thickness)',
            'total': f'G{nPC+1}(Total Thickness)'
        }
        return titles.get(data_type, '')


    def get_data_for_hemisphere(hemi, data_type_ref_x, gradient_data, data_type, nPC, atlas, atlas_to_network):
        if hemi == 'lh':
            #mask = atlas[0:32492].astype(bool)
            gradient_ref = gradient_data[data_type_ref_x][:,nPC]
            gradient = gradient_data[data_type][:,nPC]
            data_label_color = np.array(atlas_to_network[1:int(len(atlas_to_network)/2+1)])
            atlas_dat = atlas[:32492]
            atlas_min = np.min(atlas_dat[atlas_dat != 0])
            atlas_max = np.max(atlas_dat[atlas_dat != 0])
            new_data_label_color = []

            print(data_label_color.shape)
            for i,parcel in enumerate(range(atlas_min,atlas_max+1)):
                print(parcel)
                median_color = np.median(data_label_color[atlas_dat == parcel,:],axis=0)
                print(median_color)
                new_data_label_color.append(tuple(median_color.tolist()))
            data_label_color = new_data_label_color
        else:
            #mask = atlas[32492:].astype(bool)
            gradient_ref = gradient_data[data_type_ref_x][:,nPC]
            gradient = gradient_data[data_type][:,nPC]
            data_label_color = np.array(atlas_to_network[int(len(atlas_to_network)/2):])
            atlas_dat = atlas[32492:]
            atlas_min = np.min(atlas_dat[atlas_dat != 0])
            atlas_max = np.max(atlas_dat[atlas_dat != 0])
            new_data_label_color = []

            for i,parcel in enumerate(range(atlas_min,atlas_max+1)):
                median_color = np.median(data_label_color[atlas_dat == parcel,:],axis=0)
                new_data_label_color.append(tuple(median_color.tolist()))
            data_label_color = new_data_label_color


        return gradient_ref, gradient, data_label_color
    
    # Create network color mapping
    atlas_to_network = [yeo_network_colors[atlas[i]] for i in range(len(atlas))]

    # Plot each data type
    for plot_idx, data_type in enumerate(data_types):
        # Get data
        gradient_ref, gradient, data_label_color =get_data_for_hemisphere(
            hemi, data_type_ref_x, gradient_data, data_type, nPC, atlas, atlas_to_network
        )
        
        # Calculate correlation
        valid_mask = ~np.isnan(gradient_ref) & ~np.isnan(gradient)
        r, p = stats.pearsonr(gradient_ref[valid_mask], gradient[valid_mask])
        print(f"{data_type}: r={r:.2f}, p={p:.2e}")
        
        # Plot layout
        col_start = 3 * plot_idx
        col_end = col_start + 2
        
        # Main scatter plot
        ax_scatter = fig.add_subplot(gs[1:, col_start:col_end])
        sns.scatterplot(x=gradient_ref, y=gradient, alpha=1, s=60,
                        c=data_label_color,edgecolor='None', ax=ax_scatter)
        
        # Add regression line and correlation text
        sns.regplot(x=gradient_ref, y=gradient, scatter=False, color='black',
                    line_kws={'linestyle': '-', 'linewidth': 8}, ax=ax_scatter)
        ax_scatter.text(0.05, 0.95, f'r = {r:.2f}\np = {p:.2e}',
                        transform=ax_scatter.transAxes, va='top',
                        fontsize=font_size, weight='bold', color='black')
        
        # Format scatter plot
        ax_scatter.set_xlim(np.nanmin(gradient_ref), np.nanmax(gradient_ref))
        ax_scatter.set_ylim(np.nanmin(gradient), np.nanmax(gradient))
        ax_scatter.set_ylabel(f'G{nPC+1} ({data_type})', 
                            fontsize=font_size, labelpad=12, weight='bold')
        ax_scatter.set_xlabel(f'G{nPC+1} ({data_type_ref_x})', fontsize=font_size, labelpad=12, weight='bold')
        ax_scatter.set_title(get_title_name(data_type), fontsize=font_size+2, pad=20, weight='bold')
        ax_scatter.tick_params(axis='both', which='major', labelsize=tick_size)
        ax_scatter.grid(True, linestyle='--', alpha=0)
        
        for spine in ax_scatter.spines.values():
            spine.set_visible(True)
            spine.set_linewidth(2.0)
        
        # Top histogram
        ax_histx = fig.add_subplot(gs[0, col_start:col_end])
        sns.histplot(data=np.flip(gradient_ref), bins=50, kde=True, color='#8E44AD',
                    ax=ax_histx, stat='density', alpha=1)
        ax_histx.set(xlabel='', ylabel='')
        ax_histx.set_yticks([])
        ax_histx.tick_params(labelbottom=False, labelsize=tick_size)
        for spine in ax_histx.spines.values():
            spine.set_visible(False)
        
        # Right histogram
        ax_histy = fig.add_subplot(gs[1:, col_end])
        sns.histplot(y=np.flip(gradient), bins=50, kde=True, color='#8E44AD',
                    ax=ax_histy, stat='density', alpha=1)
        ax_histy.set(xlabel='', ylabel='')
        ax_histy.set_xticks([])
        ax_histy.tick_params(labelleft=False, labelsize=tick_size)
        for spine in ax_histy.spines.values():
            spine.set_visible(False)

    plt.tight_layout()
    plt.savefig('correlation_plot.png', dpi=300, bbox_inches='tight',
                facecolor='white', edgecolor='none')
    plt.show()


In [None]:
def scatter_with_histograms(
    x, y, xlabel_name, ylabel_name, marker_size=30, data_label_color=None, ax=None
):
    """
    Generate a scatterplot with histograms on a given Axes (ax). 
    If ax is not provided, create a new figure with marginal histograms.
    Returns: fig, axes_dict, stats_dict
    """
    from scipy import stats as sp_stats
    import numpy as np
    import matplotlib.pyplot as plt
    import seaborn as sns
    
    # Helper function to format p-values
    def format_pvalue(p):
        """Format p-value: show actual value if >= 0.001, otherwise show '< 0.001'"""
        if p < 0.001:
            return "< 0.001"
        else:
            return f"= {p:.3f}"
    
    # Compute correlation
    r, p = sp_stats.spearmanr(x, y)
    print(f"{ylabel_name}: r={r:.2f}, p={p:.2e}")
    
    # Format p-value for display
    p_formatted = format_pvalue(p)
    
    # Colors
    if data_label_color is None:
        data_label_color = ['#8E44AD' for _ in x]
    font_size = max(int(marker_size * 0.50), 10)
    tick_size = max(int(marker_size * 0.50), 8)
    
    # If an axis is passed, just draw scatter (no histograms, as in a multi-panel layout)
    if ax is not None:
        fig = ax.figure
        axes = {'scatter': ax}
        sns.scatterplot(x=x, y=y, alpha=1, s=marker_size, 
                        c=data_label_color, edgecolor='black', ax=ax)
        # Regression line
        if len(x) > 1:
            try:
                slope, intercept = np.polyfit(x, y, 1)
            except np.linalg.LinAlgError:
                slope, intercept = 0, np.nan
            x_mean, x_std = np.nanmean(x), np.nanstd(x)
            x_window = 3.5 * x_std
            x_lim = (x_mean - x_window, x_mean + x_window)
            x_line = np.linspace(x_lim[0], x_lim[1], 100)
            y_line = slope * x_line + intercept
            ax.plot(x_line, y_line, color='black', linestyle='--', linewidth=max(1, marker_size // 10))
        else:
            x_lim = (np.nanmin(x), np.nanmax(x))
            y_lim = (np.nanmin(y), np.nanmax(y))
        
        # Annotate correlation with formatted p-value
        ax.text(0.04, 0.97, f'r={r:.2f}, p {p_formatted}',
                   transform=ax.transAxes, va='top',
                   fontsize=font_size, weight='bold', color='black')
        
        # Axis limits
        if len(x) > 1:
            y_mean, y_std = np.nanmean(y), np.nanstd(y)
            y_window = 3.5 * y_std
            y_lim = (y_mean - y_window, y_mean + y_window)
        ax.set_xlim(x_lim)
        ax.set_ylim(y_lim)
        ax.set_xlabel(xlabel_name, fontsize=font_size, labelpad=2, weight='bold')
        ax.set_ylabel(ylabel_name, fontsize=font_size, labelpad=2, weight='bold')
        ax.tick_params(axis='both', which='major', labelsize=tick_size)
        ax.grid(True, linestyle='--', alpha=0)
        for spine in ax.spines.values():
            spine.set_visible(True)
            spine.set_linewidth(2.0)
        stats_dict = {'r': r, 'p': p}
        return fig, axes, stats_dict
    
    # If ax is None, draw as standalone figure with marginal histograms
    fig = plt.figure(figsize=(4, 4))
    gs = fig.add_gridspec(3, 3, hspace=0.12, wspace=0.12)
    axes = {}
    
    # Main scatter plot
    ax_scatter = fig.add_subplot(gs[1:, 0:2])
    axes['scatter'] = ax_scatter
    sns.scatterplot(x=x, y=y, alpha=1, s=marker_size, c=data_label_color, edgecolor='black', ax=ax_scatter)
    
    # Regression line
    if len(x) > 1:
        try:
            slope, intercept = np.polyfit(x, y, 1)
        except np.linalg.LinAlgError:
            slope, intercept = 0, np.nan
        x_mean, x_std = np.nanmean(x), np.nanstd(x)
        x_window = 3.5 * x_std
        x_lim = (x_mean - x_window, x_mean + x_window)
        x_line = np.linspace(x_lim[0], x_lim[1], 100)
        y_line = slope * x_line + intercept
        ax_scatter.plot(x_line, y_line, color='black', linestyle='--', linewidth=max(1, marker_size // 10))
    else:
        x_lim = (np.nanmin(x), np.nanmax(x))
        y_lim = (np.nanmin(y), np.nanmax(y))
    
    # Annotate correlation with formatted p-value
    ax_scatter.text(0.04, 0.97, f'r={r:.2f}, p {p_formatted}',
                    transform=ax_scatter.transAxes, va='top',
                    fontsize=font_size, weight='bold', color='black')
    
    # Axis limits
    if len(x) > 1:
        y_mean, y_std = np.nanmean(y), np.nanstd(y)
        y_window = 3.5 * y_std
        y_lim = (y_mean - y_window, y_mean + y_window)
    ax_scatter.set_xlim(x_lim)
    ax_scatter.set_ylim(y_lim)
    ax_scatter.set_xlabel(xlabel_name, fontsize=font_size, labelpad=1, weight='bold')
    ax_scatter.set_ylabel(ylabel_name, fontsize=font_size, labelpad=1, weight='bold')
    ax_scatter.tick_params(axis='both', which='major', labelsize=tick_size)
    ax_scatter.grid(True, linestyle='--', alpha=0)
    for spine in ax_scatter.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(2.0)
    
    # Top histogram
    ax_histx = fig.add_subplot(gs[0, 0:2])
    axes['histx'] = ax_histx
    sns.histplot(x=x, bins=50, kde=True, color='#8E44AD', ax=ax_histx, stat='density', alpha=1)
    ax_histx.set_xlim(x_lim)
    hist_ylim = ax_histx.get_ylim()
    ax_histx.set_ylim(0, hist_ylim[1]*3)
    ax_histx.set(xlabel='', ylabel='')
    ax_histx.set_yticks([])
    ax_histx.tick_params(labelbottom=False, labelsize=tick_size)
    for spine in ax_histx.spines.values():
        spine.set_visible(False)
    
    # Right histogram
    ax_histy = fig.add_subplot(gs[1:, 2])
    axes['histy'] = ax_histy
    sns.histplot(y=y, bins=50, kde=True, color='#8E44AD', ax=ax_histy, stat='density', alpha=1)
    ax_histy.set_ylim(y_lim)
    histx_lim = ax_histy.get_xlim()
    ax_histy.set_xlim(0, histx_lim[1]*3)
    ax_histy.set(xlabel='', ylabel='')
    ax_histy.set_xticks([])
    ax_histy.tick_params(labelleft=False, labelsize=tick_size)
    for spine in ax_histy.spines.values():
        spine.set_visible(False)
    
    plt.tight_layout()
    plt.close()
    
    stats_dict = {'r': r, 'p': p}
    return fig, axes, stats_dict

def _darker_color(color, factor=0.7):
    # Accept color in hex string or RGB; convert to RGB, darken, return as tuple
    rgb = mcolors.to_rgb(color)
    return tuple(factor * c for c in rgb)

In [None]:
#this is aperiod exponent from the sensor-space data
from scipy.stats import ranksums, spearmanr
import numpy as np
import matplotlib.pyplot as plt

exp_data_path = '/Users/dennis.jungchildmind.org/Downloads/exponent_result1.shape.gii'
#exp_data_path = '/Users/dennis.jungchildmind.org/Downloads/beta_power.shape.gii'
exp_data = nib.load(exp_data_path)
exp_data = exp_data.darrays[2]#0 is vertices, 1 is faces, 2 is exponent values

exp_data_lh = exp_data.data[0:32492]
exp_data_rh = exp_data.data[32492:]


hemi = 'lh'
YLABEL_NAME = 'Aperiodic Exponent'
LAYER_TYPE = 'total'
grad2plot = lh_data_all_unmapped[LAYER_TYPE]#total_grad #this is gradient for the hemi specific in the earlier section

if hemi == 'lh':
    atlas_data = atlas_data_lh  # Use LH for now
    exp_parcellated = parcellate_data(exp_data_lh, atlas_data)
   # yeo_color = yeo_network_colors[1:201]
else:
    atlas_data = atlas_data_rh
    exp_parcellated = parcellate_data(exp_data_rh, atlas_data)
  #  yeo_color = yeo_network_colors[201,:]

num_pcs_to_plot = 3

#generate a figure with num_pcs_to_plot subplots
main_fig, main_axes = plt.subplots(1, num_pcs_to_plot, figsize=(num_pcs_to_plot*3, num_pcs_to_plot))

for pc in range(num_pcs_to_plot):
    grad_parcellated = grad2plot[:,pc]#parcellate_data(grad2plot[:, pc], atlas_data)
    nonzero_mask = grad_parcellated != 0
    # Use correct colors for scatter plot based on hemisphere and nonzero_mask.
    #if hemi == 'lh':
        # Use first 200 colors, skip yeo_network_colors[0] (assume background/unknown), length should match parcellation.
       # color_subset = yeo_network_colors[1:201]
   # else:
        # Use next 200 colors
       # color_subset = yeo_network_colors[201:]

    # Filter to match nonzero_mask (scatter points only where data is present)
    #yeo_color = [c for c, keep in zip(color_subset, nonzero_mask) if keep]
    #edge_colors = [_darker_color(c, 0.7) for c in yeo_color]
    x = grad_parcellated[nonzero_mask]
    y = exp_parcellated[nonzero_mask]

    # Rank sum test: split y by above/below median x (not a "correlation", but a group difference)
    median_x = np.median(x)
    group1 = y[x >= median_x]
    group2 = y[x < median_x]
 


    # Instead of creating a new figure inside scatter_with_histograms, 
    # pass the target axis from main_axes[pc] so all drawing happens on your desired axes.
    # This means you need to make sure scatter_with_histograms allows passing an ax argument.
    _, _, stats_dict = scatter_with_histograms(
        x, y, 'Gradient '+str(pc+1)+ ' score', YLABEL_NAME,
        marker_size=30, data_label_color=None, ax=main_axes[pc]
    )
plt.tight_layout()
plt.show()



In [None]:
from scipy.stats import ranksums, spearmanr
import numpy as np
import matplotlib.pyplot as plt
import pickle


hemi = 'lh'

#for gradient (x-axis)
LAYER_TYPE = 'infra'

#for DFA (y-axis)
YLABEL_NAME = 'DFA Exponent'
FREQ_RANGES = [[0.1,4.0],[4.0,8.0],[8.0,15.0],[15.0,30.0],[30.0,60.0]]
num_pcs_to_plot = 3
n_freq = len(FREQ_RANGES)

# Make a big figure: each row = freq range, each col = one PC
# Use constrained_layout for better spacing
main_fig, main_axes = plt.subplots(
    n_freq, num_pcs_to_plot, 
    figsize=(num_pcs_to_plot*3, n_freq*3),  # Increased size for better visibility
    squeeze=False,
    constrained_layout=True
)

for fi, freq_range in enumerate(FREQ_RANGES):
    print(fi,freq_range)
    # Load DFA (Y-axis) data for this frequency
    dfa_data_path = f'/Users/dennis.jungchildmind.org/Desktop/MEG/112225_dfa_outputs/dfa_restin_f{freq_range[0]}_to_{freq_range[1]}_parcelonly.pkl'
    dfa_data = pickle.load(open(dfa_data_path, 'rb'))
    dfa_data_mean = np.nanmean(dfa_data, axis=0)
    dfa_data_mean_lh = dfa_data_mean[0:int(num_parcels/2)]
    dfa_data_mean_rh = dfa_data_mean[int(num_parcels/2):]
    
    # pick LH/RH for gradient and exp_parcellated
    grad2plot = lh_data_all_unmapped[LAYER_TYPE] if hemi=='lh' else rh_data_all_unmapped[LAYER_TYPE]
    
    if hemi == 'lh':
        atlas_data = atlas_data_lh
        exp_parcellated = dfa_data_mean_lh
        color_subset = yeo_network_colors[1:201]
    else:
        atlas_data = atlas_data_rh
        exp_parcellated = dfa_data_mean_rh
        color_subset = yeo_network_colors[201:]
    
    for pc in range(num_pcs_to_plot):
        ax = main_axes[fi, pc]
        grad_parcellated = grad2plot[:,pc] #parcellate_data(grad2plot[:, pc], atlas_data)
        nonzero_mask = grad_parcellated != 0
        x = grad_parcellated[nonzero_mask]
        y = exp_parcellated[nonzero_mask]
        
        # Optionally could test and annotate with stats, not required for plotting
        # median_x = np.median(x)
        # group1 = y[x >= median_x]
        # group2 = y[x < median_x]
        
        # Plot
        _, _, stats_dict = scatter_with_histograms(
            x, y, f'Gradient {pc+1} score', YLABEL_NAME,
            marker_size=30, data_label_color=color_subset, ax=ax)
        
        # Force the subplot box to be square, but keep independent axis scales
        ax.set_box_aspect(1)
        
        # Optionally annotate stats here: e.g. correlation, p-value from stats_dict

plt.show()

In [None]:
#Break HERE for now, since we don't wanna do the gene stuff yet
import abagen
import os
import pickle
from scipy.stats import zscore
#https://github.com/ThomasYeoLab/CBIG/blob/master/stable_projects/brain_parcellation/Schaefer2018_LocalGlobal/Parcellations/MNI/Schaefer2018_400Parcels_7Networks_order_FSLMNI152_1mm.nii.gz
atlas_path = '/Users/dennis.jungchildmind.org/Downloads/Schaefer2018_400Parcels_7Networks_order_FSLMNI152_1mm.nii.gz'
#need to use abagen 0.1.3 and pandas needs to be 1.5.3 to avoid inplace error 11/25/2025
expression_cache_file = 'ahba_expression.pkl'
if os.path.exists(expression_cache_file):
    with open(expression_cache_file, 'rb') as f:
        expression = pickle.load(f)
else:
    expression = abagen.get_expression_data(atlas_path, lr_mirror=True, return_donors=True)
    with open(expression_cache_file, 'wb') as f:
        pickle.dump(expression, f)


In [None]:
print(f"Number of keys: {len(expression['9861'].keys())}")
gene_keys = list(expression['9861'].keys())
htr_keys = [k for k in gene_keys if k.lower().startswith('htr')]
print(f"Keys that start with 'htr': {htr_keys}")


In [None]:

def get_gene_expression(expression, target_gene,do_zscore=True):
    """Get z-scored gene expression values across all donors for a target gene.
    
    Args:
        expression: Dictionary mapping donor IDs to expression DataFrames
        target_gene: Name of gene to extract expression for
        
    Returns:
        Array of z-scored expression values concatenated across donors
    """
    all_exp_data = []
    
    for donor, donor_data in expression.items():
        # Find columns containing target gene
        matching_cols = donor_data.columns[donor_data.columns==target_gene]
        
        if matching_cols.size > 0:
            # Extract and z-score expression values
            exp_values = donor_data[matching_cols].values
            if do_zscore:   
                exp_values = zscore(exp_values, axis=0, nan_policy='omit')
            all_exp_data.append(exp_values)


    if not all_exp_data:
        return None
    return np.concatenate(all_exp_data, axis=1)


In [None]:

#Gradients versus Gene Expression data at the Parcel Level
import seaborn as sns
from scipy import stats
import matplotlib.pyplot as plt
import numpy as np

correlations_file = 'gene_correlations.npy'
pvalues_file = 'gene_pvalues.npy'
gene_names_file = 'gene_names.npy'
data_types_file = 'data_types.npy'

if (os.path.exists(correlations_file) and os.path.exists(pvalues_file)
    and os.path.exists(gene_names_file) and os.path.exists(data_types_file)):
    correlations = np.load(correlations_file)
    p_values = np.load(pvalues_file)
    ALL_GENES = np.load(gene_names_file, allow_pickle=True)
    data_types = np.load(data_types_file, allow_pickle=True)
else:
    # Initialize arrays to store correlation results

    ALL_GENES = expression['9861'].columns.tolist()
    print(ALL_GENES)

    n_genes = len(ALL_GENES)
    data_types = [key for key in gradient_data.keys() if key != 'diff']
    n_data_types = len(data_types)

    # Arrays to store correlation coefficients and p-values
    correlations = np.zeros((n_genes, n_data_types, 2)) # 2 for lh and rh
    p_values = np.zeros((n_genes, n_data_types, 2))

    nPC = 1 #0 is the first PC

    # Loop through each gene
    for gene_idx, gene in enumerate(ALL_GENES):
        #print(f"\rProcessing gene {gene} ({gene_idx+1}/{n_genes})", end='', flush=True)
        
        # Get gene expression data
        gene_exp_data = get_gene_expression(expression, gene)
        the_gene_data = np.nanmean(gene_exp_data, axis=-1)

        # Process each hemisphere
        for hemi_idx, hemi in enumerate(['lh', 'rh']):
            if hemi == 'lh':
                gradient_data_hem = lh_data_all_unmapped
                atlas = atlas_data_lh
            else:
                gradient_data_hem = rh_data_all_unmapped
                atlas = atlas_data_rh

            # Create network color mapping
            atlas_to_network = [yeo_network_colors[atlas[i]] for i in range(len(atlas))]

            # Process each data type
            for data_type_idx, data_type in enumerate(data_types):
                # Get data for hemisphere
                if hemi == 'lh':
                    gradient = gradient_data_hem[data_type][:,nPC]
                    gene_ref = np.nan_to_num(the_gene_data[:int(num_parcels/2)])
                else:
                    gradient = gradient_data_hem[data_type][:,nPC]
                    gene_ref = np.nan_to_num(the_gene_data[int(num_parcels/2):])

                # Calculate correlation
                valid_mask = ~np.isnan(gene_ref) & ~np.isnan(gradient)
                r, p = stats.spearmanr(gene_ref[valid_mask], gradient[valid_mask])

                # Store results
                correlations[gene_idx, data_type_idx, hemi_idx] = r
                p_values[gene_idx, data_type_idx, hemi_idx] = p

    # Save results
    np.save('gene_correlations.npy', correlations)
    np.save('gene_pvalues.npy', p_values)

    # Also save metadata for interpreting results
    np.save('gene_names.npy', np.array(ALL_GENES))
    np.save('data_types.npy', np.array(data_types))



In [None]:
#gene,layer type,hemi (hemi means gradient1,graident2 sincne both hemisphere have the same level of gene expression level as they are reflected*)
from scipy.stats import false_discovery_control
# Using scipy (newer method)
alpha = 0.05
#the second dimension is layer type [['total', 'supra', 'infra', 'ratio_supra', 'ratio_infra', 'relative']]
#the third dimension is hemisphere [lh,rh]
corrected_pvals = false_discovery_control(p_values[:,0,0], method='bh')

significant_indices = np.where(corrected_pvals < alpha)[0]





In [None]:
HSE_GENES = ['BEND5','C1QL2','CACNA1E','COL24A1','COL6A1','CRYM','KCNC3',
             'KCNH4','LGALS1','MFGE8','NEFH','SCN3B','SCN4B',
             'SNCG','SV2C','SYT2','TPBG','VAMP1']
hse_gene_data = []
for gene in HSE_GENES:
    gene_exp = get_gene_expression(expression, gene)
    hse_gene_data.append(gene_exp)


hse_gene_data = np.stack(hse_gene_data, axis=0)
print(hse_gene_data.shape)
the_gene_data = np.nanmean(hse_gene_data, axis=-1)
the_gene_data = np.sum(the_gene_data, axis=0)
xlabel_name = 'HSE_GENES (mean)'
print(the_gene_data.shape)

In [None]:
#Gradients versus Gene Expression data at the Parcel Level
import seaborn as sns
from scipy import stats
import matplotlib.pyplot as plt
'''
pvalb_exp_data = np.load('073125_gene_expression/pvalb_exp_data.npy')
sst_exp_data = np.load('073125_gene_expression/sst_exp_data.npy')
vip_exp_data = np.load('073125_gene_expression/vip_exp_data.npy')
ndnf_exp_data = np.load('073125_gene_expression/ndnf_exp_data.npy')

rasgrf2_exp_data = np.load('073125_gene_expression/rasgrf2_exp_data.npy')
cux2_exp_data = np.load('073125_gene_expression/cux2_exp_data.npy')
rorb_exp_data = np.load('073125_gene_expression/rorb_exp_data.npy')
trib2_exp_data = np.load('073125_gene_expression/trib2_exp_data.npy')
b3galt2_exp_data = np.load('073125_gene_expression/b3galt2_exp_data.npy')
ntng2_exp_data = np.load('073125_gene_expression/ntng2_exp_data.npy')
tle4_exp_data = np.load('073125_gene_expression/tle4_exp_data.npy')
ctgf_exp_data = np.load('073125_gene_expression/ctgf_exp_data.npy')

HSE_GENES = ['BEND5','C1QL2','CACNA1E','COL24A1','COL6A1','CRYM','KCNC3',
             'KCNH4','LGALS1','MFGE8','NEFH','SCN3B','SCN4B',
             'SNCG','SV2C','SYT2','TPBG','VAMP1']

bend5_exp_data = np.load('073125_gene_expression/bend5_exp_data.npy')
c1ql2_exp_data = np.load('073125_gene_expression/c1ql2_exp_data.npy')
cacna1e_exp_data = np.load('073125_gene_expression/cacna1e_exp_data.npy')
col24a1_exp_data = np.load('073125_gene_expression/col24a1_exp_data.npy')
col6a1_exp_data = np.load('073125_gene_expression/col6a1_exp_data.npy')
kcnc3_exp_data = np.load('073125_gene_expression/kcnc3_exp_data.npy')
grin2b_exp_data = np.load('073125_gene_expression/grin2b_exp_data.npy')
'''

# Compute the nanmean across all genes in HSE_GENES
HSE_GENES = ['BEND5','C1QL2','CACNA1E','COL24A1','COL6A1','CRYM','KCNC3',
             'KCNH4','LGALS1','MFGE8','NEFH','SCN3B','SCN4B',
             'SNCG','SV2C','SYT2','TPBG','VAMP1']
hse_gene_data = []
for gene in HSE_GENES:
    gene_exp = get_gene_expression(expression, gene)
    hse_gene_data.append(gene_exp)


hse_gene_data = np.stack(hse_gene_data, axis=0)
print(hse_gene_data.shape)
the_gene_data = np.nanmean(hse_gene_data, axis=-1)
the_gene_data = np.sum(the_gene_data, axis=0)
xlabel_name = 'HSE_GENES (mean)'
print(the_gene_data.shape)


the_gene_data = np.nanmean(get_gene_expression(expression, "SST")-get_gene_expression(expression, "PVALB"), axis=-1)
xlabel_name = 'SST - PVALB'
#xlabel_name ='PVALB'
#the_gene_data = np.nanmean(get_gene_expression(expression, xlabel_name), axis=-1)

# Configuration
nPC = 0# 0 is the first PC
font_size = 24
tick_size = 20
#data_type_ref_x = 'ratio_supra'

for hemi in ['lh','rh']:

    if hemi == 'lh':
        gradient_data = lh_data_all_unmapped
    else:
        gradient_data = rh_data_all_unmapped

   
    data_types = [key for key in gradient_data.keys() if key != 'diff']# if key != 'diff' and key != 'ratio_supra' and key != 'ratio_infra' and key != 'relative' and key != data_type_ref_x]
    n_plots = len(data_types)

    # Plot settings
    sns.set_style("white")
    sns.set_context("paper", font_scale=1.2)
    

    # Create figure and grid
    fig = plt.figure(figsize=(8*n_plots, 8))
    # Increase spacing between subplots
    gs = fig.add_gridspec(3, 3*n_plots, hspace=0.1, wspace=0.1)



    def get_title_name(data_type):
        titles = {
            'diff': f'G{nPC+1}(Symmetry)',
            'ratio_supra': f'G{nPC+1}(Supra/Total Ratio)',
            'ratio_infra': f'G{nPC+1}(Infra/Total Ratio)', 
            'relative': f'G{nPC+1}(Relative)',
            'he': f'G{nPC+1}(Hurst Exponent)',
            'supra': f'G{nPC+1}(Supra Thickness)',
            'infra': f'G{nPC+1}(Infra Thickness)',
            'total': f'G{nPC+1}(Total Thickness)',
            'total_hcp': f'G{nPC+1}(Total Thickness(HCP))',
        }
        return titles.get(data_type, '')


    def get_data_for_hemisphere(hemi, data_type_ref_x, gradient_data, data_type, nPC, atlas, atlas_to_network):
        if hemi == 'lh':
            gradient = gradient_data[data_type][:,nPC]
            gradient_ref = np.nan_to_num(the_gene_data[:200])
            data_label_color = np.array(atlas_to_network)
            atlas_min = np.min(atlas[atlas != 0])
            atlas_max = np.max(atlas[atlas != 0])
            new_data_label_color = []

            print(data_label_color.shape)
            for i,parcel in enumerate(range(atlas_min,atlas_max+1)):
               
                median_color = np.median(data_label_color[atlas == parcel,:],axis=0)
       
                new_data_label_color.append(tuple(median_color.tolist()))
            data_label_color = new_data_label_color
        else:
            gradient = gradient_data[data_type][:,nPC]
            gradient_ref = np.nan_to_num(the_gene_data[200:])
            data_label_color = np.array(atlas_to_network)   
            atlas_min = np.min(atlas[atlas != 0])
            atlas_max = np.max(atlas[atlas != 0])
            new_data_label_color = []

            for i,parcel in enumerate(range(atlas_min,atlas_max+1)):
                median_color = np.median(data_label_color[atlas == parcel,:],axis=0)
                new_data_label_color.append(tuple(median_color.tolist()))
            data_label_color = new_data_label_color

        return gradient_ref, gradient, data_label_color
    
    

    # Plot each data type
    for plot_idx, data_type in enumerate(data_types):
        # Get data
        if hemi == 'lh':
            atlas = atlas_data_lh
        else:
            atlas = atlas_data_rh

        # Create network color mapping
        atlas_to_network = [yeo_network_colors[atlas[i]] for i in range(len(atlas))]
        gradient_ref, gradient, data_label_color = get_data_for_hemisphere(
            hemi, data_type_ref_x, gradient_data, data_type, nPC, atlas, atlas_to_network
        )
        
        # Calculate correlation
        valid_mask = ~np.isnan(gradient_ref) & ~np.isnan(gradient)
        r, p = stats.spearmanr(gradient_ref[valid_mask], gradient[valid_mask])
        print(f"{data_type}: r={r:.2f}, p={p:.2e}")
        
        # Plot layout
        col_start = 3 * plot_idx
        col_end = col_start + 2
        
        # Main scatter plot
        ax_scatter = fig.add_subplot(gs[1:, col_start:col_end])
        sns.scatterplot(x=gradient_ref, y=gradient, alpha=1, s=150,
                        c=data_label_color,edgecolor='black', ax=ax_scatter)
        
        # Add regression line and correlation text
        x_range = np.array([gradient_ref.min(), gradient_ref.max()])
        x_extended = np.array([x_range[0]*3.5, x_range[1]*3.5]) 
        slope, intercept = np.polyfit(gradient_ref[~np.isnan(gradient_ref)], 
                                    gradient[~np.isnan(gradient)], 1)
        y_extended = slope * x_extended + intercept
        ax_scatter.plot(x_extended, y_extended, color='black', linestyle='--', linewidth=4)
        ax_scatter.text(0.05, 0.95, f'r={r:.2f}, p={p:.2e}',
                        transform=ax_scatter.transAxes, va='top',
                        fontsize=font_size, weight='bold', color='black')
        
        # Center and format scatter plot
        x_mean = np.nanmean(gradient_ref)
        x_std = np.nanstd(gradient_ref)
        x_window = 3.5 * x_std
        x_lim = (x_mean - x_window, x_mean + x_window)
        ax_scatter.set_xlim(x_lim)
        
        y_mean = np.nanmean(gradient)
        y_std = np.nanstd(gradient)
        y_window = 3.5 * y_std
        y_lim = (y_mean - y_window, y_mean + y_window)
        ax_scatter.set_ylim(y_lim)
        
        ax_scatter.set_ylabel(f'{get_title_name(data_type)}', 
                            fontsize=font_size, labelpad=12, weight='bold')
        ax_scatter.set_xlabel(xlabel_name, fontsize=font_size, labelpad=12, weight='bold')
        ax_scatter.set_title(get_title_name(data_type), fontsize=font_size+2, pad=20, weight='bold')
        ax_scatter.tick_params(axis='both', which='major', labelsize=tick_size)
        ax_scatter.grid(True, linestyle='--', alpha=0)
        
        for spine in ax_scatter.spines.values():
            spine.set_visible(True)
            spine.set_linewidth(2.0)
        
        # Top histogram - align with scatter plot x-axis
        ax_histx = fig.add_subplot(gs[0, col_start:col_end])
        sns.histplot(data=gradient_ref, bins=50, kde=True, color='#8E44AD',
                    ax=ax_histx, stat='density', alpha=1)
        ax_histx.set_xlim(x_lim)
        ylim = ax_histx.get_ylim()
        ax_histx.set_ylim(0, ylim[1]*3)
        ax_histx.set(xlabel='', ylabel='')
        ax_histx.set_yticks([])
        ax_histx.tick_params(labelbottom=False, labelsize=tick_size)
        for spine in ax_histx.spines.values():
            spine.set_visible(False)
        
        # Right histogram - align with scatter plot y-axis
        ax_histy = fig.add_subplot(gs[1:, col_end])
        sns.histplot(y=gradient, bins=50, kde=True, color='#8E44AD',
                    ax=ax_histy, stat='density', alpha=1)
        ax_histy.set_ylim(y_lim)
        xlim = ax_histy.get_xlim()
        ax_histy.set_xlim(0, xlim[1]*3)
        ax_histy.set(xlabel='', ylabel='')
        ax_histy.set_xticks([])
        ax_histy.tick_params(labelleft=False, labelsize=tick_size)
        for spine in ax_histy.spines.values():
            spine.set_visible(False)

    plt.tight_layout()
    plt.savefig('correlation_plot.png', dpi=300, bbox_inches='tight',
                facecolor='white', edgecolor='none')
    plt.show()


In [None]:
# This rewrite matches the plotting/layout conventions of @file_context_0 (cells 23-60):
# - Creates subplots (fig,axes) in advance
# - Loops over genes, passing `ax` to scatter_with_histograms as main_axes[pc] is used before
# - Respects the same width/height ratio and tight_layout details

import seaborn as sns
from scipy import stats

GENES2PLOT = ['LAMP5','SNCG','VIP','SST','PVALB']
#GENES2PLOT = [('SST', 'PVALB'),('VIP', 'SST'),('VIP','PVALB')]


hemi = 'lh'

exp_type = 'dfa'#'ae' for aperiodic exponent or 'dfa' detrended fluctuation analysis

if exp_type == 'ae':
    YLABEL_NAME = 'Aperiodic Exponent' #'Aperiodic Exponent'
    exp_data_path = '/Users/dennis.jungchildmind.org/Desktop/MEG/112325_fooof_exponent/fooof_brainic_parcelonly.pkl'
elif exp_type == 'dfa':
    YLABEL_NAME = 'DFA exponent'
    # Load DFA (Y-axis) data for this frequency
    freq_range = [8.0, 15.0]
    exp_data_path = f'/Users/dennis.jungchildmind.org/Desktop/MEG/112225_dfa_outputs/dfa_restin_f{freq_range[0]}_to_{freq_range[1]}_parcelonly.pkl'

exp_data = pickle.load(open(exp_data_path, 'rb'))
exp_data_mean = np.nanmean(exp_data, axis=0)
exp_data_mean_lh = exp_data_mean[0:int(num_parcels/2)]
exp_data_mean_rh = exp_data_mean[int(num_parcels/2):]

# --- Handle tuple for subtraction, otherwise just plot each gene in list ---
processed_genes = []
xlabel_names = []

for gene in GENES2PLOT:
    if isinstance(gene, tuple):
        # Do subtraction between first and second entry (gene1 - gene2)
        data1 = np.nanmean(get_gene_expression(expression, gene[0]), axis=-1)
        data2 = np.nanmean(get_gene_expression(expression, gene[1]), axis=-1)
        gene_data = data1 - data2
        processed_genes.append(gene_data)
        xlabel_names.append(f"{gene[0]} - {gene[1]}")
    else:
        gene_data = np.nanmean(get_gene_expression(expression, gene), axis=-1)
        processed_genes.append(gene_data)
        xlabel_names.append(gene)

num_genes_to_plot = len(processed_genes)
fig, main_axes = plt.subplots(1, num_genes_to_plot, figsize=(num_genes_to_plot**2, num_genes_to_plot))

if num_genes_to_plot == 1:
    main_axes = [main_axes]  # Ensure iterable if single plot


for idx, (the_gene_data, xlabel_name) in enumerate(zip(processed_genes, xlabel_names)):
    ax = main_axes[idx]
    ylabel_name = YLABEL_NAME
    if hemi == 'lh':
        the_gene_data = the_gene_data[:200]
       # exponent_parcellated = parcellate_data(exp_data_lh, atlas_data)
        exponent_parcellated = exp_data_mean_lh
    else:
        the_gene_data = the_gene_data[200:]
        #exponent_parcellated = parcellate_data(exp_data_rh, atlas_data)
        exponent_parcellated = exp_data_mean_rh
    valid_mask = ~np.isnan(the_gene_data) & ~np.isnan(exponent_parcellated)
    x = the_gene_data[valid_mask]
    y = exponent_parcellated[valid_mask]
    # Use color convention: first 200 for lh, second 200 for rh. +1 offset for matched color scheme.
    if hemi == 'lh':
        color_subset = yeo_network_colors[1:201]
    else:
        color_subset = yeo_network_colors[201:]
    data_label_color = [c for c, k in zip(color_subset, valid_mask) if k]

    _, _, stats_dict = scatter_with_histograms(
        x, y, xlabel_name, ylabel_name,
        marker_size=len(GENES2PLOT)*10, data_label_color=data_label_color, ax=ax
    )

plt.tight_layout()
plt.show()

In [None]:
import os

corr_file = 'meg_exp_correlations.npy'
pval_file = 'meg_exp_pvalues.npy'
genes_file = 'meg_exp_gene_names.npy'

# Try to load precomputed data if available
if os.path.exists(corr_file) and os.path.exists(pval_file) and os.path.exists(genes_file):
    correlations = np.load(corr_file)
    p_values = np.load(pval_file)
    ALL_GENES = np.load(genes_file, allow_pickle=True).tolist()
    n_genes = len(ALL_GENES)
    print("Loaded correlation results from files.")
else:
    # Initialize arrays to store correlation results
    ALL_GENES = expression['9861'].columns.tolist()
    n_genes = len(ALL_GENES)

    # Arrays to store correlation coefficients and p-values
    correlations = np.zeros((n_genes, 2)) # 2 for lh and rh
    p_values = np.zeros((n_genes, 2))

    # Loop through each gene
    for gene_idx, gene in enumerate(ALL_GENES):
        #print(f"\rProcessing gene {gene} ({gene_idx+1}/{n_genes})", end='', flush=True)
        
        # Get gene expression data
        gene_exp_data = get_gene_expression(expression, gene)
        the_gene_data = np.nanmean(gene_exp_data, axis=-1)

        # Process each hemisphere
        for hemi_idx, hemi in enumerate(['lh', 'rh']):
            if hemi == 'lh':
                exp_data = parcellate_data(exp_data_lh, atlas_data_lh)
                gene_ref = np.nan_to_num(the_gene_data[:int(num_parcels/2)])
            else:
                exp_data = parcellate_data(exp_data_rh, atlas_data_rh)
                gene_ref = np.nan_to_num(the_gene_data[int(num_parcels/2):])

            # Calculate correlation
            valid_mask = ~np.isnan(gene_ref) & ~np.isnan(exp_data)
            r, p = stats.spearmanr(gene_ref[valid_mask], exp_data[valid_mask])

            # Store results
            correlations[gene_idx, hemi_idx] = r
            p_values[gene_idx, hemi_idx] = p

    # Save results
    np.save(corr_file, correlations)
    np.save(pval_file, p_values)
    np.save(genes_file, np.array(ALL_GENES))
    print("Correlation results computed and saved to files.")



In [None]:
#gene,layer type,hemi (hemi means gradient1,graident2 sincne both hemisphere have the same level of gene expression level as they are reflected*)
from scipy.stats import false_discovery_control
target_gene_index = ALL_GENES.index('LAMP5')
pvalb_corr = correlations[target_gene_index, 0]  # LH correlation for PVALB

# Plot using a density curve instead of histogram
from scipy.stats import gaussian_kde

# Kernel density estimation
corr_data = correlations[:,0]
corr_data = corr_data[~np.isnan(corr_data)]  # remove NaNs if any

fig = plt.figure(figsize=(4,3))
ax = fig.add_subplot(111)
ax.hist(corr_data - np.mean(corr_data), bins=30, color='black', alpha=0.8)
ax.axvline(x=pvalb_corr - np.mean(corr_data), color='red', linestyle='-', linewidth=2)
ax.set_xlabel("Spatial Correlation (r)", fontsize=18)
ax.set_xlim(-1,1)
ax.legend(fontsize=16)
# Make tick labels bigger
ax.tick_params(axis='both', which='major', labelsize=16)
ax.tick_params(axis='both', which='minor', labelsize=14)
# Remove top and right spines
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.tight_layout()
plt.show()
alpha = 0.05
#the second dimension is layer type [['total', 'supra', 'infra', 'ratio_supra', 'ratio_infra', 'relative']]
#the third dimension is hemisphere [lh,rh]
corrected_pvals = false_discovery_control(p_values[:,0], method='bh')

significant_indices = np.where(corrected_pvals < alpha)[0]
print(significant_indices)
print(p_values[significant_indices])
print(correlations[significant_indices])


In [None]:
import numpy as np
from scipy.stats import spearmanr
PC_INDEX = 0#of cortical thickness gradient
LAYER_TYPE = 'ratio_supra'
GENE_AXIS_LABEL = "SST-PVALB"
Y_AXIS_LABEL = "DFA Exponent"
the_gene_data = np.nanmean(get_gene_expression(expression, "SST")-get_gene_expression(expression, "PVALB"), axis=-1)

#GENE_AXIS_LABEL = "PVALB"
#the_gene_data = np.nanmean(get_gene_expression(expression, GENE_AXIS_LABEL), axis=-1)
the_gene_data_lh = the_gene_data[:int(num_parcels/2)]
the_gene_data_lh = the_gene_data[:int(num_parcels/2)]

'''
#MEG sensor space 
exp_data_path = '/Users/dennis.jungchildmind.org/Downloads/exponent_result1.shape.gii'
#exp_data_path = '/Users/dennis.jungchildmind.org/Downloads/offset_result1.shape.gii'
exp_data = nib.load(exp_data_path)
exp_data = exp_data.darrays[2]#0 is vertices, 1 is faces, 2 is exponent values

exp_data_lh = exp_data.data[0:32492]
exp_data_rh = exp_data.data[32492:]

exp_dat_lh_parc = parcellate_data(exp_data_lh, atlas_data_lh)
exp_dat_lh_parc = parcellate_data(exp_data_lh, atlas_data_lh)
'''


import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression

# Define your list of frequency ranges
freq_ranges = [
    [0.1, 4.0],
    [4.0, 8.0],
    [8.0, 15.0],
    [15.0, 30.0],
    [30.0, 60.0]
]

font_size = 17

# Store correlation stats to print all at end, if desired
corr_stats = []

# Prepare figure with one row per freq_range, 3 columns
fig, axes = plt.subplots(len(freq_ranges), 3, figsize=(3*3, 3*len(freq_ranges)), squeeze=False)

for i, freq_range in enumerate(freq_ranges):
    # --- Load E/I data for this freq_range ---
    dfa_data_path = f'/Users/dennis.jungchildmind.org/Desktop/MEG/112225_dfa_outputs/dfa_restin_f{freq_range[0]}_to_{freq_range[1]}_parcelonly.pkl'
    dfa_data = pickle.load(open(dfa_data_path, 'rb'))
    dfa_data_mean = np.nanmean(dfa_data, axis=0)
    exp_dat_lh_parc = dfa_data_mean[0:int(num_parcels/2)]
    exp_dat_rh_parc = dfa_data_mean[int(num_parcels/2):]

    # --- Compose predictors ---
    X = np.column_stack([exp_dat_lh_parc, the_gene_data_lh])

    # --- Gradient axis and plotting window ---
    y = lh_data_all_unmapped[LAYER_TYPE][:, PC_INDEX]
    y_mean = np.nanmean(y)
    y_std = np.nanstd(y)
    y_window = 3.5 * y_std
    y_lim = (y_mean - y_window, y_mean + y_window)

    # Optionally check/print data for this freq_range
    print(f"\n[Freq {freq_range[0]}-{freq_range[1]} Hz]")
    print(f"Original gradient shape: {lh_data_all_unmapped[LAYER_TYPE].shape}")
    print(f"X (LIMITED) shape: {X.shape}")
    print(f"y shape: {y.shape}")
    print(f"Aperiodic Exponent min/max after clip: {X[:,0].min():.3f}, {X[:,0].max():.3f}")
    print(f"Gene SST-PVALB min/max after clip: {X[:,1].min():.3f}, {X[:,1].max():.3f}")

    # --- Remove NaNs ---
    index = ~np.isnan(X).any(axis=1) & ~np.isnan(y)
    X_clean = X[index]
    y_clean = y[index]

    print(f"N regions after removing NaNs: {len(y_clean)}")

    # --- Correlations (Step 1) ---
    r_EI, p_EI = spearmanr(X_clean[:, 0], y_clean)
    r_gene, p_gene = spearmanr(X_clean[:, 1], y_clean)

    # --- Partial "multiple regression" (Step 2) ---
    model_ei_gene = LinearRegression().fit(X_clean[:, [1]], X_clean[:, 0])
    ei_resid = X_clean[:, 0] - model_ei_gene.predict(X_clean[:, [1]])
    model_gene_ei = LinearRegression().fit(X_clean[:, [0]], X_clean[:, 1])
    gene_resid = X_clean[:, 1] - model_gene_ei.predict(X_clean[:, [0]])
    r_partial_ei, p_partial_ei = spearmanr(ei_resid, y_clean)
    r_partial_gene, p_partial_gene = spearmanr(gene_resid, y_clean)

    # Print to console for each band:
    print(f"\nUnivariate correlations (Spearman):")
    print(f"  E/I:  r={r_EI:.3f}, p={p_EI:.4f}")
    print(f"  Gene: r={r_gene:.3f}, p={p_gene:.4f}")
    print(f"\nPartial Spearman correlations:")
    print(f"  Gradient vs E/I (controlling for gene):    r={r_partial_ei:.3f}, p={p_partial_ei:.4f}")
    print(f"  Gradient vs Gene (controlling for E/I):    r={r_partial_gene:.3f}, p={p_partial_gene:.4f}")
    print("\nVariance explained (approximated by rÂ²):")
    print(f"  E/I only:      rÂ²={r_EI**2:.3f}")
    print(f"  Gene only:     rÂ²={r_gene**2:.3f}")
    print(f"  E/I|Gene:      rÂ²={r_partial_ei**2:.3f}")
    print(f"  Gene|E/I:      rÂ²={r_partial_gene**2:.3f}")

    # --- Store stats if desired ---
    corr_stats.append(dict(
        freq_range=f"{freq_range[0]}-{freq_range[1]}Hz",
        r_EI=r_EI, p_EI=p_EI,
        r_gene=r_gene, p_gene=p_gene,
        r_partial_ei=r_partial_ei, p_partial_ei=p_partial_ei,
        r_partial_gene=r_partial_gene, p_partial_gene=p_partial_gene
    ))

    # --- Plotting (Step 3) ---
    ax_row = axes[i]
    # E/I vs gradient
    ax_row[0].scatter(y_clean, X_clean[:, 0], alpha=0.5, s=30)
    ax_row[0].set_xlabel(f'Gradient {PC_INDEX+1} Scores', fontsize=16, fontweight='bold')
    ax_row[0].set_ylabel(Y_AXIS_LABEL, fontsize=16, fontweight='bold')
    ax_row[0].set_xlim(y_lim)
    ax_row[0].tick_params(axis='both', which='major', labelsize=16)
    ax_row[0].tick_params(axis='both', which='minor', labelsize=14)
    ax_row[0].spines['top'].set_visible(False)
    ax_row[0].spines['right'].set_visible(False)
    ax_row[0].text(
        0.99, 0.99, f"r={r_EI:.2f}, p={p_EI:.2e}",
        ha="right", va="top", transform=ax_row[0].transAxes,
        fontsize=font_size*0.7, weight='bold', color='black'
    )
    ax_row[0].set_box_aspect(1)
    ax_row[0].set_title(f'{freq_range[0]}-{freq_range[1]}Hz: Grad vs. E/I', fontsize=15)

    # Gene vs gradient
    ax_row[1].scatter(y_clean, X_clean[:, 1], alpha=0.5, s=30, color='orange')
    ax_row[1].set_xlabel(f'Gradient {PC_INDEX+1} Scores', fontsize=16, fontweight='bold')
    ax_row[1].set_ylabel(GENE_AXIS_LABEL, fontsize=16, fontweight='bold')
    ax_row[1].set_xlim(y_lim)
    ax_row[1].set_ylim(-3, 3)
    ax_row[1].tick_params(axis='both', which='major', labelsize=16)
    ax_row[1].tick_params(axis='both', which='minor', labelsize=14)
    ax_row[1].spines['top'].set_visible(False)
    ax_row[1].spines['right'].set_visible(False)
    ax_row[1].text(
        0.99, 0.99, f"r={r_gene:.2f}, p={p_gene:.2e}",
        ha="right", va="top", transform=ax_row[1].transAxes,
        fontsize=font_size*0.7, weight='bold', color='black'
    )
    ax_row[1].set_box_aspect(1)
    ax_row[1].set_title('Grad vs. Gene', fontsize=15)

    # Residualized/pseudo-multivariate fit
    ax_row[2].scatter(y_clean, ei_resid, alpha=0.5, s=30, color='green')
    ax_row[2].set_xlabel(f'Gradient {PC_INDEX+1} Scores', fontsize=16, fontweight='bold')
    ax_row[2].set_ylabel(f'{Y_AXIS_LABEL}\n(resid E/I | Gene)', fontsize=16, fontweight='bold')
    ax_row[2].set_xlim(y_lim)
    ax_row[2].tick_params(axis='both', which='major', labelsize=16)
    ax_row[2].tick_params(axis='both', which='minor', labelsize=14)
    ax_row[2].spines['top'].set_visible(False)
    ax_row[2].spines['right'].set_visible(False)
    ax_row[2].text(
        0.99, 0.99, f"r={r_partial_ei:.2f}, p={p_partial_ei:.2e}",
        ha="right", va="top", transform=ax_row[2].transAxes,
        fontsize=font_size*0.7, weight='bold', color='black'
    )
    ax_row[2].set_box_aspect(1)
    ax_row[2].set_title('Grad vs. E/I (control gene)', fontsize=15)

plt.tight_layout()
plt.show()

In [None]:
"""
Organized analysis:
1. Prepare gene data matrix for defined genes & combinations.
2. Prepare DFA feature matrix for all frequency ranges.
3. Prepare cortical gradient matrices for all layer types.
4. Concatenate gene and DFA features.
5. Compute Spearman correlations with gradients.
6. Plot and display.
"""

# ==== Imports and Config ====
from scipy.stats import spearmanr
from statsmodels.stats.multitest import multipletests
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.colors as mcolors

ALPHA = 0.0001

# ---- CONFIGURATION ----
#GENES
TARGET_GENES = ['PVALB', 'SST', 'VIP','LAMP5','SNCG']
#TARGET_GENES = sorted((list(expression['9861'].keys())))
def _label_from_gene(g):
    if isinstance(g, tuple):
        if len(g) == 2:
            return f"{g[0]}-{g[1]}"
        else:
            return "-".join(g)
    else:
        return g
TARGET_GENES_LABELS = [_label_from_gene(g) for g in TARGET_GENES]
#DFA EXPONENTS
TARGET_FREQ_RANGES = [[0.1, 4.0], [4.0, 8.0], [8.0, 15.0], [15.0, 30.0], [30.0, 60.0]]
TARGET_FREQ_LABELS = ['DF(Delta)', 'DF(Theta)', 'DF(Alpha)', 'DF(Beta)', 'DF(Gamma)']

#RSS EXPONENTS
TARGET_RS_LABELS = ['RS(Delta)', 'RS(Theta)', 'RS(Alpha)', 'RS(Beta)', 'RS(Gamma)']

#APERIOD EXPONENTS
TARGET_APERIOD_EXPONENTS = ['AE(brain)']#, 'AE(all)', 'AO(brain)', 'AO(all)']
#GRADIENTS
TARGET_LAYER_TYPES = ['total_hcp','total','supra', 'infra', 'ratio_supra', 'ratio_infra']
TARGET_LAYER_TYPES_LABELS = ['HCP: Total','Ex: Total', 'Ex: Supra', 'Ex: Infra', 'Ex: Supra/Total', 'Ex: Infra/Total']

PC_INDEX = 1   # Principal component for gradients
hemi = 'lh'    # Hemisphere selection; assumes 200 parcels per hemisphere

# ==== 1. PREPARE GENE DATA MATRIX ====
def get_parcel_gene_data(gene, hemi):
    """Return mean (across donors) gene expression for the selected hemisphere."""
    if isinstance(gene, tuple):
        if len(gene) == 2:
            g1 = np.nanmean(get_gene_expression(expression, gene[0]), axis=-1)
            g2 = np.nanmean(get_gene_expression(expression, gene[1]), axis=-1)
            vals = g1 - g2
            label = f"{gene[0]}-{gene[1]}"
        elif len(gene) > 2:
            print(gene)
            vals = np.zeros_like(np.nanmean(get_gene_expression(expression, gene[0]), axis=-1))
            label = "-".join(gene)
            for g in gene:
                vals += np.nanmean(get_gene_expression(expression, g), axis=-1)
    else:
        vals = np.nanmean(get_gene_expression(expression, gene), axis=-1)
        label = gene
    # Split by hemisphere
    if hemi == 'lh':
        vals = vals[:200]
    else:
        vals = vals[200:]
    return vals, label

gene_data_matrix = []
gene_data_labels = []
for gene in TARGET_GENES:
    vals, label = get_parcel_gene_data(gene, hemi)
    gene_data_matrix.append(vals)
    gene_data_labels.append(label)
gene_data_matrix = np.stack(gene_data_matrix, axis=1)  # shape: (n_parcels, n_genes)

# ==== 2. PREPARE DFA FEATURE MATRIX ====
def get_dfa_matrix(freq_ranges, hemi):
    dfa_matrix = []
    for freq_range in freq_ranges:
        dfa_data_path = f'/Users/dennis.jungchildmind.org/Desktop/MEG/112225_dfa_outputs/dfa_restin_f{freq_range[0]}_to_{freq_range[1]}_parcelonly.pkl'
        dfa_data = pickle.load(open(dfa_data_path, 'rb'))
        dfa_data_mean = np.nanmean(dfa_data, axis=0)
        if hemi == 'lh':
            dfa_vals = dfa_data_mean[0:int(num_parcels/2)]
        else:
            dfa_vals = dfa_data_mean[int(num_parcels/2):]
        dfa_matrix.append(dfa_vals)
    dfa_matrix = np.array(dfa_matrix).T  # parcels x freq
    return dfa_matrix

dfa_matrix_lh = get_dfa_matrix(TARGET_FREQ_RANGES, "lh")
dfa_matrix_rh = get_dfa_matrix(TARGET_FREQ_RANGES, "rh")

# ==== 2. PREPARE RRS FEATURE MATRIX ====
def get_rrs_matrix(freq_ranges, hemi):
    rrs_matrix = []
    for freq_range in freq_ranges:
        rrs_data_path = f'/Users/dennis.jungchildmind.org/Desktop/MEG/112625_rrs_outputs/rrs_restin_f{freq_range[0]}_to_{freq_range[1]}_parcelonly.pkl'
        rrs_data = pickle.load(open(rrs_data_path, 'rb'))
        rrs_data_mean = np.nanmean(rrs_data, axis=0)
        if hemi == 'lh':
            rrs_vals = rrs_data_mean[0:int(num_parcels/2)]
        else:
            rrs_vals = rrs_data_mean[int(num_parcels/2):]
        rrs_matrix.append(rrs_vals)
    rrs_matrix = np.array(rrs_matrix).T  # parcels x freq
    return rrs_matrix

rrs_matrix_lh = get_rrs_matrix(TARGET_FREQ_RANGES, "lh")
rrs_matrix_rh = get_rrs_matrix(TARGET_FREQ_RANGES, "rh")

# ==== 2.5. PREPARE APERIODIC EXPONENT MATRIX ====
def get_aperiodic_exponent_matrix(hemi, data_type):
    ae_matrix = []
    ae_data_path = f'/Users/dennis.jungchildmind.org/Desktop/MEG/112325_fooof_exponent/fooof_{data_type}_parcelonly.pkl'
    ae_data = pickle.load(open(ae_data_path, 'rb'))
    ae_data_mean = np.nanmean(ae_data, axis=0)
    if hemi == 'lh':
        ae_vals = ae_data_mean[0:int(num_parcels/2)]
    else:
        ae_vals = ae_data_mean[int(num_parcels/2):]
    ae_matrix.append(ae_vals)
    ae_matrix = np.array(ae_matrix).T  # parcels x freq
    return ae_matrix

ae_brainic_matrix_lh = get_aperiodic_exponent_matrix("lh",data_type='brainic')
ae_brainic_matrix_rh = get_aperiodic_exponent_matrix("rh",data_type='brainic')
ae_allic_matrix_lh = get_aperiodic_exponent_matrix("lh",data_type='allic')
ae_allic_matrix_rh = get_aperiodic_exponent_matrix("rh",data_type='allic')

ao_brainic_matrix_lh = get_aperiodic_exponent_matrix("lh",data_type='offset_brainic')
ao_brainic_matrix_rh = get_aperiodic_exponent_matrix("rh",data_type='offset_brainic')
ao_allic_matrix_lh = get_aperiodic_exponent_matrix("lh",data_type='offset_allic')
ao_allic_matrix_rh = get_aperiodic_exponent_matrix("rh",data_type='offset_allic')


# ==== 3. PREPARE GRADIENT MATRIX ====
def get_gradient_matrix(layer_types, PC_INDEX, hemi):
    mat = []
    data_dict = lh_data_all_unmapped if hemi == 'lh' else rh_data_all_unmapped
    for layer_type in layer_types:
        mat.append(data_dict[layer_type][:, PC_INDEX])
    return np.array(mat).T  # shape: (n_parcels, n_layers)

gradient_matrix_lh = get_gradient_matrix(TARGET_LAYER_TYPES, PC_INDEX, "lh")
gradient_matrix_rh = get_gradient_matrix(TARGET_LAYER_TYPES, PC_INDEX, "rh")
print("Gradient matrix (lh) shape:", gradient_matrix_lh.shape)

# ==== 4. CONCATENATE GENE AND DFA FEATURES ====
gene_dfa_matrix_lh = np.concatenate([gene_data_matrix, dfa_matrix_lh, rrs_matrix_lh, ae_brainic_matrix_lh], axis=1)
gene_dfa_matrix_rh = np.concatenate([gene_data_matrix, dfa_matrix_rh, rrs_matrix_rh, ae_brainic_matrix_rh], axis=1)
print("Gene/DFA matrix (lh) shape:", gene_dfa_matrix_lh.shape)
print("Gene/DFA matrix (rh) shape:", gene_dfa_matrix_rh.shape)

# ==== 5. COMPUTE SPEARMAN'S CORRELATION ====
def compute_spearman_matrix(X, Y):
    """
    X: shape (n_parcels, n_x_features)
    Y: shape (n_parcels, n_y_features)
    Returns: corr & p-value matrices: shape (n_x_features, n_y_features)
    """
    n_x, n_y = X.shape[1], Y.shape[1]
    corrs = np.zeros((n_x, n_y))
    ps = np.zeros_like(corrs)
    for i in range(n_x):
        for j in range(n_y):
            r, p = spearmanr(X[:, i], Y[:, j], nan_policy='omit')
            corrs[i, j] = r
            ps[i, j] = p
    return corrs, ps

# Compute Spearman's correlation and p-values
spearman_corrs_lh, spearman_ps_lh = compute_spearman_matrix(gene_dfa_matrix_lh, gradient_matrix_lh)
spearman_corrs_rh, spearman_ps_rh = compute_spearman_matrix(gene_dfa_matrix_rh, gradient_matrix_rh)

# Apply multiple comparisons correction (FDR) across all p-values (lh)
pvals_lh_flat = spearman_ps_lh.flatten()
rej_lh, pvals_fdr_lh, _, _ = multipletests(pvals_lh_flat, alpha=ALPHA, method='fdr_bh')
spearman_ps_fdr_lh = pvals_fdr_lh.reshape(spearman_ps_lh.shape)
# NOTE: In statsmodels' multipletests, "rej_lh" is True (1) for *significant* (not rejected) null hypothesis, i.e., significant correlations.
# So, spearman_signif_lh == 1 means significant (null rejected), 0 means not significant (null not rejected).
spearman_signif_lh = rej_lh.reshape(spearman_ps_lh.shape)

# Same for rh
pvals_rh_flat = spearman_ps_rh.flatten()
rej_rh, pvals_fdr_rh, _, _ = multipletests(pvals_rh_flat, alpha=ALPHA, method='fdr_bh')
spearman_ps_fdr_rh = pvals_fdr_rh.reshape(spearman_ps_rh.shape)
spearman_signif_rh = rej_rh.reshape(spearman_ps_rh.shape)

print("LH Spearman correlation matrix shape:", spearman_corrs_lh.shape)
print("RH Spearman correlation matrix shape:", spearman_corrs_rh.shape)

# ==== 6. PLOTTING ====

# ---- Plotting helpers ----
small_font = 12
smaller_font = 10

def plot_fixed_size_imshow(ax, data, vmin=None, vmax=None, cmap='viridis', aspect_ratio=1, fixed_width=3, fixed_height=6, **imshow_kwargs):
    """
    Plots an image on 'ax' such that the Axes size fits the requested fixed pixel size for the image portion.
    Args:
        ax: Matplotlib Axes.
        data: Image array.
        vmin, vmax: for imshow.
        cmap: Colormap.
        aspect_ratio: Height/Width; if None defaults to data shape.
        fixed_width, fixed_height: in inches.
        imshow_kwargs: Additional kwargs for imshow.
    Returns:
        The image object (from imshow)
    """
    fig = ax.figure
    # Compute aspect to get correct physical extent
    if aspect_ratio is None:
        aspect_ratio = data.shape[0] / data.shape[1]
    bbox = ax.get_position()
    # Set Axes location in figure to fit fixed_width Ã fixed_height (in inches)
    figw, figh = fig.get_size_inches()
    # full figure spans (0,0)-(1,1); left, bottom, width, height
    width_fig_frac = fixed_width / figw
    height_fig_frac = fixed_height / figh
    ax.set_position([bbox.x0, bbox.y0, width_fig_frac, height_fig_frac])
    im = ax.imshow(data, vmin=vmin, vmax=vmax, cmap=cmap, aspect='auto', **imshow_kwargs)
    return im

# ---- Plot: Spearman correlation matrix ----
fig0 = plt.figure(figsize=(3,3)) # Large enough so Axes can be fixed, but the image will remain fixed size
ax0 = fig0.add_subplot(111)
fixed_im_width = (0.2*len(TARGET_LAYER_TYPES_LABELS)+0.2) # inches
fixed_im_height = 2.2 # inches, taller for more y items

im0 = plot_fixed_size_imshow(
    ax0,
    spearman_corrs_lh,
    vmin=-0.5, vmax=0.5,
    cmap='coolwarm',
    fixed_width=fixed_im_width,
    fixed_height=fixed_im_height
)
#ax0.set_xlabel('Cortical Gradients', fontsize=small_font)
ax0.set_xticks(np.arange(len(TARGET_LAYER_TYPES_LABELS)))
ax0.set_xticklabels(TARGET_LAYER_TYPES_LABELS, rotation=90, fontsize=smaller_font)
ax0.set_yticks(np.arange(len(TARGET_GENES_LABELS + TARGET_FREQ_LABELS + TARGET_RS_LABELS + TARGET_APERIOD_EXPONENTS)))
ax0.set_yticklabels(TARGET_GENES_LABELS + TARGET_FREQ_LABELS + TARGET_RS_LABELS + TARGET_APERIOD_EXPONENTS, fontsize=smaller_font)
divider0 = make_axes_locatable(ax0)
cax0 = divider0.append_axes("right", size=0.1, pad=0.08)
cbar0 = plt.colorbar(im0, cax=cax0, label='Spearman Correlation')
cbar0.ax.set_ylabel('Spearman Correlation', rotation=270, labelpad=15, fontsize=small_font)
cbar0.ax.tick_params(labelsize=smaller_font)
plt.show()

# ---- Plot: p-values (FDR) as bins ----
# Bins and corresponding colorbar labels (monochrome for significance)
p_bins = [0, 0.00001, 0.001, 0.01, 1.0001]  # set upper limit as slightly above 1
p_bin_labels = ['<0.00001', '<0.001', '<0.01', 'n.s.']
# Shades of gray: black for most significant, light gray for n.s.
p_colors = ['#222222', '#555555', '#bbbbbb', '#eeeeee']
binned_ps = np.digitize(spearman_ps_fdr_lh, p_bins) - 1
cmap = mcolors.ListedColormap(p_colors)
bounds = np.arange(len(p_bin_labels)+1)-0.5
norm = mcolors.BoundaryNorm(np.arange(len(p_bin_labels)+1), cmap.N)

fig = plt.figure(figsize=(7, 7))
ax = fig.add_subplot(111)
im = plot_fixed_size_imshow(
    ax,
    binned_ps,
    cmap=cmap,
    norm=norm,
    fixed_width=fixed_im_width,
    fixed_height=fixed_im_height
)
#ax.set_xlabel('Cortical Gradients', fontsize=small_font)
ax.set_xticks(np.arange(len(TARGET_LAYER_TYPES_LABELS)))
ax.set_xticklabels(TARGET_LAYER_TYPES_LABELS, rotation=90, fontsize=smaller_font)
ax.set_yticks(np.arange(len(TARGET_GENES_LABELS + TARGET_FREQ_LABELS + TARGET_RS_LABELS + TARGET_APERIOD_EXPONENTS )))
ax.set_yticklabels(TARGET_GENES_LABELS + TARGET_FREQ_LABELS + TARGET_RS_LABELS + TARGET_APERIOD_EXPONENTS, fontsize=smaller_font)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size=0.1, pad=0.08)
tick_locs = (np.array(p_bins[:-1]) + np.array(p_bins[1:])) / 2
cbar = plt.colorbar(im, cax=cax, boundaries=bounds, values=np.arange(len(p_bin_labels)))
cbar.set_ticks(np.arange(len(p_bin_labels)))
cbar.set_ticklabels(p_bin_labels)
cbar.ax.set_yticks(np.arange(len(p_bin_labels)))
cbar.ax.set_yticklabels(p_bin_labels)
cbar.ax.set_ylabel('p-value', rotation=270, labelpad=15, fontsize=small_font)
cbar.ax.tick_params(labelsize=smaller_font)
plt.show()


In [None]:
fig = plt.figure(figsize=(7, 7))
ax = fig.add_subplot(111)
sort_order = np.argsort(spearman_corrs_lh[:, 0])
sorted_spearman_corrs_lh = spearman_corrs_lh[sort_order]
im0 = plot_fixed_size_imshow(
    ax,
    sorted_spearman_corrs_lh,
    vmin=-0.5, vmax=0.5,
    cmap='coolwarm',
    fixed_width=fixed_im_width,
    fixed_height=fixed_im_height
)
ax.set_xticks(np.arange(len(TARGET_LAYER_TYPES_LABELS)))
ax.set_xticklabels(TARGET_LAYER_TYPES_LABELS, rotation=90, fontsize=smaller_font)
ax.set_yticks(np.arange(len(TARGET_GENES_LABELS + TARGET_FREQ_LABELS + TARGET_RS_LABELS + TARGET_APERIOD_EXPONENTS)))
ytick_labels = TARGET_GENES_LABELS + TARGET_FREQ_LABELS + TARGET_RS_LABELS + TARGET_APERIOD_EXPONENTS
ytick_labels = np.array(ytick_labels)[sort_order]
ax.set_yticklabels(ytick_labels, fontsize=smaller_font)
divider0 = make_axes_locatable(ax)
cax0 = divider0.append_axes("right", size=0.1, pad=0.08)
cbar0 = plt.colorbar(im0, cax=cax0, label='Spearman Correlation')
cbar0.ax.set_ylabel('Spearman Correlation', rotation=270, labelpad=15, fontsize=small_font)
cbar0.ax.tick_params(labelsize=smaller_font)

plt.show()


fig = plt.figure(figsize=(7,7))
ax0 = fig.add_subplot(111)

p_bins = [0, 0.0001, 0.001, 0.01, 1.0001]  # set upper limit as slightly above 1
p_bin_labels = ['<0.0001', '<0.001', '<0.01', 'n.s.']
# Shades of gray: black for most significant, light gray for n.s.
p_colors = ['#222222', '#555555', '#bbbbbb', '#eeeeee']
binned_ps = np.digitize(spearman_ps_fdr_lh, p_bins) - 1
cmap = mcolors.ListedColormap(p_colors)
bounds = np.arange(len(p_bin_labels)+1)-0.5
norm = mcolors.BoundaryNorm(np.arange(len(p_bin_labels)+1), cmap.N)
binned_ps = binned_ps[sort_order]

im1 = plot_fixed_size_imshow(
    ax0,
    binned_ps,
    cmap=cmap,
    fixed_width=fixed_im_width,
    fixed_height=fixed_im_height
)
ax0.set_xticks(np.arange(len(TARGET_LAYER_TYPES_LABELS)))
ax0.set_xticklabels(TARGET_LAYER_TYPES_LABELS, rotation=90, fontsize=smaller_font)
ax0.set_yticks(np.arange(len(TARGET_GENES_LABELS + TARGET_FREQ_LABELS + TARGET_RS_LABELS + TARGET_APERIOD_EXPONENTS)))
ax0.set_yticklabels(ytick_labels, fontsize=smaller_font)
divider0 = make_axes_locatable(ax0)
cax0 = divider0.append_axes("right", size=0.1, pad=0.08)
cbar0 = plt.colorbar(im1, cax=cax0, label='p-value')
cbar0.ax.tick_params(labelsize=smaller_font)
cbar0.ax.set_yticks(np.arange(len(p_bin_labels)))
cbar0.ax.set_yticklabels(p_bin_labels)
cbar0.ax.set_ylabel('p-value', rotation=270, labelpad=15, fontsize=small_font)
cbar0.ax.tick_params(labelsize=smaller_font)

In [None]:
"""
Organized analysis:
1. Prepare gene data matrix for defined genes & combinations.
2. Prepare DFA feature matrix for all frequency ranges.
3. Prepare cortical gradient matrices for all layer types.
4. Concatenate gene and DFA features.
5. Compute Spearman correlations with gradients.
6. Plot and display.
"""

# ==== Imports and Config ====
from scipy.stats import spearmanr
from statsmodels.stats.multitest import multipletests
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.colors as mcolors

ALPHA = 0.0001

# ---- CONFIGURATION ----
#GENES
TARGET_GENES = ['PVALB', 'SST', 'VIP','LAMP5','SNCG']
#TARGET_GENES = sorted((list(expression['9861'].keys())))

TARGET_GENES_LABELS = [_label_from_gene(g) for g in TARGET_GENES]
#DFA EXPONENTS
TARGET_FREQ_RANGES = [[0.1, 4.0], [4.0, 8.0], [8.0, 15.0], [15.0, 30.0], [30.0, 60.0]]
TARGET_FREQ_LABELS = ['DFA(Delta)', 'DFA(Theta)', 'DFA(Alpha)', 'DFA(Beta)', 'DFA(Gamma)']
#APERIOD EXPONENTS
TARGET_APERIOD_EXPONENTS = ['AE(brain)']
#GRADIENTS
if run_hcp_data:
    #HCP has total only
    TARGET_LAYER_TYPES = ['total']#     , 'supra', 'infra', 'ratio_supra', 'ratio_infra']
    TARGET_LAYER_TYPES_LABELS = ['Total']#, 'Supra', 'Infra', 'Supra/Total', 'Infra/Total']
else:
    TARGET_LAYER_TYPES = ['total','supra', 'infra', 'ratio_supra', 'ratio_infra']
    TARGET_LAYER_TYPES_LABELS = ['Total','Supra', 'Infra', 'Supra/Total', 'Infra/Total']

PC_INDEX = 0   # Principal component for gradients
hemi = 'lh'    # Hemisphere selection; assumes 200 parcels per hemisphere



gene_data_matrix = []
gene_data_labels = []
for gene in TARGET_GENES:
    vals, label = get_parcel_gene_data(gene, hemi)
    gene_data_matrix.append(vals)
    gene_data_labels.append(label)
gene_data_matrix = np.stack(gene_data_matrix, axis=1)  # shape: (n_parcels, n_genes)

# ==== 2. PREPARE DFA FEATURE MATRIX ====


dfa_matrix_lh = get_dfa_matrix(TARGET_FREQ_RANGES, "lh")
dfa_matrix_rh = get_dfa_matrix(TARGET_FREQ_RANGES, "rh")
rss_matrix_lh = get_rrs_matrix(TARGET_FREQ_RANGES, "lh")
rss_matrix_rh = get_rrs_matrix(TARGET_FREQ_RANGES, "rh")

ae_brainic_matrix_lh = get_aperiodic_exponent_matrix("lh",data_type='brainic')
ae_brainic_matrix_rh = get_aperiodic_exponent_matrix("rh",data_type='brainic')
ae_allic_matrix_lh = get_aperiodic_exponent_matrix("lh",data_type='allic')
ae_allic_matrix_rh = get_aperiodic_exponent_matrix("rh",data_type='allic')

ao_brainic_matrix_lh = get_aperiodic_exponent_matrix("lh",data_type='offset_brainic')
ao_brainic_matrix_rh = get_aperiodic_exponent_matrix("rh",data_type='offset_brainic')
ao_allic_matrix_lh = get_aperiodic_exponent_matrix("lh",data_type='offset_allic')
ao_allic_matrix_rh = get_aperiodic_exponent_matrix("rh",data_type='offset_allic')

#gradient_matrix_lh = get_gradient_matrix(TARGET_LAYER_TYPES, PC_INDEX, "lh")
#gradient_matrix_rh = get_gradient_matrix(TARGET_LAYER_TYPES, PC_INDEX, "rh")

# ==== 4. CONCATENATE GENE AND DFA FEATURES ====
gene_dfa_matrix_lh = np.concatenate([dfa_matrix_lh, rss_matrix_lh, ae_brainic_matrix_lh], axis=1)
gene_dfa_matrix_rh = np.concatenate([dfa_matrix_rh, rss_matrix_rh, ae_brainic_matrix_rh], axis=1)


# ==== 5. COMPUTE SPEARMAN'S CORRELATION ====
# Compute Spearman's correlation and p-values
spearman_corrs_lh, spearman_ps_lh = compute_spearman_matrix(gene_dfa_matrix_lh, gene_data_matrix)
spearman_corrs_rh, spearman_ps_rh = compute_spearman_matrix(gene_dfa_matrix_rh, gene_data_matrix)

# Apply multiple comparisons correction (FDR) across all p-values (lh)
pvals_lh_flat = spearman_ps_lh.flatten()
rej_lh, pvals_fdr_lh, _, _ = multipletests(pvals_lh_flat, alpha=ALPHA, method='fdr_bh')
spearman_ps_fdr_lh = pvals_fdr_lh.reshape(spearman_ps_lh.shape)
# NOTE: In statsmodels' multipletests, "rej_lh" is True (1) for *significant* (not rejected) null hypothesis, i.e., significant correlations.
# So, spearman_signif_lh == 1 means significant (null rejected), 0 means not significant (null not rejected).
spearman_signif_lh = rej_lh.reshape(spearman_ps_lh.shape)

# Same for rh
pvals_rh_flat = spearman_ps_rh.flatten()
rej_rh, pvals_fdr_rh, _, _ = multipletests(pvals_rh_flat, alpha=ALPHA, method='fdr_bh')
spearman_ps_fdr_rh = pvals_fdr_rh.reshape(spearman_ps_rh.shape)
spearman_signif_rh = rej_rh.reshape(spearman_ps_rh.shape)

print("LH Spearman correlation matrix shape:", spearman_corrs_lh.shape)
print("RH Spearman correlation matrix shape:", spearman_corrs_rh.shape)

# ==== 6. PLOTTING ====

# ---- Plotting helpers ----
small_font = 12
smaller_font = 10
YLABEL_NAME = TARGET_FREQ_LABELS + TARGET_RS_LABELS + TARGET_APERIOD_EXPONENTS
# ---- Plot: Spearman correlation matrix ----
fig0 = plt.figure(figsize=(3,3)) # Large enough so Axes can be fixed, but the image will remain fixed size
ax0 = fig0.add_subplot(111)
fixed_im_width = (0.2*len(TARGET_LAYER_TYPES_LABELS)+0.2) # inches
fixed_im_height = (0.2*len(YLABEL_NAME)+0.2)

im0 = plot_fixed_size_imshow(
    ax0,
    spearman_corrs_lh,
    vmin=-0.5, vmax=0.5,
    cmap='coolwarm',
    fixed_width=fixed_im_width,
    fixed_height=fixed_im_height
)
#ax0.set_xlabel('Cortical Gradients', fontsize=small_font)
ax0.set_xticks(np.arange(len(TARGET_GENES_LABELS)))
ax0.set_xticklabels(TARGET_GENES_LABELS, rotation=90, fontsize=smaller_font)
ax0.set_yticks(np.arange(len(YLABEL_NAME)))
ax0.set_yticklabels(YLABEL_NAME, fontsize=smaller_font)
divider0 = make_axes_locatable(ax0)
cax0 = divider0.append_axes("right", size=0.1, pad=0.08)
cbar0 = plt.colorbar(im0, cax=cax0, label='Spearman Correlation')
cbar0.ax.set_ylabel('Spearman Correlation', rotation=270, labelpad=15, fontsize=small_font)
cbar0.ax.tick_params(labelsize=smaller_font)
plt.show()

# ---- Plot: p-values (FDR) as bins ----
# Bins and corresponding colorbar labels (monochrome for significance)
p_bins = [0, 0.0001, 0.001, 0.01, 1.0001]  # set upper limit as slightly above 1
p_bin_labels = ['<0.0001', '<0.001', '<0.01', 'n.s.']
# Shades of gray: black for most significant, light gray for n.s.
p_colors = ['#222222', '#555555', '#bbbbbb', '#eeeeee']
binned_ps = np.digitize(spearman_ps_fdr_lh, p_bins) - 1
cmap = mcolors.ListedColormap(p_colors)
bounds = np.arange(len(p_bin_labels)+1)-0.5
norm = mcolors.BoundaryNorm(np.arange(len(p_bin_labels)+1), cmap.N)

fig = plt.figure(figsize=(7, 7))
ax = fig.add_subplot(111)
im = plot_fixed_size_imshow(
    ax,
    binned_ps,
    cmap=cmap,
    norm=norm,
    fixed_width=fixed_im_width,
    fixed_height=fixed_im_height
)
#ax.set_xlabel('Cortical Gradients', fontsize=small_font)
ax.set_xticks(np.arange(len(TARGET_GENES_LABELS)))
ax.set_xticklabels(TARGET_GENES_LABELS, rotation=90, fontsize=smaller_font)
ax.set_yticks(np.arange(len(YLABEL_NAME)))
ax.set_yticklabels(YLABEL_NAME, fontsize=smaller_font)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size=0.1, pad=0.08)
tick_locs = (np.array(p_bins[:-1]) + np.array(p_bins[1:])) / 2
cbar = plt.colorbar(im, cax=cax, boundaries=bounds, values=np.arange(len(p_bin_labels)))
cbar.set_ticks(np.arange(len(p_bin_labels)))
cbar.set_ticklabels(p_bin_labels)
cbar.ax.set_yticks(np.arange(len(p_bin_labels)))
cbar.ax.set_yticklabels(p_bin_labels)
cbar.ax.set_ylabel('p-value', rotation=270, labelpad=15, fontsize=small_font)
cbar.ax.tick_params(labelsize=smaller_font)
plt.show()


In [None]:
"""
Organized analysis:
1. Prepare gene data matrix for defined genes & combinations.
2. Prepare DFA feature matrix for all frequency ranges.
3. Prepare cortical gradient matrices for all layer types.
4. Concatenate gene and DFA features.
5. Compute Spearman correlations with gradients.
6. Plot and display.
"""

# ==== Imports and Config ====
from scipy.stats import spearmanr
from statsmodels.stats.multitest import multipletests
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.colors as mcolors

ALPHA = 0.0001

# ---- CONFIGURATION ----
#GENES
TARGET_GENES = ['PVALB', 'SST', 'VIP','LAMP5','SNCG']
#TARGET_GENES = sorted((list(expression['9861'].keys())))

TARGET_GENES_LABELS = [_label_from_gene(g) for g in TARGET_GENES]
#DFA EXPONENTS
TARGET_FREQ_RANGES = [[0.1, 4.0], [4.0, 8.0], [8.0, 15.0], [15.0, 30.0], [30.0, 60.0]]
TARGET_FREQ_LABELS = ['DFA(Delta)', 'DFA(Theta)', 'DFA(Alpha)', 'DFA(Beta)', 'DFA(Gamma)']
#APERIOD EXPONENTS
TARGET_APERIOD_EXPONENTS = ['AE(brain)','AE(all)','AO(brain)','AO(all)']
#GRADIENTS
if run_hcp_data:
    #HCP has total only
    TARGET_LAYER_TYPES = ['total']#     , 'supra', 'infra', 'ratio_supra', 'ratio_infra']
    TARGET_LAYER_TYPES_LABELS = ['Total']#, 'Supra', 'Infra', 'Supra/Total', 'Infra/Total']
else:
    TARGET_LAYER_TYPES = ['total','supra', 'infra', 'ratio_supra', 'ratio_infra']
    TARGET_LAYER_TYPES_LABELS = ['Total','Supra', 'Infra', 'Supra/Total', 'Infra/Total']

PC_INDEX = 0   # Principal component for gradients
hemi = 'lh'    # Hemisphere selection; assumes 200 parcels per hemisphere



gene_data_matrix = []
gene_data_labels = []
for gene in TARGET_GENES:
    vals, label = get_parcel_gene_data(gene, hemi)
    gene_data_matrix.append(vals)
    gene_data_labels.append(label)
gene_data_matrix = np.stack(gene_data_matrix, axis=1)  # shape: (n_parcels, n_genes)

# ==== 2. PREPARE DFA FEATURE MATRIX ====


dfa_matrix_lh = get_dfa_matrix(TARGET_FREQ_RANGES, "lh")
dfa_matrix_rh = get_dfa_matrix(TARGET_FREQ_RANGES, "rh")
rrs_matrix_lh = get_rrs_matrix(TARGET_FREQ_RANGES, "lh")
rrs_matrix_rh = get_rrs_matrix(TARGET_FREQ_RANGES, "rh")

ae_brainic_matrix_lh = get_aperiodic_exponent_matrix("lh",data_type='brainic')
ae_brainic_matrix_rh = get_aperiodic_exponent_matrix("rh",data_type='brainic')
ae_allic_matrix_lh = get_aperiodic_exponent_matrix("lh",data_type='allic')
ae_allic_matrix_rh = get_aperiodic_exponent_matrix("rh",data_type='allic')

ao_brainic_matrix_lh = get_aperiodic_exponent_matrix("lh",data_type='offset_brainic')
ao_brainic_matrix_rh = get_aperiodic_exponent_matrix("rh",data_type='offset_brainic')
ao_allic_matrix_lh = get_aperiodic_exponent_matrix("lh",data_type='offset_allic')
ao_allic_matrix_rh = get_aperiodic_exponent_matrix("rh",data_type='offset_allic')

#gradient_matrix_lh = get_gradient_matrix(TARGET_LAYER_TYPES, PC_INDEX, "lh")
#gradient_matrix_rh = get_gradient_matrix(TARGET_LAYER_TYPES, PC_INDEX, "rh")

# ==== 4. CONCATENATE GENE AND DFA FEATURES ====

exp_matrix_lh = np.concatenate([ae_brainic_matrix_lh, ae_allic_matrix_lh, ao_brainic_matrix_lh, ao_allic_matrix_lh], axis=1)
exp_matrix_rh = np.concatenate([ae_brainic_matrix_rh, ae_allic_matrix_rh, ao_brainic_matrix_rh, ao_allic_matrix_rh], axis=1)

# ==== 5. COMPUTE SPEARMAN'S CORRELATION ====
# Compute Spearman's correlation and p-values
spearman_corrs_lh, spearman_ps_lh = compute_spearman_matrix(np.concatenate([dfa_matrix_lh, rrs_matrix_lh], axis=1), exp_matrix_lh)
spearman_corrs_rh, spearman_ps_rh = compute_spearman_matrix(np.concatenate([dfa_matrix_rh, rrs_matrix_rh], axis=1), exp_matrix_rh)

# Apply multiple comparisons correction (FDR) across all p-values (lh)
pvals_lh_flat = spearman_ps_lh.flatten()
rej_lh, pvals_fdr_lh, _, _ = multipletests(pvals_lh_flat, alpha=ALPHA, method='fdr_bh')
spearman_ps_fdr_lh = pvals_fdr_lh.reshape(spearman_ps_lh.shape)
# NOTE: In statsmodels' multipletests, "rej_lh" is True (1) for *significant* (not rejected) null hypothesis, i.e., significant correlations.
# So, spearman_signif_lh == 1 means significant (null rejected), 0 means not significant (null not rejected).
spearman_signif_lh = rej_lh.reshape(spearman_ps_lh.shape)

# Same for rh
pvals_rh_flat = spearman_ps_rh.flatten()
rej_rh, pvals_fdr_rh, _, _ = multipletests(pvals_rh_flat, alpha=ALPHA, method='fdr_bh')
spearman_ps_fdr_rh = pvals_fdr_rh.reshape(spearman_ps_rh.shape)
spearman_signif_rh = rej_rh.reshape(spearman_ps_rh.shape)

print("LH Spearman correlation matrix shape:", spearman_corrs_lh.shape)
print("RH Spearman correlation matrix shape:", spearman_corrs_rh.shape)

# ==== 6. PLOTTING ====

# ---- Plotting helpers ----
small_font = 12
smaller_font = 10

YLABEL_NAME = TARGET_FREQ_LABELS + TARGET_RS_LABELS
# ---- Plot: Spearman correlation matrix ----
fig0 = plt.figure(figsize=(3,3)) # Large enough so Axes can be fixed, but the image will remain fixed size
ax0 = fig0.add_subplot(111)
fixed_im_width = (0.2*len(TARGET_LAYER_TYPES_LABELS)+0.2) # inches
fixed_im_height = (0.2*len(YLABEL_NAME)+0.2)

im0 = plot_fixed_size_imshow(
    ax0,
    spearman_corrs_lh,
    vmin=-0.5, vmax=0.5,
    cmap='coolwarm',
    fixed_width=fixed_im_width,
    fixed_height=fixed_im_height
)
#ax0.set_xlabel('Cortical Gradients', fontsize=small_font)
ax0.set_xticks(np.arange(len(TARGET_APERIOD_EXPONENTS)))
ax0.set_xticklabels(TARGET_APERIOD_EXPONENTS, rotation=90, fontsize=smaller_font)
ax0.set_yticks(np.arange(len(YLABEL_NAME)))
ax0.set_yticklabels(YLABEL_NAME, fontsize=smaller_font)
divider0 = make_axes_locatable(ax0)
cax0 = divider0.append_axes("right", size=0.1, pad=0.08)
cbar0 = plt.colorbar(im0, cax=cax0, label='Spearman Correlation')
cbar0.ax.set_ylabel('Spearman Correlation', rotation=270, labelpad=15, fontsize=small_font)
cbar0.ax.tick_params(labelsize=smaller_font)
plt.show()

# ---- Plot: p-values (FDR) as bins ----
# Bins and corresponding colorbar labels (monochrome for significance)
p_bins = [0, 0.0001, 0.001, 0.01, 1.0001]  # set upper limit as slightly above 1
p_bin_labels = ['<0.0001', '<0.001', '<0.01', 'n.s.']
# Shades of gray: black for most significant, light gray for n.s.
p_colors = ['#222222', '#555555', '#bbbbbb', '#eeeeee']
binned_ps = np.digitize(spearman_ps_fdr_lh, p_bins) - 1
cmap = mcolors.ListedColormap(p_colors)
bounds = np.arange(len(p_bin_labels)+1)-0.5
norm = mcolors.BoundaryNorm(np.arange(len(p_bin_labels)+1), cmap.N)

fig = plt.figure(figsize=(7, 7))
ax = fig.add_subplot(111)
im = plot_fixed_size_imshow(
    ax,
    binned_ps,
    cmap=cmap,
    norm=norm,
    fixed_width=fixed_im_width,
    fixed_height=fixed_im_height
)
#ax.set_xlabel('Cortical Gradients', fontsize=small_font)
ax.set_xticks(np.arange(len(TARGET_APERIOD_EXPONENTS)))
ax.set_xticklabels(TARGET_APERIOD_EXPONENTS, rotation=90, fontsize=smaller_font)
ax.set_yticks(np.arange(len(YLABEL_NAME)))
ax.set_yticklabels(YLABEL_NAME, fontsize=smaller_font)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size=0.1, pad=0.08)
tick_locs = (np.array(p_bins[:-1]) + np.array(p_bins[1:])) / 2
cbar = plt.colorbar(im, cax=cax, boundaries=bounds, values=np.arange(len(p_bin_labels)))
cbar.set_ticks(np.arange(len(p_bin_labels)))
cbar.set_ticklabels(p_bin_labels)
cbar.ax.set_yticks(np.arange(len(p_bin_labels)))
cbar.ax.set_yticklabels(p_bin_labels)
cbar.ax.set_ylabel('p-value', rotation=270, labelpad=15, fontsize=small_font)
cbar.ax.tick_params(labelsize=smaller_font)
plt.show()


In [None]:
def hist_exponent_matrix(freq_ranges, freq_labels, xlabel_name, hemi):
    import numpy as np, matplotlib.pyplot as plt, matplotlib.cm as cm, pickle

    # Vectorized data loading and selection
    slices = slice(0, int(num_parcels/2)) if hemi == 'lh' else slice(int(num_parcels/2), None)
   # files = [f'/Users/dennis.jungchildmind.org/Desktop/MEG/112225_dfa_outputs/dfa_restin_f{fr[0]}_to_{fr[1]}_parcelonly.pkl' for fr in freq_ranges]
    files = [f'/Users/dennis.jungchildmind.org/Desktop/MEG/112625_rrs_outputs/rrs_restin_f{fr[0]}_to_{fr[1]}_parcelonly.pkl' for fr in freq_ranges]
    vals = [pickle.load(open(f, 'rb'))[slices, :].ravel() for f in files]
    
    bins = np.histogram_bin_edges(np.concatenate(vals), bins=100)
    n_freq = len(freq_labels)
    color_palette = [cm.get_cmap('jet', n_freq)(i) for i in range(n_freq)]

    plt.figure(figsize=(8, 4))
    [plt.hist(v, bins=bins, alpha=0.7, edgecolor='white', color=color_palette[i], label=freq_labels[i]) for i, v in enumerate(vals)]
    plt.xlabel(xlabel_name, fontsize=14)
    plt.ylabel('Count', fontsize=14)
    ax = plt.gca()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    plt.xticks(fontsize=13)  # Make xticks smaller
    plt.yticks(fontsize=13)  # Make yticks smaller
    plt.legend(fontsize=12, frameon=False, loc='upper right')
    plt.tight_layout()


#dfa
#hist_exponent_matrix(TARGET_FREQ_RANGES, TARGET_FREQ_LABELS, xlabel_name="DFA exponent", hemi="lh")
#rrs
hist_exponent_matrix(TARGET_FREQ_RANGES, TARGET_FREQ_LABELS, xlabel_name="RS", hemi="lh")

In [None]:
from scipy.stats import spearmanr
import numpy as np

# Path analysis / mediation using Spearman correlation
print("Testing mediation: Gene â E/I â Gradient (Spearman correlation)")
print("="*50)

# Direct effects using Spearman correlation
rho_gene_gradient, _ = spearmanr(X_clean[:, 1], y_clean)  # c path (total effect)
rho_gene_EI, _ = spearmanr(X_clean[:, 1], X_clean[:, 0])  # a path
rho_EI_gradient, _ = spearmanr(X_clean[:, 0], y_clean)    # b path (bivariate)

print(f"Gene â Gradient (total):    Ï={rho_gene_gradient:.3f}")
print(f"Gene â E/I (a path):        Ï={rho_gene_EI:.3f}")
print(f"E/I â Gradient (bivariate): Ï={rho_EI_gradient:.3f}")

# Partial correlation: E/I with gradient, controlling for genes
from sklearn.linear_model import LinearRegression

# Residualize E/I and gradient with respect to genes
model_EI_gene = LinearRegression().fit(X_clean[:, [1]], X_clean[:, 0])
EI_resid = X_clean[:, 0] - model_EI_gene.predict(X_clean[:, [1]])

model_grad_gene = LinearRegression().fit(X_clean[:, [1]], y_clean)
grad_resid = y_clean - model_grad_gene.predict(X_clean[:, [1]])

# Partial correlation using Spearman correlation
rho_partial, p_partial = spearmanr(EI_resid, grad_resid)
print(f"\nE/I â Gradient (controlling for genes, Spearman): Ï={rho_partial:.3f}, p={p_partial:.4f}")

# Compare standardized betas (still uses linear regression on standardized z-scores)
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_clean)
y_scaled = scaler.fit_transform(y_clean.reshape(-1, 1)).ravel()

model_full = LinearRegression().fit(X_scaled, y_scaled)
print(f"\nStandardized Î² coefficients:")
print(f"  Î²_EI = {model_full.coef_[0]:.3f}")
print(f"  Î²_gene = {model_full.coef_[1]:.3f}")

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import spearmanr

# Compute Spearman correlations
rho_gene_EI, p_gene_EI = spearmanr(X_clean[:, 1], X_clean[:, 0])
rho_EI_gradient, p_EI_gradient = spearmanr(X_clean[:, 0], y_clean)
rho_gene_gradient, p_gene_gradient = spearmanr(X_clean[:, 1], y_clean)
rho_partial, p_partial_spearman = spearmanr(EI_resid, grad_resid)

fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# Increase font sizes globally for all ticks
plt.rcParams.update({'xtick.labelsize': 16, 'ytick.labelsize': 16, 'axes.labelsize': 20, 'axes.titlesize': 22})

# 1. Gene vs E/I (the key relationship)
axes[0].scatter(X_clean[:, 1], X_clean[:, 0], alpha=0.5, s=50, c='purple')
z = np.polyfit(X_clean[:, 1], X_clean[:, 0], 1)
p = np.poly1d(z)
x_line = np.linspace(X_clean[:, 1].min(), X_clean[:, 1].max(), 100)
axes[0].plot(x_line, p(x_line), 'r-', lw=2)
axes[0].set_xlabel('Gene Expression (SST-PVALB)', fontsize=20)
axes[0].set_ylabel('Aperiodic Exponent', fontsize=20)
axes[0].set_title(f'Gene â â E/I\nSpearman Ï={rho_gene_EI:.3f}', fontsize=22)
# Remove grid
axes[0].grid(False)

# 2. The masking effect
# Color points by gene expression level
scatter = axes[1].scatter(X_clean[:, 0], y_clean, c=X_clean[:, 1], cmap='viridis', alpha=0.6, s=50)
axes[1].set_xlabel('Aperiodic Exponent', fontsize=20)
axes[1].set_ylabel('Gradient Score', fontsize=20)
axes[1].set_title(f'E/I â Gradient (colored by gene)\nSpearman Ï={rho_EI_gradient:.3f}', fontsize=22)
# Colorbar with increased fontsize
cbar = plt.colorbar(scatter, ax=axes[1])
cbar.set_label('Gene Expression', fontsize=18, rotation=-90, labelpad=10)
cbar.ax.tick_params(labelsize=16)
# Remove grid
axes[1].grid(False)

# 3. After controlling for genes (residuals)
axes[2].scatter(EI_resid, grad_resid, alpha=0.5, s=50, c='orange')
z = np.polyfit(EI_resid, grad_resid, 1)
p = np.poly1d(z)
x_line = np.linspace(EI_resid.min(), EI_resid.max(), 100)
axes[2].plot(x_line, p(x_line), 'r--', lw=2)
axes[2].set_xlabel('Aperiodic Exponent (residual)', fontsize=20)
axes[2].set_ylabel('Gradient (residual)', fontsize=20)
axes[2].set_title(f'E/I â Gradient (controlling gene)\nSpearman Ï={rho_partial:.3f}, p={p_partial_spearman:.3f}', fontsize=22)
# Remove grid
axes[2].grid(False)
axes[2].axhline(0, color='gray', linestyle='--', alpha=0.5)
axes[2].axvline(0, color='gray', linestyle='--', alpha=0.5)

plt.tight_layout()
plt.show()

print("\n" + "="*60)
print("INTERPRETATION")
print("="*60)
print(f"1. Gene and E/I are negatively correlated (Spearman Ï={rho_gene_EI:.3f})")
print(f"   â Higher inhibitory genes = Lower 1/f exponent")
print(f"\n2. Gene predicts gradient strongly (Spearman Ï={rho_gene_gradient:.3f})")
print(f"   â SST-PVALB expression organizes structure")
print(f"\n3. E/I does NOT predict gradient independently")
print(f"   â Raw: Spearman Ï={rho_EI_gradient:.3f}")
print(f"   â Controlling for genes: Spearman Ï={rho_partial:.3f} (p={p_partial_spearman:.3f})")
print(f"\n4. All variance is explained by GENES, not E/I balance")
print("="*60)

In [None]:
import numpy as np
from sklearn.linear_model import LinearRegression
from scipy.stats import pearsonr

# Your data prep
the_gene_data = np.nanmean(get_gene_expression(expression, "SST")-get_gene_expression(expression, "PVALB"), axis=-1)
exp_dat_lh_parc = parcellate_data(exp_data_lh, atlas_data_lh)
the_gene_data_lh = the_gene_data[:int(num_parcels/2)]

X = np.column_stack([exp_dat_lh_parc, the_gene_data_lh])
y = lh_data_all_unmapped['supra']

# Remove NaNs
index = ~np.isnan(X).any(axis=1) & ~np.isnan(y)
X_clean = X[index]
y_clean = y[index]

print(f"N regions: {len(y_clean)}")

# === STEP 1: Basic correlations (diagnostic) ===
r_EI, p_EI = pearsonr(X_clean[:, 0], y_clean)
r_gene, p_gene = pearsonr(X_clean[:, 1], y_clean)

print(f"\nUnivariate correlations (NOT corrected for spatial autocorrelation):")
print(f"  E/I:  r={r_EI:.3f}, p={p_EI:.4f}")
print(f"  Gene: r={r_gene:.3f}, p={p_gene:.4f}")

# === STEP 2: Multiple regression ===
model = LinearRegression()
model.fit(X_clean, y_clean)
r_squared = model.score(X_clean, y_clean)
beta_EI, beta_gene = model.coef_

print(f"\nMultiple Regression:")
print(f"  RÂ² = {r_squared:.3f}")
print(f"  Î²_EI = {beta_EI:.3f}")
print(f"  Î²_gene = {beta_gene:.3f}")

# === STEP 3: Variance partitioning ===
model_EI = LinearRegression().fit(X_clean[:, [0]], y_clean)
model_gene = LinearRegression().fit(X_clean[:, [1]], y_clean)

R2_EI = model_EI.score(X_clean[:, [0]], y_clean)
R2_gene = model_gene.score(X_clean[:, [1]], y_clean)

unique_EI = r_squared - R2_gene
unique_gene = r_squared - R2_EI
shared = R2_EI + R2_gene - r_squared

print(f"\nVariance Partitioning:")
print(f"  Total RÂ²:     {r_squared:.3f}")
print(f"  Unique E/I:   {unique_EI:.3f}")
print(f"  Unique Gene:  {unique_gene:.3f}")
print(f"  Shared:       {shared:.3f}")

# === STEP 4: Visualization ===
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# E/I vs gradient
axes[0].scatter(X_clean[:, 0], y_clean, alpha=0.5, s=30)
axes[0].set_xlabel('E/I Balance (1/f exponent)')
axes[0].set_ylabel('Gradient Value')
axes[0].set_title(f'E/I vs Gradient\nr={r_EI:.3f}')

# Gene vs gradient
axes[1].scatter(X_clean[:, 1], y_clean, alpha=0.5, s=30, color='orange')
axes[1].set_xlabel('SST-PVALB Expression')
axes[1].set_ylabel('Gradient Value')
axes[1].set_title(f'Gene vs Gradient\nr={r_gene:.3f}')

# Predicted vs actual
y_pred = model.predict(X_clean)
axes[2].scatter(y_clean, y_pred, alpha=0.5, s=30, color='green')
axes[2].plot([y_clean.min(), y_clean.max()], [y_clean.min(), y_clean.max()], 'r--', lw=2)
axes[2].set_xlabel('Actual Gradient')
axes[2].set_ylabel('Predicted Gradient')
axes[2].set_title(f'Full Model\nRÂ²={r_squared:.3f}')

plt.tight_layout()
plt.show()

# Variance partitioning pie chart
fig, ax = plt.subplots(figsize=(6, 6))
sizes = [unique_EI, unique_gene, shared, 1-r_squared]
labels = [f'Unique E/I\n({unique_EI:.2f})', 
          f'Unique Gene\n({unique_gene:.2f})',
          f'Shared\n({shared:.2f})',
          f'Unexplained\n({1-r_squared:.2f})']
colors = ['#ff9999', '#66b3ff', '#99ff99', '#lightgray']
ax.pie(sizes, labels=labels, colors=colors, autopct='%1.1f%%', startangle=90)
ax.set_title('Variance Partitioning', fontsize=16)
plt.show()

In [None]:
import os
import numpy as np
import nibabel as nib
from sklearn.linear_model import LinearRegression, Lasso, Ridge, RidgeCV
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import r2_score, mean_squared_error
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import yaspy
# Load and process BigBrain layer data
big_brain_layer_dir = '/Users/dennis.jungchildmind.org/Downloads/BigBrain/thickness/resample/'
#supra_sample = lh_data['supra'][0]
#infra_sample = lh_data['infra'][0]
#total_sample = supra_sample+infra_sample#lh_data['total'][0]#this total is little different from the supra_sample + infra_sample

def load_layers_bb(hemi):
    """Load each layer thickness data for a given hemisphere"""
    layers = {}
    for layer in range(1,7):
        start = layer-1
        filename = f'{hemi}.{start}-{layer}.32k.shape.gii'
        layers[f'L{layer}'] = nib.load(os.path.join(big_brain_layer_dir, filename)).darrays[0].data
    return np.stack([layers[f'L{layer}'] for layer in range(1,7)], axis=1)

# Load data for both hemispheres
bb_layers_lh = load_layers_bb('lh')
bb_layers_rh = load_layers_bb('rh')

print(f"BigBrain layer data shape: {bb_layers_lh.shape}")

# Prepare template data (BigBrain data) for regression analysis
def prepare_template_data(layer_data, target_layer):
    """Prepare template data for regression"""
    supra = np.sum(layer_data[:,0:3], axis=1)#as defined in the ex vivo data (L1-3)
    infra = np.sum(layer_data[:,3:], axis=1)#as defined in the ex vivo data (L4-6) * 
    total = np.sum(layer_data, axis=1)
    vertex_id = np.arange(layer_data.shape[0])#positional information
    #x = np.column_stack([supra,infra, vertex_id])
    x = np.column_stack([supra, vertex_id])
    y = layer_data[:,target_layer]  # Layer 4 * we want to regress out layer 4 from the infragranular (L4-6 in ex vivo)
    return x, y

#BigBrain data as template data for the model fitting
#perform for each layer and hemisphere
r2_all = {'lh': [], 'rh': []}
mse_all = {'lh': [], 'rh': []}
mlr_models = {'lh': [], 'rh': []}
bb_layers = {'lh': bb_layers_lh, 'rh': bb_layers_rh}

  
for hemi in ['lh', 'rh']:
    for layer in range(0,6):
        x_template, y_template = prepare_template_data(bb_layers[hemi],[layer])

        # Fit multiple linear regression model
        mlr_model = LinearRegression(fit_intercept=True,tol=1e-6,copy_X=True)
        #mlr_model = RidgeCV(alphas=np.logspace(-3, 3, 100), cv=10)
        #mlr_model = RandomForestRegressor(n_estimators=100, random_state=42)
        mlr_model.fit(x_template, y_template)
        y_pred = mlr_model.predict(x_template)
        
        # Calculate and store metrics
        r2_all[hemi].append(r2_score(y_template, y_pred))
        mse_all[hemi].append(mean_squared_error(y_template, y_pred))
        mlr_models[hemi].append(mlr_model)
        print(f'{hemi.upper()} Model performance - R2: {r2_all[hemi][-1]:.3f}, MSE: {mse_all[hemi][-1]:.3f}')
#plt.plot(y_template,y_pred,'.')

plt.figure(figsize=(5,4))
x = np.arange(6)
width = 0.35
plt.bar(x - width/2, r2_all['lh'], width, label='LH', color='#2ecc71')  # Emerald green
plt.bar(x + width/2, r2_all['rh'], width, label='RH', color='#3498db')  # Bright blue
plt.xticks(x, ['L1', 'L2', 'L3', 'L4', 'L5', 'L6'])

plt.ylim(0, 1)
plt.ylabel('RÂ² Score')
plt.xlabel('Cortical Layer')
plt.title('MLR Performance')
plt.legend(loc='upper right', framealpha=1.0)
plt.tight_layout()
plt.show()

In [None]:
#based on the previous section, let's calculate infra_cleaned with multiple linear regression
mlr_model_lh = LinearRegression(fit_intercept=True,tol=1e-6,copy_X=True)
#mlr_model_lh = RandomForestRegressor(n_estimators=100, random_state=42)
x_template_lh, y_template_lh = prepare_template_data(bb_layers_lh,3)
mlr_model_lh.fit(x_template_lh, y_template_lh)
y_pred_lh = mlr_model_lh.predict(x_template_lh)
print('r2_score',r2_score(y_template_lh, y_pred_lh))
print('mean_squared_error',mean_squared_error(y_template_lh, y_pred_lh))

mlr_model_rh = LinearRegression(fit_intercept=True,tol=1e-6,copy_X=True)
#mlr_model_rh = RandomForestRegressor(n_estimators=100, random_state=42)
x_template_rh, y_template_rh = prepare_template_data(bb_layers_rh,3)
mlr_model_rh.fit(x_template_rh, y_template_rh)
y_pred_rh = mlr_model_rh.predict(x_template_rh)
print('r2_score',r2_score(y_template_rh, y_pred_rh))
print('mean_squared_error',mean_squared_error(y_template_rh, y_pred_rh))


# Process samples
def process_samples(mlr_model,supra_sample, infra_sample):
    """Process all samples to remove layer 4 contribution"""
    n_vertices, n_samples = supra_sample.shape
    predicted_x4 = np.zeros((n_vertices, n_samples))
    infra_cleaned = np.zeros((n_vertices, n_samples))
    vertex_id = np.arange(n_vertices)

    for i in range(n_samples):
        #X = np.column_stack([supra_sample[:,i], infra_sample[:,i],vertex_id])
        X = np.column_stack([supra_sample[:,i],vertex_id])
        pred_layers = mlr_model.predict(X)
        
        predicted_x4[:,i] = pred_layers#np.sum(pred_layers,axis=1)
        
        infra_cleaned[:,i]= infra_sample[:,i] - predicted_x4[:,i]
        
        print(f"Sample {i+1}: Layer 4 range [{predicted_x4.min():.3f}, {predicted_x4.max():.3f}]")
        
    return predicted_x4, infra_cleaned


supra_sample_lh = lh_data['supra'][0]
infra_sample_lh = lh_data['infra'][0]
supra_sample_rh = rh_data['supra'][0]
infra_sample_rh = rh_data['infra'][0]
predicted_x4_lh, infra_cleaned_lh = process_samples(mlr_model_lh,supra_sample_lh, infra_sample_lh)
predicted_x4_rh, infra_cleaned_rh = process_samples(mlr_model_rh,supra_sample_rh, infra_sample_rh)



In [None]:

total_sample_lh = lh_data['total'][0]
total_sample_rh = rh_data['total'][0]

ratio_cleaned_lh = infra_cleaned_lh / total_sample_lh
ratio_cleaned_rh = infra_cleaned_rh / total_sample_rh


# Calculate and plot cleaned data
#cleaned_data_lh = np.mean(infra_cleaned_lh, axis=1)
#cleaned_data_rh = np.mean(infra_cleaned_rh, axis=1)

data2plot_lh = np.mean(ratio_cleaned_lh, axis=1) 
data2plot_rh = np.mean(ratio_cleaned_rh, axis=1)

#data2plot_lh = np.mean(infra_cleaned_lh, axis=1)
#data2plot_rh = np.mean(infra_cleaned_rh, axis=1)

cmap = 'RdBu_r'

#data2plot = np.mean(predicted_x4,axis=1)

# Visualization
surf_path = '/Users/dennis.jungchildmind.org/Downloads/HCP_S1200_Atlas_Z4_pkXDZ/S1200.R.white_MSMAll.32k_fs_LR.surf.gii'
plotter = yaspy.Plotter(surf_path, hemi='rh')
overlay = plotter.overlay(data2plot_rh, cmap=cmap, alpha=1,vmin=0.4,vmax=0.6)
plotter.border(data2plot_rh, alpha=0)

# Create and display figure
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6,4))
ax1.imshow(plotter.screenshot("lateral"))
ax1.axis('off')

divider = make_axes_locatable(ax1)
cax = divider.append_axes("bottom", size="3%", pad=0.05)
cbar = plt.colorbar(overlay, cax=cax, orientation='horizontal')
cbar.ax.tick_params(labelsize="small")

ax2.imshow(plotter.screenshot("medial"))
ax2.axis('off')

plt.tight_layout()
plt.show()


In [None]:
#lh_data_parc['total'][i] = parcellate_data(lh_data['total'][i].T, atlas, 'lh').T
#lh_data_parc['total'][i] = parcellate_data(lh_data['total'][i].T, atlas, 'lh').T
lh_data_parc['infra_cleaned'] = {}
rh_data_parc['infra_cleaned'] = {}

lh_data_parc['infra_cleaned'][0] = parcellate_data(infra_cleaned_lh.T, atlas, 'lh').T
rh_data_parc['infra_cleaned'][0] = parcellate_data(infra_cleaned_rh.T, atlas, 'rh').T

infra_cleaned_corr_lh = {0: calculate_correlation_matrix(lh_data_parc['infra_cleaned'][0], transpose=False, partial_corr=True, shrink_it=False, first_index=1)}
infra_cleaned_corr_rh = {0: calculate_correlation_matrix(rh_data_parc['infra_cleaned'][0], transpose=False, partial_corr=True, shrink_it=False, first_index=1)}
print(infra_cleaned_corr_lh[0].shape)
print(infra_cleaned_corr_rh[0].shape)

lh_data_parc['ratio_cleaned'] = {}
rh_data_parc['ratio_cleaned'] = {}
lh_data_parc['ratio_cleaned'][0] = parcellate_data(infra_cleaned_lh.T/total_sample_lh.T, atlas, 'lh').T
rh_data_parc['ratio_cleaned'][0] = parcellate_data(infra_cleaned_rh.T/total_sample_rh.T, atlas, 'rh').T

ratio_cleaned_corr_lh = {0: calculate_correlation_matrix(lh_data_parc['ratio_cleaned'][0], transpose=False, partial_corr=True, shrink_it=False, first_index=1)}
ratio_cleaned_corr_rh = {0: calculate_correlation_matrix(rh_data_parc['ratio_cleaned'][0], transpose=False, partial_corr=True, shrink_it=False, first_index=1)}
print(ratio_cleaned_corr_lh[0].shape)
print(ratio_cleaned_corr_rh[0].shape)

lh_data_parc['predicted_x4'] = {}
rh_data_parc['predicted_x4'] = {}
lh_data_parc['predicted_x4'][0] = parcellate_data(predicted_x4_lh.T, atlas, 'lh').T
rh_data_parc['predicted_x4'][0] = parcellate_data(predicted_x4_rh.T, atlas, 'rh').T

predicted_x4_corr_lh = {0: calculate_correlation_matrix(lh_data_parc['predicted_x4'][0], transpose=False, partial_corr=True, shrink_it=False, first_index=1)}
predicted_x4_corr_rh = {0: calculate_correlation_matrix(rh_data_parc['predicted_x4'][0], transpose=False, partial_corr=True, shrink_it=False, first_index=1)}
print(predicted_x4_corr_lh[0].shape)
print(predicted_x4_corr_rh[0].shape)







