In [67]:
import pandas as pd
import os
import os.path as op
import h5py
import numpy as np
import matplotlib.pyplot as plt
from dtw import *
from sklearn.decomposition import PCA
from sklearn.preprocessing import RobustScaler
from scipy.stats.mstats import winsorize
from tqdm.notebook import trange, tqdm
from scipy.ndimage import gaussian_filter1d
from mne.stats import permutation_cluster_test
from matplotlib import colors
import scipy.io
from mne.io import read_epochs_eeglab

In [64]:
pipeline='NEARICA'
ages=['9m','12m','adult']

cluster_chans={
    'C': ['E16', 'E20', 'E21', 'E22','E41', 'E49', 'E50', 'E51'],
    'P': ['E26', 'E27','E28', 'E31', 'E40','E42', 'E45', 'E46']
}

def load_bursts(pipeline, age, clus_name):
    bursts = {
        'age': [],
        'subject': [],
        'chan': [],
        'cluster': [],
        'condition': [],
        'epoch': [],
        'fwhm_freq': [],
        'fwhm_time': [],
        'peak_adjustment': [],
        'peak_amp_base': [],
        'peak_amp_iter': [],
        'peak_freq': [],
        'peak_time': [],
        'polarity': [],
        'trial': [],
        'waveform': [],        
        'waveform_times': []
    }

    subjects=pd.read_csv(op.join('/home/bonaiuto/dev_beta_umd/data',age,'data/participants.tsv'), sep='\t')
    for subject in subjects['participant_id']:
        print('loading {}'.format(subject))
        fname=op.join('/home/bonaiuto/dev_beta_umd/data',age,'derivatives',pipeline,subject,'processed_data/processed_{}_bursts.mat'.format(clus_name))
        if os.path.exists(fname):
            with h5py.File(fname, 'r') as f:
                bursts['age'].extend([age for x in list(f['bursts']['trial'])])
                bursts['subject'].extend([subject for x in list(f['bursts']['trial'])])
                bursts['chan'].extend(list(np.squeeze(f['bursts']['chan'][:])))
                bursts['cluster'].extend([u''.join(chr(c) for c in np.squeeze(f[obj_ref][:])) for obj_ref in np.squeeze(f['bursts']['cluster'][:])])
                bursts['condition'].extend([u''.join(chr(c) for c in np.squeeze(f[obj_ref][:])) for obj_ref in np.squeeze(f['bursts']['condition'][:])])
                bursts['epoch'].extend([u''.join(chr(c) for c in np.squeeze(f[obj_ref][:])) for obj_ref in np.squeeze(f['bursts']['epoch'][:])])
                bursts['fwhm_freq'].extend(list(np.squeeze(f['bursts']['fwhm_freq'][:])))
                bursts['fwhm_time'].extend(list(np.squeeze(f['bursts']['fwhm_time'][:])))
                bursts['peak_adjustment'].extend(list(np.squeeze(f['bursts']['peak_adjustment'][:])))
                bursts['peak_amp_base'].extend(list(np.squeeze(f['bursts']['peak_amp_base'][:])))
                bursts['peak_amp_iter'].extend(list(np.squeeze(f['bursts']['peak_amp_iter'][:])))
                bursts['peak_freq'].extend(list(np.squeeze(f['bursts']['peak_freq'][:])))
                bursts['peak_time'].extend(list(np.squeeze(f['bursts']['peak_time'][:])))
                bursts['polarity'].extend(list(np.squeeze(f['bursts']['polarity'][:])))
                bursts['trial'].extend(list(np.squeeze(f['bursts']['trial'][:])))
                for i in range(f['bursts']['waveform'][:,:].shape[1]):
                    bursts['waveform'].append(f['bursts']['waveform'][:,i])
                bursts['waveform_times']=f['bursts']['waveform_times'][:]
                
    bursts['age']=np.array(bursts['age'])
    bursts['subject']=np.array(bursts['subject'])
    bursts['chan']=np.array(bursts['chan'])
    bursts['cluster']=np.array(bursts['cluster'])
    bursts['condition']=np.array(bursts['condition'])
    bursts['epoch']=np.array(bursts['epoch'])
    bursts['fwhm_freq']=np.array(bursts['fwhm_freq'])
    bursts['fwhm_time']=np.array(bursts['fwhm_time'])
    bursts['peak_adjustment']=np.array(bursts['peak_adjustment'])
    bursts['peak_amp_base']=np.array(bursts['peak_amp_base'])
    bursts['peak_amp_iter']=np.array(bursts['peak_amp_iter'])
    bursts['peak_freq']=np.array(bursts['peak_freq'])
    bursts['peak_time']=np.array(bursts['peak_time'])
    bursts['polarity']=np.array(bursts['polarity'])
    bursts['trial']=np.array(bursts['trial'])
    bursts['waveform']=np.array(bursts['waveform'])
    bursts['waveform_times']=np.array(bursts['waveform_times'])
    
    return bursts

In [65]:
def filter_bursts(bursts, idx):
    
    filtered_bursts={
        'age': bursts['age'][idx],
        'subject': bursts['subject'][idx],
        'chan': bursts['chan'][idx],
        'cluster': bursts['cluster'][idx],
        'condition': bursts['condition'][idx],
        'epoch': bursts['epoch'][idx],
        'fwhm_freq': bursts['fwhm_freq'][idx],
        'fwhm_time': bursts['fwhm_time'][idx],
        'peak_adjustment': bursts['peak_adjustment'][idx],
        'peak_amp_base': bursts['peak_amp_base'][idx],
        'peak_amp_iter': bursts['peak_amp_iter'][idx],
        'peak_freq': bursts['peak_freq'][idx],
        'peak_time': bursts['peak_time'][idx],
        'polarity': bursts['polarity'][idx],
        'trial': bursts['trial'][idx],
        'waveform': bursts['waveform'][idx,:],
        'waveform_times': bursts['waveform_times']
    }

    return filtered_bursts

In [68]:
c_bursts_9m=load_bursts(pipeline, '9m', 'C')
p_bursts_9m=load_bursts(pipeline, '9m', 'P')

loading sub-010
loading sub-012
loading sub-013
loading sub-014
loading sub-015
loading sub-016
loading sub-018
loading sub-019
loading sub-020
loading sub-021
loading sub-023
loading sub-024
loading sub-025
loading sub-026
loading sub-027
loading sub-028
loading sub-029
loading sub-030
loading sub-031
loading sub-032
loading sub-033
loading sub-036
loading sub-037
loading sub-038
loading sub-041
loading sub-042
loading sub-043
loading sub-044
loading sub-010
loading sub-012
loading sub-013
loading sub-014
loading sub-015
loading sub-016
loading sub-018
loading sub-019
loading sub-020
loading sub-021
loading sub-023
loading sub-024
loading sub-025
loading sub-026
loading sub-027
loading sub-028
loading sub-029
loading sub-030
loading sub-031
loading sub-032
loading sub-033
loading sub-036
loading sub-037
loading sub-038
loading sub-041
loading sub-042
loading sub-043
loading sub-044


In [69]:
c_bursts_12m=load_bursts(pipeline, '12m', 'C')
p_bursts_12m=load_bursts(pipeline, '12m', 'P')

loading sub-001
loading sub-002
loading sub-003
loading sub-004
loading sub-005
loading sub-006
loading sub-007
loading sub-008
loading sub-009
loading sub-010
loading sub-011
loading sub-012
loading sub-013
loading sub-014
loading sub-015
loading sub-016
loading sub-017
loading sub-018
loading sub-019
loading sub-020
loading sub-021
loading sub-022
loading sub-023
loading sub-024
loading sub-025
loading sub-026
loading sub-027
loading sub-028
loading sub-029
loading sub-030
loading sub-031
loading sub-032
loading sub-033
loading sub-001
loading sub-002
loading sub-003
loading sub-004
loading sub-005
loading sub-006
loading sub-007
loading sub-008
loading sub-009
loading sub-010
loading sub-011
loading sub-012
loading sub-013
loading sub-014
loading sub-015
loading sub-016
loading sub-017
loading sub-018
loading sub-019
loading sub-020
loading sub-021
loading sub-022
loading sub-023
loading sub-024
loading sub-025
loading sub-026
loading sub-027
loading sub-028
loading sub-029
loading 

In [70]:
c_bursts_adult=load_bursts(pipeline, 'adult', 'C')
p_bursts_adult=load_bursts(pipeline, 'adult', 'P')

loading sub-001
loading sub-002
loading sub-003
loading sub-004
loading sub-005
loading sub-006
loading sub-007
loading sub-008
loading sub-009
loading sub-010
loading sub-011
loading sub-012
loading sub-013
loading sub-014
loading sub-015
loading sub-016
loading sub-017
loading sub-018
loading sub-019
loading sub-020
loading sub-021
loading sub-022
loading sub-001
loading sub-002
loading sub-003
loading sub-004
loading sub-005
loading sub-006
loading sub-007
loading sub-008
loading sub-009
loading sub-010
loading sub-011
loading sub-012
loading sub-013
loading sub-014
loading sub-015
loading sub-016
loading sub-017
loading sub-018
loading sub-019
loading sub-020
loading sub-021
loading sub-022


In [71]:
idx=np.where((np.abs(c_bursts_9m['peak_amp_base']-np.median(c_bursts_9m['peak_amp_base']))<2.5*np.std(c_bursts_9m['peak_amp_base'])))[0]
c_bursts_9m=filter_bursts(c_bursts_9m, idx)
idx=np.where((np.abs(p_bursts_9m['peak_amp_base']-np.median(p_bursts_9m['peak_amp_base']))<2.5*np.std(p_bursts_9m['peak_amp_base'])))[0]
p_bursts_9m=filter_bursts(p_bursts_9m, idx)

idx=np.where((np.abs(c_bursts_12m['peak_amp_base']-np.median(c_bursts_12m['peak_amp_base']))<2.5*np.std(c_bursts_12m['peak_amp_base'])))[0]
c_bursts_12m=filter_bursts(c_bursts_12m, idx)
idx=np.where((np.abs(p_bursts_12m['peak_amp_base']-np.median(p_bursts_12m['peak_amp_base']))<2.5*np.std(p_bursts_12m['peak_amp_base'])))[0]
p_bursts_12m=filter_bursts(p_bursts_12m, idx)

idx=np.where((np.abs(c_bursts_adult['peak_amp_base']-np.median(c_bursts_adult['peak_amp_base']))<2.5*np.std(c_bursts_adult['peak_amp_base'])))[0]
c_bursts_adult=filter_bursts(c_bursts_adult, idx)
idx=np.where((np.abs(p_bursts_adult['peak_amp_base']-np.median(p_bursts_adult['peak_amp_base']))<2.5*np.std(p_bursts_adult['peak_amp_base'])))[0]
p_bursts_adult=filter_bursts(p_bursts_adult, idx)

In [72]:
c_burst_times_9m=c_bursts_9m['waveform_times']
c_burst_times_12m=c_bursts_12m['waveform_times']
c_burst_times_adult=c_bursts_adult['waveform_times']
med_c_9m=np.percentile(c_bursts_9m['waveform'], 50, axis=0)
med_c_12m=np.percentile(c_bursts_12m['waveform'], 50, axis=0)
med_c_adult=np.percentile(c_bursts_adult['waveform'], 50, axis=0)
norm_c_adult=med_c_adult/np.max(med_c_adult)
norm_c_12m=med_c_12m/np.max(med_c_12m)
norm_c_9m=med_c_9m/np.max(med_c_9m)
alignment_c_9m=dtw(norm_c_9m, norm_c_adult, keep_internals=True, step_pattern=rabinerJuangStepPattern(5,"c"))
aligned_c_9m_idx=warp(alignment_c_9m,index_reference=False)
alignment_c_12m=dtw(norm_c_12m, norm_c_adult, keep_internals=True, step_pattern=rabinerJuangStepPattern(5,"c"))
aligned_c_12m_idx=warp(alignment_c_12m,index_reference=False)
aligned_c_9m_bursts=c_bursts_9m['waveform'][:,aligned_c_9m_idx]
aligned_c_12m_bursts=c_bursts_12m['waveform'][:,aligned_c_12m_idx]
aligned_c_adult_bursts=c_bursts_adult['waveform'][:,:-1]

p_burst_times_9m=p_bursts_9m['waveform_times']
p_burst_times_12m=p_bursts_12m['waveform_times']
p_burst_times_adult=p_bursts_adult['waveform_times']
med_p_9m=np.percentile(p_bursts_9m['waveform'], 50, axis=0)
med_p_12m=np.percentile(p_bursts_12m['waveform'], 50, axis=0)
med_p_adult=np.percentile(p_bursts_adult['waveform'], 50, axis=0)
norm_p_adult=med_p_adult/np.max(med_p_adult)
norm_p_12m=med_p_12m/np.max(med_p_12m)
norm_p_9m=med_p_9m/np.max(med_p_9m)
alignment_p_9m=dtw(norm_p_9m, norm_p_adult, keep_internals=True, step_pattern=rabinerJuangStepPattern(5,"c"))
aligned_p_9m_idx=warp(alignment_p_9m,index_reference=False)
alignment_p_12m=dtw(norm_p_12m, norm_p_adult, keep_internals=True, step_pattern=rabinerJuangStepPattern(5,"c"))
aligned_p_12m_idx=warp(alignment_p_12m,index_reference=False)
aligned_p_9m_bursts=p_bursts_9m['waveform'][:,aligned_p_9m_idx]
aligned_p_12m_bursts=p_bursts_12m['waveform'][:,aligned_p_12m_idx]
aligned_p_adult_bursts=p_bursts_adult['waveform'][:,:-1]

aligned_c_bursts=np.vstack([aligned_c_9m_bursts,aligned_c_12m_bursts,aligned_c_adult_bursts])
aligned_p_bursts=np.vstack([aligned_p_9m_bursts,aligned_p_12m_bursts,aligned_p_adult_bursts])

In [73]:
c_pca = PCA(n_components=20)
burst_std_c_9m=RobustScaler().fit_transform(aligned_c_9m_bursts)
burst_std_c_12m=RobustScaler().fit_transform(aligned_c_12m_bursts)
burst_std_c_adult=RobustScaler().fit_transform(aligned_c_adult_bursts)
burst_std_c=np.vstack([burst_std_c_9m, burst_std_c_12m, burst_std_c_adult])
c_pca=c_pca.fit(burst_std_c)
pcs_c_9m = c_pca.transform(burst_std_c_9m)
pcs_c_12m = c_pca.transform(burst_std_c_12m)
pcs_c_adult = c_pca.transform(burst_std_c_adult)

p_pca = PCA(n_components=20)
burst_std_p_9m=RobustScaler().fit_transform(aligned_p_9m_bursts)
burst_std_p_12m=RobustScaler().fit_transform(aligned_p_12m_bursts)
burst_std_p_adult=RobustScaler().fit_transform(aligned_p_adult_bursts)
burst_std_p=np.vstack([burst_std_p_9m, burst_std_p_12m, burst_std_p_adult])
p_pca=p_pca.fit(burst_std_p)
pcs_p_9m = p_pca.transform(burst_std_p_9m)
pcs_p_12m = p_pca.transform(burst_std_p_12m)
pcs_p_adult = p_pca.transform(burst_std_p_adult)

for p in range(20):
    pcs_c_9m[:,p]=winsorize(pcs_c_9m[:,p], limits=(.01, .01))   
    pcs_p_9m[:,p]=winsorize(pcs_p_9m[:,p], limits=(.01, .01))   
pca_components_c_9m = pd.DataFrame(pcs_c_9m)
pca_components_p_9m = pd.DataFrame(pcs_p_9m)
for p in range(20):
    pcs_c_12m[:,p]=winsorize(pcs_c_12m[:,p], limits=(.01, .01))   
    pcs_p_12m[:,p]=winsorize(pcs_p_12m[:,p], limits=(.01, .01))   
pca_components_c_12m = pd.DataFrame(pcs_c_12m)
pca_components_p_12m = pd.DataFrame(pcs_p_12m)
for p in range(20):
    pcs_c_adult[:,p]=winsorize(pcs_c_adult[:,p], limits=(.01, .01))   
    pcs_p_adult[:,p]=winsorize(pcs_p_adult[:,p], limits=(.01, .01))   
pca_components_c_adult = pd.DataFrame(pcs_c_adult)
pca_components_p_adult = pd.DataFrame(pcs_p_adult)

pcs_c=np.vstack([pcs_c_9m,pcs_c_12m,pcs_c_adult])
pcs_p=np.vstack([pcs_p_9m,pcs_p_12m,pcs_p_adult])

In [74]:
pca_c_9m_dict={i:c_bursts_9m[i] for i in c_bursts_9m if i!='waveform' and i!='waveform_times'}
for i in range(20):
    pca_c_9m_dict['PC{}'.format(i+1)]=pcs_c_9m[:,i]
df_c_9m=pd.DataFrame.from_dict(pca_c_9m_dict,orient='index').transpose()
pca_p_9m_dict={i:p_bursts_9m[i] for i in p_bursts_9m if i!='waveform' and i!='waveform_times'}
for i in range(20):
    pca_p_9m_dict['PC{}'.format(i+1)]=pcs_p_9m[:,i]
df_p_9m=pd.DataFrame.from_dict(pca_p_9m_dict,orient='index').transpose()

In [75]:
pca_c_12m_dict={i:c_bursts_12m[i] for i in c_bursts_12m if i!='waveform' and i!='waveform_times'}
for i in range(20):
    pca_c_12m_dict['PC{}'.format(i+1)]=pcs_c_12m[:,i]
df_c_12m=pd.DataFrame.from_dict(pca_c_12m_dict,orient='index').transpose()
pca_p_12m_dict={i:p_bursts_12m[i] for i in p_bursts_12m if i!='waveform' and i!='waveform_times'}
for i in range(20):
    pca_p_12m_dict['PC{}'.format(i+1)]=pcs_p_12m[:,i]
df_p_12m=pd.DataFrame.from_dict(pca_p_12m_dict,orient='index').transpose()

In [76]:
pca_c_adult_dict={i:c_bursts_adult[i] for i in c_bursts_adult if i!='waveform' and i!='waveform_times'}
for i in range(20):
    pca_c_adult_dict['PC{}'.format(i+1)]=pcs_c_adult[:,i]
df_c_adult=pd.DataFrame.from_dict(pca_c_adult_dict,orient='index').transpose()
pca_p_adult_dict={i:p_bursts_adult[i] for i in p_bursts_adult if i!='waveform' and i!='waveform_times'}
for i in range(20):
    pca_p_adult_dict['PC{}'.format(i+1)]=pcs_p_adult[:,i]
df_p_adult=pd.DataFrame.from_dict(pca_p_adult_dict,orient='index').transpose()

In [78]:
df=pd.concat([df_c_9m,df_c_12m,df_c_adult], ignore_index=True, sort=False)
df.to_csv('/home/bonaiuto/dev_beta_umd/output/bursts_c.csv')
df=pd.concat([df_p_9m,df_p_12m,df_p_adult], ignore_index=True, sort=False)
df.to_csv('/home/bonaiuto/dev_beta_umd/output/bursts_p.csv')