In [2]:
from pathlib import Path

from tqdm.notebook import tqdm
import numpy as np
import pandas as pd
import mne
import mne_connectivity

import acareeg

In [3]:
recompute = False
mastersheet = acareeg.eegip.get_mastersheet()
mastersheet = mastersheet.reset_index()[["Site", "ID", '6m_EEGLAB_QCR_File', 
                                         '12m_EEGLAB_QCR_File', '18m_EEGLAB_QCR_File']]
mastersheet = mastersheet.dropna(how="all")
mastersheet = mastersheet[mastersheet.Site != "Boston"]
mastersheet["subject_no"] = [int(subject[-3:]) if isinstance(subject, str) else subject 
                             for subject in mastersheet["ID"]]

fmin = (3, 6, 9, 30, 4)
fmax = (5, 9, 30, 100, 100)
bands = ("theta", "alpha", "beta", "gamma", "broadband")
con_names = ("ciplv", )
mode = 'multitaper'

common_template_age = 12
use_same_template_age = True
tmax = 1.0
min_nb_epochs = 20
np.random.seed(324234)
con_path = Path("/Volumes/usc_data/ElementSE/eegip/con_paper/")
bids_root = Path("/Volumes/usc_data/ElementSE/eegip")
subjects_dir = "."

In [None]:
for index, row in tqdm(mastersheet.iterrows(), total=len(mastersheet)):
    dataset = row.Site.lower()
    for age in [6, 12, 18]:

        done = True
        for con_name in con_names:
            if use_same_template_age:
                path_out = con_path / f"{row.subject_no}_{dataset}_{age}_{con_name}_{common_template_age}m-template.csv"
            else:
                path_out = con_path / f"{row.subject_no}_{dataset}_{age}_{con_name}.csv"
            if not path_out.exists():
                done = False

        if done and not recompute:
            print(f"skip {row.subject_no}_{dataset}_{age}")
            continue

        file_name = row[f"{age}m_EEGLAB_QCR_File"]
        epochs = acareeg.eegip.get_resting_state_epochs(row.subject_no, dataset, age, 
                                                        bids_root=bids_root,
                                                        tmax=tmax)

        if epochs is None:
            continue
        if len(epochs) < min_nb_epochs:
            continue

        epochs.set_eeg_reference(projection=True)        

        if not isinstance(file_name, str) or "qcr.se" not in file_name:
            # "qcr.se" because .set is currently truncated for Washington
            continue

        if use_same_template_age:
            sources_age = common_template_age
        else:
            sources_age = age            
   
        
        # Currently, we don't include the volume sources because there is no way to simulate
        # with mixed source models.                    
        label_ts, anat_label = acareeg.infantmodels.compute_sources(epochs, sources_age, 
                                                                    subjects_dir=subjects_dir, 
                                                                    return_labels=True, return_xr=False, loose=0, 
                                                                    fixed=True, inv_method="eLORETA", pick_ori=None,
                                                                    lambda2=1e-4, minimal_snr=None, verbose=False, 
                                                                    include_vol_src=False)        

        label_names = np.array([label.name for label in anat_label])
        sfreq = epochs.info['sfreq']            

        for con_name in con_names:
            if use_same_template_age:
                path_out = con_path / f"{row.subject_no}_{dataset}_{age}_{con_name}_{common_template_age}m-template.csv"
            else:
                path_out = con_path / f"{row.subject_no}_{dataset}_{age}_{con_name}.csv"
            if path_out.exists() and not recompute:
                continue

            nb_iters = int(len(epochs)/min_nb_epochs*2)     
            con_df = []
            for _ in tqdm(list(range(nb_iters)), leave=False):
                inds = np.random.choice(np.arange(len(epochs)), min_nb_epochs, False).astype(int)            
                con = mne_connectivity.spectral_connectivity_epochs(np.array(label_ts)[inds], 
                                                                    method=con_name,
                                                                    mode=mode, sfreq=sfreq, fmin=fmin,
                                                                    fmax=fmax, faverage=True, verbose=False)

                dfs = [] 
                for mat, band in zip(con.get_data("dense").transpose(2, 0, 1), bands):
                    mat = pd.DataFrame(mat) + np.triu(mat * np.nan)
                    mat.index = label_names
                    mat.columns = label_names
                    df = mat.reset_index().melt(id_vars="index").dropna()
                    df.columns = ["region1", "region2", "con"]
                    df["con_name"] = con_name
                    df["band"] = band
                    df["age"] = age
                    dfs.append(df)

                con_df.append(pd.concat(dfs))
            pd.concat(con_df).to_csv(path_out)
