In [1]:
import nibabel as nib
import ants
from ants.plotting import plot
import numpy as np


In [2]:
import h5py
import nibabel as nib
import numpy as np
import os
import pickle
from joblib import Parallel, delayed
from matplotlib import pyplot as plt
from scipy import io as sio
from scipy.signal import resample
import yaspy
from fooof import FOOOF 
from utils import mtspectrumc

def calculate_source_1over4exp(
    subject, 
    session, 
    session_num, 
    basedir_preproc,
    basedir_meg,
    atlas_path,
    save_path=None,
    n_jobs=24,
    fdownsample=200,
):
    """
    Main function to run subcortical parcellation analysis for a single subject/session.
    """

    # Setup atlas files
    atlas_lh_file = os.path.join(atlas_path, 'L.Schaefer2018_400Parcels_7Networks_order_4k_fslr.label.gii')
    atlas_rh_file = os.path.join(atlas_path, 'R.Schaefer2018_400Parcels_7Networks_order_4k_fslr.label.gii')
    atlas_lh = nib.load(atlas_lh_file)
    atlas_rh = nib.load(atlas_rh_file)
    atlas_both_hemi = np.concatenate((atlas_lh.darrays[0].data, atlas_rh.darrays[0].data))
    lh_label = np.unique(atlas_lh.darrays[0].data)
    rh_label = np.unique(atlas_rh.darrays[0].data)
    total_label = np.concatenate((lh_label[1:], rh_label[1:]))

    # File paths
    icaclass_dir = os.path.join(
        basedir_meg, subject, 'MEG', session, 'icaclass', 
        f'{subject}_MEG_{session_num}-{session}_icaclass.mat'
    )
    icaclass_vs_dir = os.path.join(
        basedir_meg, subject, 'MEG', session, 'icaclass', 
        f'{subject}_MEG_{session_num}-{session}_icaclass_vs.mat'
    )
    icamne_dir = os.path.join(
        basedir_preproc, subject, 'MEG', session, 'icamne', 
        f'{subject}_MEG_{session_num}-{session}_icamne.mat'
    )
    if save_path is None:
        save_path = os.path.join(basedir_preproc, subject, 'MEG', 'source_timeseries')
    os.makedirs(save_path, exist_ok=True)
    save_filename = f'source_timeseries_{session_num}_{session}.pkl'
    save_filepath = os.path.join(save_path, save_filename)
    print(f"Output will be saved to: {save_filepath}")

    def load_mat_file(filepath):
        print(f"Loading {os.path.basename(filepath)}...")
        try:
            data = sio.loadmat(filepath, struct_as_record=False, squeeze_me=True)
            print(f"  ✓ Loaded with scipy.io")
            return data, 'scipy'
        except NotImplementedError:
            data = h5py.File(filepath, 'r')
            print(f"  ✓ Loaded with h5py")
            return data, 'h5py'

    def resample_trial(trial_data, original_fs, target_fs):
        n_channels, n_samples = trial_data.shape
        n_samples_new = int(np.round(n_samples * target_fs / original_fs))
        trial_resampled = resample(trial_data, n_samples_new, axis=1)
        return trial_resampled

    # Load data
    print("="*60)
    print("LOADING DATA FILES")
    print("="*60)
    icaclass_data, icaclass_type = load_mat_file(icaclass_dir)
    icaclass_vs_data, icaclass_vs_type = load_mat_file(icaclass_vs_dir)
    icamne_data, icamne_type = load_mat_file(icamne_dir)

    print("\n" + "="*60)
    print("EXTRACTING FIELDS")
    print("="*60)

    # ============ Extract icamne data (source signals)
    if icamne_type == 'scipy':
        source_sig = list(icamne_data['source'].avg.mom)
    else:
        mom_refs = icamne_data['source']['avg']['mom'][:]
        source_sig = []
        for ref in mom_refs.flatten():
            data = icamne_data[ref][:].T
            source_sig.append(data)
    pos = icamne_data['source']['pos']
    tri = icamne_data['source']['tri']

    print(f"\nSource data:")
    print(f"  Number of source locations: {len(source_sig)}")
    print(f"  First source shape: {source_sig[0].shape} (expected: 3 x n_ICs)")

    # ============ Extract icaclass data
    if icaclass_type == 'scipy':
        comp_class = icaclass_data['comp_class']
        fsample = float(comp_class.fsample)
        # Access 'class' field robustly
        if hasattr(comp_class, 'class_'):
            class_struct = comp_class.class_
        elif hasattr(comp_class, '_class'):
            class_struct = comp_class._class
        else:
            class_struct = comp_class.__dict__.get('class', None)
            if class_struct is None:
                for field_name in getattr(comp_class, '_fieldnames', []):
                    if 'class' in field_name.lower():
                        class_struct = getattr(comp_class, field_name)
                        break
        brainic_index = np.array(class_struct.brain_ic).flatten().astype(int) - 1
        ica_trial_sig_full = list(comp_class.trial)
    else:
        comp_class = icaclass_data['comp_class']
        fsample = float(comp_class['fsample'][0, 0])
        brainic_ref = comp_class['class']['brain_ic'][0, 0]
        brainic_index = icaclass_data[brainic_ref][:].flatten().astype(int) - 1
        trial_refs = comp_class['trial'][:].flatten()
        ica_trial_sig_full = [icaclass_data[ref][:].T for ref in trial_refs]

    print(f"\nICA class data:")
    print(f"  Sample rate: {fsample} Hz")
    print(f"  Number of brain ICs: {len(brainic_index)}")
    print(f"  Brain IC indices: {brainic_index}")
    print(f"  Number of trials: {len(ica_trial_sig_full)}")
    for i, trial in enumerate(ica_trial_sig_full):
        print(f"    Trial {i+1} shape: {trial.shape} (expected: n_ICs x n_timepoints)")

    # ============ DOWNSAMPLE ICA TRIAL DATA
    print("\n" + "="*60)
    print("DOWNSAMPLING ICA TRIAL DATA")
    print("="*60)
    print(f"Original sampling rate: {fsample} Hz")
    print(f"Target sampling rate: {fdownsample} Hz")
    print(f"Downsampling ratio: {fsample/fdownsample:.2f}x")

    ica_trial_sig_downsampled = []
    for i, trial in enumerate(ica_trial_sig_full):
        trial_ds = resample_trial(trial, fsample, fdownsample)
        ica_trial_sig_downsampled.append(trial_ds)
        print(f"  Trial {i+1}: {trial.shape} -> {trial_ds.shape}")

    ica_trial_sig_full = ica_trial_sig_downsampled
    fsample = fdownsample  # Update sampling rate

    # ============ Extract icaclass_vs data
    if icaclass_vs_type == 'scipy':
        comp_class_vs = icaclass_vs_data['comp_class']
        fsample_vs = float(comp_class_vs.fsample)
        if hasattr(comp_class_vs, 'class_'):
            class_struct_vs = comp_class_vs.class_
        elif hasattr(comp_class_vs, '_class'):
            class_struct_vs = comp_class_vs._class
        else:
            class_struct_vs = comp_class_vs.__dict__.get('class', None)
            if class_struct_vs is None:
                for field_name in getattr(comp_class_vs, '_fieldnames', []):
                    if 'class' in field_name.lower():
                        class_struct_vs = getattr(comp_class_vs, field_name)
                        break
        brainic_index_vs = np.array(class_struct_vs.brain_ic).flatten().astype(int) - 1
        ica_trial_sig_vs_full = list(comp_class_vs.trial)
    else:
        comp_class_vs = icaclass_vs_data['comp_class']
        fsample_vs = float(comp_class_vs['fsample'][0, 0])
        brainic_ref_vs = comp_class_vs['class']['brain_ic'][0, 0]
        brainic_index_vs = icaclass_vs_data[brainic_ref_vs][:].flatten().astype(int) - 1
        trial_refs_vs = comp_class_vs['trial'][:].flatten()
        ica_trial_sig_vs_full = [icaclass_vs_data[ref][:].T for ref in trial_refs_vs]

    print(f"\nICA class VS data:")
    print(f"  Sample rate: {fsample_vs} Hz")
    print(f"  Number of brain ICs: {len(brainic_index_vs)}")

    # Downsample VS trials
    print("\nDownsampling VS trial data...")
    ica_trial_sig_vs_downsampled = []
    for i, trial in enumerate(ica_trial_sig_vs_full):
        trial_ds = resample_trial(trial, fsample_vs, fdownsample)
        ica_trial_sig_vs_downsampled.append(trial_ds)
        print(f"  VS Trial {i+1}: {trial.shape} -> {trial_ds.shape}")

    ica_trial_sig_vs_full = ica_trial_sig_vs_downsampled

    # ============ Select brain components from source signals
    print("\n" + "="*60)
    print("SELECTING BRAIN COMPONENTS")
    print("="*60)
    source_sig_brain = []
    for i, src in enumerate(source_sig):
        if src.shape[1] >= max(brainic_index) + 1:
            source_sig_brain.append(src[:, brainic_index])
        elif src.shape[0] >= max(brainic_index) + 1:
            source_sig_brain.append(src[brainic_index, :].T)
        else:
            raise ValueError(f"Cannot select brain ICs from source {i} with shape {src.shape}")

    source_sig = source_sig_brain
    print(f"Source signals after selection:")
    print(f"  First source shape: {source_sig[0].shape} (expected: 3 x n_brain_ICs)")

    # ============ Select brain components from ICA trial signals
    ica_trial_sig = []
    for i, trial in enumerate(ica_trial_sig_full):
        if trial.shape[0] >= max(brainic_index) + 1:
            ica_trial_sig.append(trial[brainic_index, :])
        elif trial.shape[1] >= max(brainic_index) + 1:
            ica_trial_sig.append(trial[:, brainic_index].T)
        else:
            raise ValueError(f"Cannot select brain ICs from trial {i} with shape {trial.shape}")

    print(f"Trial signals after selection:")
    for i, trial in enumerate(ica_trial_sig):
        print(f"  Trial {i+1} shape: {trial.shape} (expected: n_brain_ICs x n_timepoints)")

    # Same for icaclass_vs
    ica_trial_sig_vs = []
    for i, trial in enumerate(ica_trial_sig_vs_full):
        if trial.shape[0] >= max(brainic_index_vs) + 1:
            ica_trial_sig_vs.append(trial[brainic_index_vs, :])
        elif trial.shape[1] >= max(brainic_index_vs) + 1:
            ica_trial_sig_vs.append(trial[:, brainic_index_vs].T)
        else:
            raise ValueError(f"Cannot select brain ICs from VS trial {i} with shape {trial.shape}")

    # ============ Create source level trial signals
    print("\n" + "="*60)
    print("CREATING SOURCE LEVEL TRIAL SIGNALS")
    print("="*60)

    source_level_trial_sig = []
    for trl_index in range(len(ica_trial_sig)):
        trial_data = []
        for src_idx, src in enumerate(source_sig):
            result = src @ ica_trial_sig[trl_index]
            trial_data.append(result)
            if trl_index == 0 and src_idx == 0:
                print(f"Matrix multiplication check:")
                print(f"  Source: {src.shape} @ Trial: {ica_trial_sig[trl_index].shape} = Result: {result.shape}")
        source_level_trial_sig.append(trial_data)
        print(f"  Trial {trl_index+1}/{len(ica_trial_sig)} complete ({len(trial_data)} sources)")

    print(f"\nSource-level trial signals created:")
    print(f"  Number of trials: {len(source_level_trial_sig)}")
    print(f"  Sources per trial: {len(source_level_trial_sig[0])}")
    print(f"  First source shape: {source_level_trial_sig[0][0].shape} (expected: 3 x n_timepoints)")

    # (1) calculate power spectrum first 
    def calculate_power_spectrum(voxel_data, params, freq_range):
        S, f, Serr = mtspectrumc(voxel_data, params)
        return S, f

    # Setup parameters
    freq_range = [1, 50]
    k = 0
    voxel_data_list = source_level_trial_sig[k]
    n_voxels = len(voxel_data_list)

    params = {
        'tapers': [4,7],
        'pad': 0,
        'Fs': fdownsample,  # will set below
        'fpass': [1, 100],
        'trialave': True,
    }

    results = Parallel(n_jobs=n_jobs, verbose=5)(
        delayed(calculate_power_spectrum)(voxel_data_list[i], params, freq_range)
        for i in range(n_voxels)
    )
    S_list, f_list = zip(*results)

    # (2) then parcellate the power spectrum
    tmp_dat = np.array(S_list)
    atlas_both_hemi = np.concatenate((atlas_lh.darrays[0].data, atlas_rh.darrays[0].data))
    lh_label = np.unique(atlas_lh.darrays[0].data)
    rh_label = np.unique(atlas_rh.darrays[0].data)
    total_label = np.concatenate((lh_label[1:], rh_label[1:]))

    parcel_data = []
    for i, label in enumerate(total_label):
        print(f'{i}: {label}')
        indx = np.where(atlas_both_hemi == label)[0]
        if sum(indx) > 0:
            dat_indx = tmp_dat[indx]
            dat_indx_mean = np.mean(dat_indx, axis=0)
            parcel_data.append(dat_indx_mean)

    # (3) Calculate FOOOF
    parcel_len = len(parcel_data)
    print(parcel_len)
    freq_range = [1, 50]

    def process_voxel(f, voxel_data, freq_range):
        fm = FOOOF()
        fm.fit(f, voxel_data, freq_range)
        exp = fm.get_params('aperiodic_params', 'exponent')
        offset = fm.get_params('aperiodic_params', 'offset')
        error = fm.error_
        r2 = fm.r_squared_
        return exp, offset, error, r2

    results = Parallel(n_jobs=n_jobs, verbose=5)(
        delayed(process_voxel)(f_list[0], parcel_data[i], freq_range)
        for i in range(parcel_len)
    )

    exp_list, offset_list, error_list, r2_list = zip(*results)
    exp_list = list(exp_list)
    offset_list = list(offset_list)
    error_list = list(error_list)
    r2_list = list(r2_list)

    # Save important outputs as a pickle for further downstream steps:
    output_data = {
        'exp_list': exp_list,
        'offset_list': offset_list,
        'error_list': error_list,
        'r2_list': r2_list,
        'parcel_data': parcel_data,
        'f': f_list[0],
        'atlas_labels': total_label,
        'atlas_path': atlas_path,
        'subject': subject,
        'session': session,
        'session_num': session_num
    }
    with open(save_filepath, 'wb') as f:
        pickle.dump(output_data, f)
    
    print(f"Results saved to {save_filepath}")
    return output_data

# Example usage for SLURM (adapt params as needed for your run script):
# output = calculate_source_1overfexp(
#     subject='100307',
#     session='Restin',
#     session_num=3,
#     basedir_preproc='/data4/BrainED_project/shared/HCP_MEG_PREPROC',
#     basedir_meg='/data4/BrainED_project/shared/HCP_MEG',
#     atlas_path='/data4/BrainED_project/djung/parcellation/4k_fslr',
#     save_path=None,
#     n_jobs=24,
#     fdownsample=200,
# )


ModuleNotFoundError: No module named 'fooof'

In [None]:
#read mapped image
mapped_image_path = '/Users/dennis.jungchildmind.org/Desktop/subcortical_test/I38_new_confidence/warped_template.nii.gz'
mapped_image = ants.image_read(mapped_image_path)
print(mapped_image)


In [None]:
mapped_image_np = mapped_image.numpy()

In [None]:
print(mapped_image_np.shape)

In [None]:
print(np.unique(mapped_image_np))

In [None]:
original_image_path = '/Users/dennis.jungchildmind.org/Desktop/subcortical_test/I38_new_confidence/orig_img_resampled.nii'
original_image = ants.image_read(original_image_path)
print(original_image)
plot(original_image,axis=2)

In [None]:
tmp = nib.load(template_path)

In [None]:
from matplotlib import pyplot as plt
print(tmp.get_fdata().shape)
print(np.unique(tmp.get_fdata()).shape)


In [None]:
template_path = '/Users/dennis.jungchildmind.org/Desktop/atlas/mni_icbm152_nlin_asym_09c_nifti/mni_icbm152_t1_tal_nlin_asym_09c.nii'
headmask_path = '/Users/dennis.jungchildmind.org/Desktop/atlas/mni_icbm152_nlin_asym_09c_nifti/mni_icbm152_t1_tal_nlin_asym_09c_mask.nii'#only has t
template = ants.image_read(template_path)
headmask = ants.image_read(headmask_path)
masked_template = ants.mask_image(template, headmask,level=1,binarize=False)#this removes the skull right away
# Get numpy array
data = masked_template.numpy()

# Create left hemisphere mask
# In MNI space with ANTs, need to check orientation
center_x = data.shape[0] // 2
left_mask = np.zeros_like(data)
left_mask[:center_x, :, :] = 1

# Create ANTs image from mask
mask_img = ants.from_numpy(left_mask, origin=masked_template.origin, 
                           spacing=masked_template.spacing, direction=masked_template.direction)
left_hemisphere = masked_template * mask_img
plot(left_hemisphere,axis=1)
# Save result
ants.image_write(left_hemisphere, 'mni_icbm152_t1_tal_nlin_asym_09c_left_hemi.nii')

#do for right hemisphere
right_mask = np.zeros_like(data)
right_mask[center_x:, :, :] = 1
right_hemisphere = masked_template * ants.from_numpy(right_mask, origin=masked_template.origin, 
                           spacing=masked_template.spacing, direction=masked_template.direction)
plot(right_hemisphere,axis=1)
ants.image_write(right_hemisphere, 'mni_icbm152_t1_tal_nlin_asym_09c_right_hemi.nii')

In [None]:


# Define file paths
template_path = '/Users/dennis.jungchildmind.org/Desktop/atlas/mni_icbm152_nlin_asym_09c_nifti/mni_icbm152_t2_tal_nlin_asym_09c.nii'
headmask_path = '/Users/dennis.jungchildmind.org/Desktop/atlas/mni_icbm152_nlin_asym_09c_nifti/mni_icbm152_t1_tal_nlin_asym_09c_mask.nii'#only has t
atlas_path = '/Users/dennis.jungchildmind.org/Desktop/subcortical_test/Schaefer2018_400Parcels_7Networks_order_Tian_Subcortex_S1_3T_MNI152NLin2009cAsym_1mm.nii.gz'

# Read images
template = ants.image_read(template_path)
headmask = ants.image_read(headmask_path)
atlas = ants.image_read(atlas_path)

#make sure the shape of the atlas and template are the same
if atlas.shape != template.shape:
    print("The shape of the atlas and template are not the same")
    exit()
else:
    print("The shape of the atlas and template are the same")


In [None]:
masked_template = ants.mask_image(template, headmask,level=1,binarize=False)#this removes the skull right away
plot(masked_template,axis=2)
plot(atlas,axis=2)

In [None]:
#get labels from the atlas so that I can generate a mask for the left or right hemisphere
atlas_label = atlas.numpy()
#just manual separate of lh and rh 
# Fix: Combine hardcoded values with range for RH labels according to the comment
rh_label = list(range(1, 9)) + list(range(217, 417))
lh_label = list(range(9, 217))  # 9 through 216 inclusive

#generate a mask for the left hemisphere
left_hemisphere_mask = np.isin(atlas_label, lh_label)
right_hemisphere_mask = np.isin(atlas_label, rh_label)
# Convert to ANTs images
# Method 1: Using from_numpy (RECOMMENDED)
left_hemi_mask_ants = ants.from_numpy(
    left_hemisphere_mask.astype('float32'),  # Convert bool to float
    origin=atlas.origin,
    spacing=atlas.spacing,
    direction=atlas.direction
)

right_hemi_mask_ants = ants.from_numpy(
    right_hemisphere_mask.astype('float32'),
    origin=atlas.origin,
    spacing=atlas.spacing,
    direction=atlas.direction
)


In [None]:
masked_template_lh = ants.mask_image(masked_template,left_hemi_mask_ants,level=1,binarize=False)
plot(masked_template_lh,axis=2)
masked_template_rh = ants.mask_image(masked_template,right_hemi_mask_ants,level=1,binarize=False)
plot(masked_template_rh,axis=2)

#save the masked templates
# Save the masks
ants.image_write(masked_template_lh, 'mni_icbm152_t2_tal_nlin_asym_09c_masked_lh_only.nii.gz')
ants.image_write(masked_template_rh, 'mni_icbm152_t2_tal_nlin_asym_09c_masked_rh_only.nii.gz')

print(f"Left hemisphere mask: {left_hemi_mask_ants.shape}")
print(f"Right hemisphere mask: {right_hemi_mask_ants.shape}")
print(f"Left hemisphere voxels: {np.sum(left_hemisphere_mask)}")
print(f"Right hemisphere voxels: {np.sum(right_hemisphere_mask)}")

In [17]:
print(np.array(dists_dict['dist_lh']))
print(np.array(dists_dict['label']))


[        nan 43.17040495 45.49635934 27.22282983 20.5721383  42.28919479
 34.69781011 42.19293438 38.51152827 26.79548571 32.63160857 10.1579598
  4.61876665 35.47979882 19.18625484 25.7091647  29.84380739 43.03754274
 35.33918616 56.24260294 42.04502685 58.89166708 72.24595763 27.7690414
 64.72923622 72.03813689 44.65383935 59.05703344 82.23463442 27.44741097
 76.18864038 80.68663779 39.36571949 66.2874954  70.74577119 79.41141228
 31.18349471 56.60987574 47.37821486 72.60733547 79.42190782 69.0794426
 55.19709977 70.15999031 73.00640895 76.61737237 69.86386435 73.14514234
 42.52455668 46.98636966 27.13232676 31.0549871  30.19653251 53.59332156
 36.19690075 40.95275916 53.59547952 40.84026802 53.68958106 55.70575286
 59.85947819 50.75139931 57.59001885 56.9255993  52.10064949 40.96213839
 52.44423685 62.37445128 65.68359929 55.86807572 51.96142655 51.00856329
 54.48740379 63.47207078 66.93354088 61.78242829 65.70614992 61.65689369
 62.24928045 68.82978459 63.96824179 67.48086643 65.08