# Data loading

In [1]:
import numpy as np
import pandas as pd
import mne
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
path = Path('../data/asd/raw')
freqs = np.linspace(4, 12, 41)

info = pd.read_csv(path / 'path_file.csv')
info['fn'] = info['fn'].str.replace('и' + chr(774), 'й') # For some reason, filenames encode й as 2 unicode characters
info = info.groupby("fn").filter(lambda x: len(x) == 1)
info.reset_index(drop=True, inplace=True)
typical = np.where(info['target'] == 'typical')[0]
asd = np.where(info['target'] == 'asd')[0]

eegs = {i: mne.io.read_raw_fif(path / fn, verbose=False) for i, fn in info['fn'].items()}
ch_names = eegs[0].ch_names

sfreq=125

assert np.all([eegs[i].info['sfreq'] == sfreq for i, sfreq in info['sfreq'].items()])
info

Unnamed: 0,fn,target,dataset_name,sfreq,age,seconds
0,sedrykyn_sasha_7_og_concat_19.raw.fif,asd,asd,125,7,47.0
1,roma gritchin _5_fon_open_19.raw.fif,asd,asd,125,5,33.0
2,boy5_asd_og_new_19.raw.fif,asd,asd,125,5,50.0
3,viflyancev_4_asd_fon__concat_19.raw.fif,asd,asd,125,4,58.0
4,andrey_matveev3_asd_new_19.raw.fif,asd,asd,125,3,50.0
...,...,...,...,...,...,...
322,artem_sarkis_6_fon_19.raw.fif,typical,typical,125,6,44.0
323,gusarova_9_norm_19.raw.fif,typical,typical,125,9,83.0
324,акулов миша 10 от даши_ог_19.raw.fif,typical,typical,125,10,42.0
325,жавнис_3_19.raw.fif,typical,typical,125,3,386.0


# Computing index

In [3]:
def get_powers(signal, fmin=4, fmax=8, sfreq=125):
    epoch = signal[np.newaxis]
    freqs = np.linspace(fmin, fmax)
    return mne.time_frequency.tfr_array_morlet(epoch, sfreq=sfreq, freqs=freqs, output='power', verbose=False)[0].mean(axis=1)

def get_filters(age):
    if age in range(2, 5):
        return 4, 12
    if age in range(5, 7):
        return 5, 12
    if age in range(7, 10):
        return 6, 13
    return 7, 13

def find_borders(arr):
    return np.where(np.pad(arr, (1, 0))[:-1] < arr), np.where(arr > np.pad(arr, (0, 1))[1:])

def index_1(eeg, age):
    low, high = get_filters(age)
    alpha_powers = get_powers(eeg, fmin=low, fmax=high)
    beta_powers = get_powers(eeg, fmin=13, fmax=30)
    return alpha_powers / beta_powers

def index_2(eeg, age):
    low, high = get_filters(age)
    alpha_powers = get_powers(eeg, fmin=low, fmax=high)
    beta_power = mne.time_frequency.psd_array_welch(eeg, sfreq=125, fmin=13, fmax=30, verbose=False)[0].mean(axis=1)
    return alpha_powers / beta_power[:, None]

# Plotting

In [17]:
def plot_eeg(signal, times, title=None):
    n_seconds = len(times) // sfreq
    y_max = np.quantile(abs(signal), 0.995)

    fig, axes = plt.subplots(19, 1, figsize=(3 * n_seconds, 12))
    if title is not None:
        fig.suptitle(title, fontsize=16)

    for i, ax in enumerate(axes):
        ax.plot(times, signal[i], linewidth=1)
        ax.set_ylabel(f"{ch_names[i]}", rotation=0, labelpad=25 if i else -30, fontsize=13)

        ax.spines['top'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.set_xlim(0, n_seconds)
        ax.set_ylim(-y_max, y_max)
        ax.patch.set_alpha(0)
        ax.set_xticks([])
        ax.set_yticks([])

    axes[-1].spines['bottom'].set_visible(True)
    axes[-1].set_xticks(np.arange(n_seconds))
    axes[0].set_yticks([0, y_max])
    axes[0].set_yticklabels([0, f"{y_max * 1e6:.2f} mV"])
    fig.tight_layout()
    fig.subplots_adjust(hspace=-0.3)

    return fig, axes

def plot_mask(ratios, threshold, times, axes):
    n_seconds = len(times) // sfreq

    mask = ratios > threshold

    for i, ax in enumerate(axes):
        index_val = 0.

        start_ix, end_ix = find_borders(mask[i])
        for start, end in zip(times[start_ix], times[end_ix]):
            if end - start < .1:
                continue
            ax.axvspan(start, end, color='r', alpha=.15, ymin=.15, ymax=.85)
            index_val += end - start
        ax.set_ylabel(f"{ch_names[i]} ({index_val / n_seconds:.2f})", rotation=0, labelpad=45 if i else -10, fontsize=13)