# Convert stages to epochs 

In [None]:
# Install a pip package in the current Jupyter kernel
import sys
import warnings
warnings.filterwarnings("ignore")
import os
import numpy as np
import mne

## Global variables

In [None]:
# CHANGE HERE 
SUBJ_NUM = 44

# Type of the test (math / video)
TASK_TYPE = 'math' 

# epoch duration in seconds
epoch_dur = 3
# epoch overlap in seconds, should be less than epoch_dur
OVLP = 0

# dataset dir and name 

TGT_DIR = 'data/'
ds = 'ds02'


In [None]:
# Initialize path variables for main folders

print(os.getcwd())

if (SUBJ_NUM < 10):
    subj_dir_name = '0'+str(SUBJ_NUM)+'_'+TASK_TYPE
else:
    subj_dir_name = str(SUBJ_NUM)+'_'+TASK_TYPE
    
root_dir_path = os.path.join(os.getcwd(), 'work_data', subj_dir_name)
init_dir_path = os.path.join(root_dir_path, 'initial_data')
raw_dir_path = os.path.join(root_dir_path, 'stages_raw')
ep_dir_path = os.path.join(root_dir_path, 'stages_epochs')
ft_dir_path = os.path.join(root_dir_path, 'features')

print(init_dir_path)
print(raw_dir_path)
print(ep_dir_path)
print(ft_dir_path)


In [None]:
# Create main folders if not exist

if not os.path.exists(root_dir_path):
    os.mkdir(root_dir_path)
if not os.path.exists(init_dir_path):
    os.mkdir(init_dir_path)
if not os.path.exists(raw_dir_path):
    os.mkdir(raw_dir_path)
if not os.path.exists(ep_dir_path):
    os.mkdir(ep_dir_path)
if not os.path.exists(ft_dir_path):
    os.mkdir(ft_dir_path)


In [None]:
# Frequency bands

bands = [(0.9, 4, 'Delta (0.9-4 Hz)', 'D'), (4, 8, 'Theta (4-8 Hz)', 'T'), (8, 14, 'Alpha (8-14 Hz)', 'A'), 
         (14, 25, 'Beta (14-25 Hz)', 'B'), (25, 40, 'Gamma (25-40 Hz)', 'G')]

str_freq = [bands[i][3] for i in range(len(bands))]

In [None]:
# Localization by scalp regions

regions = [(['Fp1','Fp2'], 'Fp', 'Pre-frontal'), (['F7','F3','FC5'], 'LF', 'Left Frontal'), 
           (['Fz','FC1','FC2'], 'MF', 'Midline Frontal'), (['F4','F8','FC6'], 'RF', 'Right Frontal'),
           (['T7','CP5','P7'], 'LT', 'Left Temporal'), (['T8','CP6','P8'], 'RT', 'Right Temporal'), 
           (['C3','Cz','C4'], 'Cen', 'Central'), (['P3','Pz','P4','CP1','CP2'], 'Par', 'Parietal'), 
           (['O1','Oz','O2'], 'Occ', 'Occipital')]

SLICE_LEN = 10 #number of epochs to measure physiological features, coherence and PLV

n_freq = len(str_freq)
n_regions = len(regions)


# Loading raw data

In [None]:
# Loading filtered and cropped raw data & stage_types
raw_dir_path = os.path.join(root_dir_path, 'stages_raw')

# Stage types
stage_types = np.loadtxt(os.path.join(raw_dir_path, 'stage_types.txt'), dtype=int)

# Initialize n_stages & n_types
n_stages = len(stage_types)
n_types = len(np.unique(stage_types))

# Loading stages_raw data
stages_raw = []
for _st in range(n_stages):
    stages_raw.append(mne.io.read_raw_fif(os.path.join(raw_dir_path, 'st_'+str(_st+1)+'_raw.fif')))
    print(stages_raw[_st].get_data().shape)

baseline_main_raw = mne.io.read_raw_fif(os.path.join(raw_dir_path, 'bl_main_raw.fif'))

# Global variables
samp_rate = baseline_main_raw.info['sfreq']
ch_names = baseline_main_raw.copy().pick_types(eeg=True).ch_names
n_channels = len(ch_names)
print(n_channels, ch_names)


# Epoching data

In [None]:
# Epochs by fixed length events

kwargs = dict(baseline=None, tmin=-epoch_dur/2, tmax=epoch_dur/2-1/samp_rate, picks='eeg', preload=True)

st_epoch_events = []
stages_epochs = []
st_types_list = []

for _st in range(n_stages):
    st_epoch_events.append(mne.make_fixed_length_events(stages_raw[_st], start=0.5, duration=epoch_dur, overlap=OVLP))#, first_samp=True))#, overlap=0.2)
    stages_epochs.append(mne.Epochs(stages_raw[_st], st_epoch_events[_st].astype(int), **kwargs))
    st_types_list += [stage_types[_st]]*len(stages_epochs[_st])
    
    print(stages_epochs[_st].get_data().shape)

epochs_st_all = mne.concatenate_epochs(stages_epochs, add_offset=False)


## Saving ...

In [None]:
tgt_fld = os.path.join(os.getcwd(), TGT_DIR, subj_dir_name, ds)

if not os.path.exists(tgt_fld):
    os.makedirs(tgt_fld, exist_ok=False)

np.save(os.path.join(tgt_fld, 'x.npy'), epochs_st_all.get_data())
np.save(os.path.join(tgt_fld, 'y.npy'), st_types_list)

In [None]:
print("processed", subj_dir_name)
exit()