In [1]:
#from part1_preprocess import date2student_id
import tqdm
import numpy as np
from scipy.integrate import simps
import math
import os
import json

In [2]:
FREQ_BANDS = {    # 4-7
    "delta": [0.5, 4],   # 1-3 
    "theta": [4, 8],     # 4-7
    "alpha": [8, 13],    # 8-12
    "beta": [13, 30],    # 13-30
    "gamma": [30,50]
}

In [3]:
def bandpower(data, sf, band, method='welch', window_sec=None, relative=False):
    from scipy.signal import welch
    from mne.time_frequency import psd_array_multitaper

    band = np.asarray(band)
    low, high = band

    # Compute the modified periodogram (Welch)
    if method == 'welch':
        if window_sec is not None:
            nperseg = window_sec * sf
        else:
            nperseg = (2 / low) * sf

        freqs, psd = welch(data, sf, nperseg=nperseg)

    elif method == 'multitaper':
        psd, freqs = psd_array_multitaper(data, sf, adaptive=True,
                                          normalization='full', verbose=0)

    # Frequency resolution
    freq_res = freqs[1] - freqs[0]

    # Find index of band in frequency vector
    idx_band = np.logical_and(freqs >= low, freqs <= high)

    # Integral approximation of the spectrum using parabola (Simpson's rule)
    bp = simps(psd[idx_band], dx=freq_res)

    if relative:
        bp /= simps(psd, dx=freq_res)
    return bp

In [None]:
'''def get_bp(idx2eeg, out_path):
    print('extracting bp', end = ' ')
    idx2de = {}
    for idx in idx2eeg.keys():
        eeg = np.array(idx2eeg[idx])
        de_list = []
        for i in range(min(int(eeg.shape[1]/1000),10)):
            fs = 1000
            # tmp_data = eeg[:, i * 1000 : i * 1000 + 500]

            tmp_data = eeg[:, i * 1000 : i * 1000 + 1000]
            tmp_fs = []
            for channel_id in range(tmp_data.shape[0]):
                tmp_feature = []
                for band_item in FREQ_BANDS.values():
                    tmp_feature.append(math.log(bandpower(tmp_data[channel_id], fs, band_item)))
                tmp_fs.append(tmp_feature)
            de_list.append(tmp_fs)
        idx2de[idx] = de_list
    json.dump(idx2de, open(out_path, 'w'))
'''

In [5]:
def get_bp(idx2eeg, out_path):
    print('Extracting bp...', end=' ')
    idx2de = {}
    try:
        for idx in idx2eeg.keys():
            eeg = np.array(idx2eeg[idx])
            de_list = []
            #print('EEGSHAPE:',int(eeg.shape[1]))
            for i in range(min(int(eeg.shape[1]/1000), 10)):
                fs = 1000
                # 每次处理1000个样本
                tmp_data = eeg[:, i * 1000 : i * 1000 + 1000]
                tmp_fs = []
                for channel_id in range(tmp_data.shape[0]):
                    tmp_feature = []
                    for band_item in FREQ_BANDS.values():
                        # 确保bandpower函数可以正确处理数据
                        tmp_feature.append(math.log(bandpower(tmp_data[channel_id], fs, band_item,relative=True)))
                        # 采用psd特征
                        #tmp_feature.append((bandpower(tmp_data[channel_id], fs, band_item,relative=True)))
                    tmp_fs.append(tmp_feature)
                de_list.append(tmp_fs)
            idx2de[idx] = de_list

        # 使用with语句确保文件正确关闭
        with open(out_path, 'w') as file:
            json.dump(idx2de, file)

    except Exception as e:
        print("An error occurred: ", e)
        # 处理或记录异常，根据需要修改

    print('Done.')


In [None]:
'''
'LAB1-huqifan','LAB1-cangyueyang','LAB1-hongyurui','LAB1-fanhao','LAB1-dongyimeng','LAB1-houlinzhi','LAB1-jiwenjun','LAB1-lujianing','LAB1-miaoshengze',
             'LAB1-wanfangwei','LAB1-wangxiaoting','LAB1-wangzhengni','LAB1-yangchen','LAB1-zhangxue','LAB1-liangqihang','LAB1-daisiwei',
             'LAB1-zhangyutong','LAB1-mengfanjie','LAB1-zhangchenxi','LAB1-liangyanshu','LAB1-zhaochensong','LAB1-chenxingyu','LAB1-chenrong'
'LAB2-huqifan','LAB2-cangyueyang','LAB2-hongyurui','LAB2-fanhao','LAB2-dongyimeng','LAB2-houlinzhi','LAB2-jiwenjun','LAB2-lujianing','LAB2-miaoshengze',
             'LAB2-wanfangwei','LAB2-wangxiaoting','LAB2-wangzhengni','LAB2-yangchen','LAB2-zhangxue','LAB2-liangqihang','LAB2-daisiwei',
             'LAB2-zhangyutong','LAB2-mengfanjie','LAB2-zhangchenxi','LAB2-liangyanshu','LAB2-zhaochensong','LAB2-chenxingyu','LAB2-chenrong'
'''
#缺了'guohongyang'


In [6]:
for date in tqdm.tqdm(['LAB2-guohongyang','LAB2-zhaochensong','LAB2-chenxingyu','LAB2-chenrong','LAB1-zhaochensong','LAB1-chenxingyu','LAB1-chenrong'
]):
    idx2eeg = json.load(open('./x2eeg/'+date+'_idx2eeg.json'))
    get_bp(idx2eeg, './higher_features/'+date+'_idx2de.json')

  0%|          | 0/7 [00:00<?, ?it/s]

Extracting bp... 

 14%|█▍        | 1/7 [01:22<08:14, 82.39s/it]

Done.
Extracting bp... 

 29%|██▊       | 2/7 [02:19<05:37, 67.45s/it]

Done.
Extracting bp... 

 43%|████▎     | 3/7 [03:47<05:06, 76.70s/it]

Done.
Extracting bp... 

 57%|█████▋    | 4/7 [04:56<03:40, 73.64s/it]

Done.
Extracting bp... 

 71%|███████▏  | 5/7 [06:02<02:22, 71.12s/it]

Done.
Extracting bp... 

 86%|████████▌ | 6/7 [07:26<01:15, 75.38s/it]

Done.
Extracting bp... 

100%|██████████| 7/7 [08:33<00:00, 73.38s/it]

Done.



