In [50]:
import numpy as np
import pandas as pd
import mne
import matplotlib.pyplot as plt
%matplotlib inline

In [51]:
age_groups = {'2-4': range(2, 5),
              '5-6': range(5, 7),
              '7-9': range(7, 10),
              '10+': range(10, 20)}

In [52]:
from pathlib import Path

path = Path('../data/asd/raw')
freqs = np.linspace(4, 12, 41)

info = pd.read_csv(path / 'path_file.csv')
info['fn'] = info['fn'].str.replace('и' + chr(774), 'й') # For some reason, filenames encode й as 2 unicode characters
typical = np.where(info['target'] == 'typical')[0]
asd = np.where(info['target'] == 'asd')[0]

eegs = [mne.io.read_raw_fif(path / fn, verbose=False) for fn in info['fn'].values]

assert np.all([eegs[i].info['sfreq'] == sfreq for i, sfreq in info['sfreq'].items()])
info

Unnamed: 0,fn,target,dataset_name,sfreq,age,seconds
0,sedrykyn_sasha_7_og_concat_19.raw.fif,asd,asd,125,7,47.0
1,roma gritchin _5_fon_open_19.raw.fif,asd,asd,125,5,33.0
2,boy5_asd_og_new_19.raw.fif,asd,asd,125,5,50.0
3,viflyancev_4_asd_fon__concat_19.raw.fif,asd,asd,125,4,58.0
4,andrey_matveev3_asd_new_19.raw.fif,asd,asd,125,3,50.0
...,...,...,...,...,...,...
330,artem_sarkis_6_fon_19.raw.fif,typical,typical,125,6,44.0
331,gusarova_9_norm_19.raw.fif,typical,typical,125,9,83.0
332,акулов миша 10 от даши_ог_19.raw.fif,typical,typical,125,10,42.0
333,жавнис_3_19.raw.fif,typical,typical,125,3,386.0


In [53]:
sfreq = 125
ch_names = eegs[0].ch_names
eeg_info = mne.create_info(ch_names, sfreq=sfreq, ch_types='eeg')

shapes = pd.DataFrame([eeg[:][0].shape for eeg in eegs])
to_drop = np.where(shapes[0] != 19)[0]
assert np.all(shapes[1] == info['seconds'] * sfreq)

In [54]:
def plot_eeg(eeg, title=None):
    n_seconds = eeg.n_times // sfreq
    y_max = np.quantile(abs(eeg[:][0]), 0.995)

    fig, axes = plt.subplots(19, 1, figsize=(n_seconds, 12))
    if title is not None:
        fig.suptitle(title, fontsize=16)

    for ch_name, ax in zip(ch_names, axes):
        signal, times = eeg[ch_name]
        ax.plot(times, signal.flatten(), linewidth=1)
        ax.set_ylabel(ch_name, rotation=0, labelpad=25, fontsize=13)

    axes[0].set_ylabel(ch_names[0], rotation=0, labelpad=-30, fontsize=13)

    for ax in axes:
        ax.spines['top'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.set_xlim(0, n_seconds)
        ax.set_ylim(-y_max, y_max)
        ax.patch.set_alpha(0)
        ax.set_xticks([])
        ax.set_yticks([])

    axes[-1].spines['bottom'].set_visible(True)
    axes[-1].set_xticks(np.arange(n_seconds))
    axes[0].set_yticks([0, y_max])
    axes[0].set_yticklabels([0, f"{y_max * 1e6:.2f} mV"])

    fig.subplots_adjust(hspace=-0.3)

In [57]:
image_path = Path('../data/asd/images')
orig_path = image_path / 'non-filt'
filt_path = image_path / 'filt'
for path in (image_path, orig_path, filt_path):
    if not path.exists():
        path.mkdir()

filters = {'2-4': (4, 12), '5-6': (5, 12), '7-9': (6, 13), '10+': (7, 13)}

dfs = []
starts = []

for ag_name, ag in age_groups.items():
    dfs.append(info[info['age'].isin(ag)].sample(10))

    (orig_path / ag_name).mkdir()
    (filt_path / ag_name).mkdir()
    for i in dfs[-1].index:
        n_seconds = info.loc[i, 'seconds']
        end = np.random.randint(10, n_seconds)
        starts.append(end - 10)
        eeg = eegs[i].copy().load_data().crop(end - 10, end)
        plot_eeg(eeg, f'Subject {i}')
        plt.savefig(orig_path/ag_name/ f'subject-{i}.png')
        plt.close()
        plot_eeg(eeg.filter(*filters[ag_name]), f'Subject {i}')
        plt.savefig(filt_path/ag_name/ f'subject-{i}.png')
        plt.close()

pd.concat(dfs).assign(start=starts).to_csv(image_path/'info.csv')

In [228]:
np.random.randint(10, info.loc[1, 'seconds'])

16

In [32]:
cropped_info = pd.DataFrame()

cropped_info['orig_id'] = info.index
cropped_info['orig_fn'] = info['fn']
cropped_info['type'] = info['dataset_name']
cropped_info = cropped_info.drop(to_drop)
cropped_info['id'] = np.random.permutation(len(cropped_info))

In [30]:
def crop(eeg, seconds, start=None):
    signals = eeg[:][0]
    ticks = seconds * sfreq
    start = start or np.random.randint(0, signals.shape[1] - ticks - 1)
    return signals[:, start:start+ticks+1]

cropped_eegs = {id: mne.io.RawArray(crop(eegs[orig_id], 30), eeg_info, verbose=False) for orig_id, id in cropped_info['id'].items()}

In [37]:
cropped_info = cropped_info.sort_values('id').reset_index(drop=True)

In [39]:
out_path = Path('../data/asd/cropped')

In [41]:
cropped_info.to_csv(out_path / 'info.csv')

for id, eeg in cropped_eegs.items():
    eeg.save(out_path / f'{id}.raw.fif')