In [None]:
import mne
import matplotlib.pyplot as plt
from collections import defaultdict
import numpy as np

In [None]:
recording_path = 'recordings/recording_ssvep4_gel.raw.fif'

In [None]:
raw = mne.io.read_raw_fif(recording_path, preload=True).rescale(1e-6)
raw.info = mne.create_info(ch_names=raw.ch_names, sfreq=raw.info['sfreq'], ch_types='eeg')
raw.info

In [None]:
raw.plot()
plt.show()

In [None]:
raw.annotations.description, raw.annotations.onset

In [None]:
label_onsets = defaultdict(list)
current_cue = None
stimulus_duration = None

# Calculate stimulus presentation time assuming all are equal
for i in range(len(raw.annotations) - 1):
    if raw.annotations[i]['description'].startswith('stimulus'):
        assert raw.annotations[i + 1]['description'].startswith('cue')
        # Calculate the duration of the stimulus presentation
        stimulus_duration = int(raw.annotations[i + 1]['onset'] - raw.annotations[i]['onset'])
        break

print(stimulus_duration)

for annotation in raw.annotations:
    if annotation['description'].startswith('cue'):
        current_cue = annotation['description'].split(' ')[1]
    elif annotation['description'].startswith('stimulus'):
        current_onset = annotation['onset']
        # We must already have a cue (every stimulus follows a cue)
        label_onsets[current_cue].append(current_onset)

label_onsets

In [None]:
# For each label, crop the raw data
def crop_raw_data(raw, label_onsets, stimulus_duration):
    cropped_data = {}
    for label, onsets in label_onsets.items():
        cropped_data[label] = []
        for onset in onsets:
            end = onset + stimulus_duration
            cropped_data[label].append(raw.copy().crop(onset, end))
    return cropped_data
cropped = crop_raw_data(raw, label_onsets, stimulus_duration)
cropped

In [None]:
def filter_cropped_data(cropped, l_freq=None, h_freq=None):
    filtered_data = {}
    for label, raw_list in cropped.items():
        filtered_data[label] = []
        for raw in raw_list:
            filtered = raw.copy().filter(l_freq, h_freq)
            filtered_data[label].append(filtered)
    return filtered_data
filtered = filter_cropped_data(cropped, l_freq=8, h_freq=30)
filtered

In [None]:
def epoch_filtered_data(filtered, window_size, window_overlap):
    epochs_data = {}
    for label, raw_list in filtered.items():
        epochs_data[label] = []
        for raw in raw_list:
            # Create epochs with the specified window size and overlap
            epochs = mne.make_fixed_length_epochs(raw, duration=window_size, overlap=window_overlap, preload=True)
            epochs_data[label].append(epochs)
    return epochs_data
epochs = epoch_filtered_data(filtered, window_size=2, window_overlap=0.5)
epochs

In [None]:
def convert_epochs_to_array(epochs):
    data_arrays = {}
    for label, epochs_list in epochs.items():
        data_arrays[label] = []
        for raw in epochs_list:
            # Convert epochs to numpy array
            data_arrays[label].append(raw.get_data())
        # Convert to numpy array
        data_arrays[label] = np.concatenate(data_arrays[label], axis=0)
    return data_arrays
data_arrays = convert_epochs_to_array(epochs)
data_arrays['15.0']

In [None]:
data_arrays['15.0'][0]