In [1]:
import scipy.io
import numpy as np
from pathlib import Path
import mne
import os
import torch

In [2]:
cur_dir = Path.cwd()
path = str(cur_dir.parent) + '/data/SEED/'
files_mat = sorted([f for f in os.listdir(path) if f.endswith('.mat') and f != 'label.mat'])
labels_mat = scipy.io.loadmat(path + 'label.mat')

In [None]:
montage = mne.channels.read_custom_montage(path + 'channel_62_pos.locs')
print(montage.ch_names)
mapping = {}
mapping['P7'] = 'T5'
mapping['P8'] = 'T6'
montage.rename_channels(mapping)
print(montage.ch_names)

In [4]:
sr = 200
# n_subjects = 15
n_sessions = 3
n_videos = 15
n_classes = 3
t = 1.
samples = int(t * sr)

In [None]:
labels = labels_mat['label'][0]
files_mat.sort()
X = []
for i, file in enumerate(files_mat):
    sub = int(file.split('_')[0])
    sess = (i%n_sessions)+1
    mat = scipy.io.loadmat(path + file)
    for j in range(n_videos):
        file_name = f"S{sub:02}R{sess}{j+1:02}.edf.raw.fif"
        annotation_file_name = f"S{sub:02}R{j+1:02}.edf.events"
        key = [k for k in mat if k.endswith(f'_eeg{j+1}')][0]
        stimuli_eeg_j = mat[key].T
        stimuli_eeg_j -= np.mean(stimuli_eeg_j, axis=0)
        stimuli_eeg_j /= np.std(stimuli_eeg_j, axis=0)
        info = mne.create_info(ch_names=montage.ch_names, sfreq=sr, ch_types=['eeg']*len(montage.ch_names))

        raw_stim = mne.io.RawArray(stimuli_eeg_j.T, info)
        raw_stim.set_montage(montage)
        # raw_stim = mne.preprocessing.compute_current_source_density(raw_stim)

        l = stimuli_eeg_j.shape[0]
        # for k in range(l//samples):
        #     X.append(torch.tensor(stimuli_eeg_j[l-((k+1)*samples):l-(k*samples), :].T, dtype=torch.float32))

        annotation_desc = [f'S{labels[j]+1}' for _ in range(l//samples)]
        annotation_onset = [k*t for k in range(l//samples)]
        annotation_duration = [samples/sr for _ in range(l//samples)]
        raw_stim.set_annotations(mne.Annotations(onset=annotation_onset, duration=annotation_duration,
                                                description=annotation_desc))
        # raw_stim.save(path + f'files/Session{sess}/S{sub:02}/' + file_name, overwrite=True)
        # raw_stim.annotations.save(path + f'files/Session{sess}/S{sub:02}/' + annotation_file_name, overwrite=True)
        raw_stim.save(path + f'files/Sessions/S{sub:02}/' + file_name, overwrite=True)
        raw_stim.annotations.save(path + f'files/Sessions/S{sub:02}/' + annotation_file_name, overwrite=True)
# X = torch.stack(X)

In [None]:
print("Max: ", torch.max(X))
print("Min: ", torch.min(X))