1.导包

In [8]:
import numpy as np
import time
from sklearn.model_selection import KFold
from src import data_label as dt, create_xlsx as cx, gene_idx as bf, ao_files as f
from metabci.brainda.algorithms.decomposition.tdca import TDCA

In [9]:
class RawDecode:
    def __init__(self, day):
        super(RawDecode, self).__init__()
        self.model = TDCA(0,1)
        self.day = day  # 0: raw, 1: day1, 2: day2
        self.t = 3 # time window
        self.n_fold = 4


    def raw_classify(self, yf, n_split, raw_data, raw_label):
        '''
        只跑TDCA的识别函数
        Parameters:
        model: classification model
        yf: reference signal
        n_split: number of cross validation
        训练测试集的大小根据file_version来定 默认标注file_version=0 为1时翻倍
        '''
        # Container for loss
        raw_loss = []
        
        kf = KFold(n_splits=n_split, random_state=None)
        randseed_p1, randseed_p2 = 1, 3
        iter_seed = 0
        predict = 0
        # Iterate through KFold splits
        for train_index, test_index in kf.split(raw_data, raw_label):
            # Data divide and Label divide
            iter_seed += 0.1
            data_train, data_test = raw_data[train_index], raw_data[test_index]
            label_train, label_test = raw_label[train_index], raw_label[test_index]

            if predict == 1:  # Single slice, train with 20, test with 20
                # Data divide and Label divide
                all_len, len1 = len(data_train), 20
                indices_p1 = bf.generate_indices(all_len, len1, randseed_p1+iter_seed)
                data_train, label_train = data_train[indices_p1], label_train[indices_p1]

            elif predict == 3: # train with 40, test with 20
                all_len, indice_2 = len(data_train), 40
                indices_p2 = bf.generate_indices(all_len, indice_2, randseed_p2+iter_seed)
                data_train, label_train = data_train[indices_p2], label_train[indices_p2]
            
            # get accuracy
            label_pred = self.model.fit(data_train, label_train, yf).predict(data_test)
            raw_loss.append(dt.loss(label_test, label_pred))
        losses = [raw_loss]
        loss = [np.round(100*np.mean(i), 2) for i in losses]
        return loss

    def raw_analyze(self, sub, filename, cap='new', file_version=0, over=False):
        """适用于原始数据分析的函数
        训练集和测试集来着同一个受试者相同范式的数据
        """

        fs = 500
        raw_fileload = dt.fileload if file_version == 0 else dt.fileload_ant
        pth_event_para1, pth_data_para1, pth_event_para2, pth_data_para2, pth_event_para3, pth_data_para3 = raw_fileload(filename)
        print(f'len of event: {len(pth_event_para1)}, {len(pth_event_para2)}, {len(pth_event_para3)}')
        print(f'len of data: {len(pth_data_para1)}, {len(pth_data_para2)}, {len(pth_data_para3)}')

        # 数据预处理
        data_para1, label_para1 = dt.filter_data(pth_event_para1, pth_data_para1, fs, self.t, cap)
        data_para2, label_para2 = dt.filter_data(pth_event_para2, pth_data_para2, fs, self.t, cap)
        # 参考信号
        freq = dt.get_freq(event=pth_event_para1[0])[0]
        yf = dt.reference_s(freq, 500, self.t)
        # 训练模型解码
        methods = ["TDCA"]
        score1 = self.raw_classify(yf, self.n_fold, data_para1, label_para1)
        score2 = self.raw_classify(yf, self.n_fold, data_para2, label_para2)
        result_para1 = {
            "sub": [sub],
            "method": methods,
            "para1": score1
        }
        result_para2 = {
            "sub": [sub],
            "method": methods,
            "para2": score2
        }
        cx.save_data_to_excel(result_para1, 'para1', overwrite=over)
        cx.save_data_to_excel(result_para2, 'para2', overwrite=over)

        if pth_event_para3 and pth_data_para3:
            data_para3, label_para3 = dt.filter_data(pth_event_para3, pth_data_para3, fs, self.t, cap)
            mean_score3 = self.raw_classify(yf, self.n_fold, data_para3, label_para3)
            result_para3 = {
                "sub": [sub],
                "method": methods,
                "para3": mean_score3
            }
            cx.save_data_to_excel(result_para3, 'para3', overwrite=over)
    
        print('_.'*30)

    def main(self):
        # 获取数据和信息
        infos = f.load_files()
        files, caps, subs, versions = infos['files'][self.day], infos['caps'][self.day], infos['subs'][self.day], infos['versions'][self.day]
        no_need = []
        idx = [i for i in range(len(files)) if i not in no_need]
        files, caps, subs, versions = [files[i] for i in idx], [caps[i] for i in idx], [subs[i] for i in idx], [versions[i] for i in idx]

        start = time.time()

        # 离线解码分析
        for i, (file, cap, sub, v) in enumerate(zip(files, caps, subs, versions)):
            print(sub)
            if self.day==0:
                over = True if sub=='S1' else False
            else:
                over = True if sub=='S10' else False
            self.raw_analyze(sub, file, cap, v, over)
                
        end = time.time()
        print(f'Time used: {end - start:.2f}s')


In [10]:
if __name__ == "__main__":
    decoder = RawDecode(day=0)  # 0: raw, 1: day1, 2: day2
    decoder.main()

S1
len of event: 4, 4, 4
len of data: 4, 4, 4
数据已覆盖写入到 para1_results.xlsx
数据已覆盖写入到 para2_results.xlsx
数据已覆盖写入到 para3_results.xlsx
_._._._._._._._._._._._._._._._._._._._._._._._._._._._._._.
S2
len of event: 4, 4, 4
len of data: 4, 4, 4
数据已追加写入到 para1_results.xlsx
数据已追加写入到 para2_results.xlsx
数据已追加写入到 para3_results.xlsx
_._._._._._._._._._._._._._._._._._._._._._._._._._._._._._.
S3
len of event: 4, 4, 4
len of data: 4, 4, 4
数据已追加写入到 para1_results.xlsx
数据已追加写入到 para2_results.xlsx
数据已追加写入到 para3_results.xlsx
_._._._._._._._._._._._._._._._._._._._._._._._._._._._._._.
S4
len of event: 4, 4, 4
len of data: 4, 4, 4
数据已追加写入到 para1_results.xlsx
数据已追加写入到 para2_results.xlsx
数据已追加写入到 para3_results.xlsx
_._._._._._._._._._._._._._._._._._._._._._._._._._._._._._.
S5
len of event: 4, 4, 4
len of data: 4, 4, 4
数据已追加写入到 para1_results.xlsx
数据已追加写入到 para2_results.xlsx
数据已追加写入到 para3_results.xlsx
_._._._._._._._._._._._._._._._._._._._._._._._._._._._._._.
S6
len of event: 4, 4, 4
len of data: 4, 4, 4