In [None]:
#from part1_preprocess import date2student_id
import tqdm
import numpy as np
from scipy.integrate import simps
import math
import os
import json

In [None]:
FREQ_BANDS = {    # 4-7
    "delta": [0.5, 4],   # 1-3 
    "theta": [4, 8],     # 4-7
    "alpha": [8, 13],    # 8-12
    "beta": [13, 30],    # 13-30
    "gamma": [30,50]
}

In [None]:
def bandpower(data, sf, band, method='welch', window_sec=None, relative=False):
    from scipy.signal import welch
    from mne.time_frequency import psd_array_multitaper

    band = np.asarray(band)
    low, high = band

    # Compute the modified periodogram (Welch)
    if method == 'welch':
        if window_sec is not None:
            nperseg = window_sec * sf
        else:
            nperseg = (2 / low) * sf

        freqs, psd = welch(data, sf, nperseg=nperseg)

    elif method == 'multitaper':
        psd, freqs = psd_array_multitaper(data, sf, adaptive=True,
                                          normalization='full', verbose=0)

    # Frequency resolution
    freq_res = freqs[1] - freqs[0]

    # Find index of band in frequency vector
    idx_band = np.logical_and(freqs >= low, freqs <= high)

    # Integral approximation of the spectrum using parabola (Simpson's rule)
    bp = simps(psd[idx_band], dx=freq_res)

    if relative:
        bp /= simps(psd, dx=freq_res)
    return bp

In [None]:
'''def get_bp(idx2eeg, out_path):
    print('extracting bp', end = ' ')
    idx2de = {}
    for idx in idx2eeg.keys():
        eeg = np.array(idx2eeg[idx])
        de_list = []
        for i in range(min(int(eeg.shape[1]/1000),10)):
            fs = 1000
            # tmp_data = eeg[:, i * 1000 : i * 1000 + 500]

            tmp_data = eeg[:, i * 1000 : i * 1000 + 1000]
            tmp_fs = []
            for channel_id in range(tmp_data.shape[0]):
                tmp_feature = []
                for band_item in FREQ_BANDS.values():
                    tmp_feature.append(math.log(bandpower(tmp_data[channel_id], fs, band_item)))
                tmp_fs.append(tmp_feature)
            de_list.append(tmp_fs)
        idx2de[idx] = de_list
    json.dump(idx2de, open(out_path, 'w'))
'''

In [None]:
#设置时间窗口、去掉前后的幅度。

In [None]:
def get_bp(idx2eeg, out_path, video_duration=60, fs=1000, discard_sec=5):
    print('Extracting bp...', end=' ')
    idx2de = {}
    try:
        for idx in idx2eeg.keys():
            #print("idx:", idx)
            eeg = np.array(idx2eeg[idx])

            # 计算需要舍弃的数据点数目
            discard_samples = discard_sec * fs

            # 调整EEG数据，去掉视频开始和结束的5秒数据
            eeg = eeg[:, discard_samples: -discard_samples]

            de_list = []

            # 计算特征提取的次数，确保不超出调整后的数据长度
            num_of_extraction = min(int(eeg.shape[1] / 5000), video_duration - 2 * discard_sec)

            for i in range(num_of_extraction):
                # 每次处理1000个样本
                tmp_data = eeg[:, i * 5000: i * 5000 + 5000]
                tmp_fs = []
                for channel_id in range(tmp_data.shape[0]):
                    tmp_feature = []
                    for band_item in FREQ_BANDS.values():
                        # 采用de特征
                        tmp_feature.append(math.log(bandpower(tmp_data[channel_id], fs, band_item, relative=True)))
                        # 采用psd特征
                        # tmp_feature.append((bandpower(tmp_data[channel_id], fs, band_item, relative=True)))
                    tmp_fs.append(tmp_feature)
                de_list.append(tmp_fs)
            idx2de[idx] = de_list

        # 使用with语句确保文件正确关闭
        with open(out_path, 'w') as file:
            json.dump(idx2de, file)

    except Exception as e:
        print("An error occurred: ", e)
        # 处理或记录异常，根据需要修改

    print('Done.')


In [None]:
for date in tqdm.tqdm(['LAB1-huqifan','LAB1-cangyueyang','LAB1-hongyurui','LAB1-fanhao','LAB1-dongyimeng','LAB1-houlinzhi']):
    idx2eeg = json.load(open('./x2eeg/'+date+'_idx2eeg.json'))
    get_bp(idx2eeg, './discard_features_window=5/'+date+'_idx2de.json')