In [2]:
from scipy.signal import hilbert
import numpy as np
import scipy.io as scio
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.nn.functional as func
from torch.autograd import Function
from torch.autograd import Variable
import math
import random
import sys
import os
from parameters import *


def add_delay(data):
    sound = data.iloc[:, 0]
    sound_not_target = data.iloc[:, 1]
    EEG = data.iloc[:, 2:]

    if delay >= 0:
        sound = sound.iloc[:sound.shape[0] - delay]
        sound_not_target = sound_not_target.iloc[:sound_not_target.shape[0] - delay]
        EEG = EEG.iloc[delay:, :]
    else:
        sound = sound.iloc[-delay:]
        sound_not_target = sound_not_target.iloc[-delay:]
        EEG = EEG.iloc[:EEG.shape[0] + delay, :]

    sound = sound.reset_index(drop=True)
    sound_not_target = sound_not_target.reset_index(drop=True)
    EEG = EEG.reset_index(drop=True)

    data_pf = pd.concat([sound, sound_not_target, EEG],
                        axis=1, ignore_index=True)
    return data_pf


def read(name, fb_index = None):
    first = True
    data_pf = pd.DataFrame()
    data = pd.DataFrame()

    for l in range(len(ConType)):
        sex = pd.read_csv("./csv/" + name + ConType[l] + ".csv")
        for k in range(trail_number):
            # 读取数据
            filename = data_document_path + "/" + ConType[l] + "/" + name + "Tra" + str(k + 1)
            filename = filename + "_" + str(fb_index) if bands_number != 1 else filename
            filename = filename + ".csv"
            train_data = pd.read_csv(filename, header=None)

            EEG_data = train_data.iloc[:, 2:]
            Sound_data = train_data.iloc[:, 0]
            Sound_data_not_target = train_data.iloc[:, 1]

            # 调整左右位置，添加辅助信息
            if isDS and sex.iloc[k, isFM] == 2:
                temp = Sound_data
                Sound_data = Sound_data_not_target
                Sound_data_not_target = temp

            # 合并
            EEG_data = pd.DataFrame(EEG_data)
            Sound_data = pd.DataFrame(Sound_data)
            Sound_data_not_target = pd.DataFrame(Sound_data_not_target)
            data_pf = pd.concat(
                [Sound_data, Sound_data_not_target, EEG_data], axis=1, ignore_index=True)

            # 加入时延
            data_pf = add_delay(data_pf)

            if first:
                data = data_pf
                first = False
            else:
                data = pd.concat([data, data_pf], axis=0, ignore_index=True)

    return data


def timeSplit(data, name):
    def update_CNN_DS_S(data_pf, data_direction, temp_data, temp_direction):
        # 如果是CNN：D+S或CNN：FM+S模型
        if isDS:
            data_pf = pd.concat([data_pf, temp_data],
                                axis=0, ignore_index=True)
            data_direction.append(temp_direction)
        # 否则是CNN：S模型
        else:
            temp = np.array(temp_data)
            data_pf = pd.concat([data_pf, pd.DataFrame(
                temp.copy())], axis=0, ignore_index=True)
            temp[:, [0, channel_number - 1]] = temp[:, [channel_number - 1, 0]]
            data_pf = pd.concat([data_pf, pd.DataFrame(
                temp.copy())], axis=0, ignore_index=True)
            data_direction.append(1)
            data_direction.append(2)
        return data_pf, data_direction

    # 参数初始化
    global cell_number
    global test_percent
    cell_number = cell_number - abs(delay)
    window_lap = window_length * (1 - overlap)
    overlap_distance = math.floor(1 / (1 - overlap)) - 1
    selection_trails = 0
    if isBeyoudTrail:
        selection_trails = random.sample(
            range(trail_number), math.ceil(trail_number * test_percent))

    # 找不到其他空矩阵创建方法，先用着
    data_pf = pd.DataFrame(data.iloc[:window_length, :])
    test_pf = pd.DataFrame(data.iloc[:window_length, :])
    data_direction = []
    test_direction = []

    # 对于每个ConType进行划分
    for l in range(len(ConType)):
        sex = pd.read_csv("./csv/" + name + ConType[l] + ".csv")

        # 对于ConType里的每个trail进行划分
        for k in range(trail_number):
            # 每个trail里的窗口个数
            window_number = math.floor(
                (cell_number - window_length) / window_lap) + 1
            # 随机抽取的测试窗口长度
            if isBeyoudTrail:
                test_percent = 1 if k in selection_trails else 0
            test_percent = 0 if isALLTrain else test_percent
            test_window_length = math.floor(
                (cell_number * test_percent - window_length) / window_lap)
            test_window_length = test_window_length if test_percent == 0 else max(
                0, test_window_length)
            test_window_length = test_window_length + 1
            # 随机抽取的测试窗口左右边界
            test_window_left = random.randint(
                0, window_number - test_window_length)
            test_window_right = test_window_left + test_window_length - 1

            # 对于ConType里的trail里的每个窗口进行划分
            for i in range(window_number):
                left = math.floor(k * cell_number + i * window_lap)
                right = math.floor(left + window_length)
                # 如果不是要抽取的测试窗口，即为训练集里的窗口
                if test_window_left > test_window_right or test_window_left - i > overlap_distance or i - test_window_right > overlap_distance:
                    temp_data = data.iloc[left:right, :]
                    temp_direction = sex.iloc[k, isFM]
                    data_pf, data_direction = update_CNN_DS_S(
                        data_pf, data_direction, temp_data, temp_direction)
                elif i >= test_window_left and i <= test_window_right:
                    temp_data = data.iloc[left:right, :]
                    temp_direction = sex.iloc[k, isFM]
                    test_pf, test_direction = update_CNN_DS_S(
                        test_pf, test_direction, temp_data, temp_direction)

    # 去除初始化的数据
    data_pf = data_pf.iloc[window_length:, :]
    test_pf = test_pf.iloc[window_length:, :]

    # 重新组织结构
    data_pf = np.array(data_pf).reshape(-1, window_length, channel_number)
    test_pf = np.array(test_pf).reshape(-1, window_length, channel_number)

    data = []
    for i in range(data_pf.shape[0]):
        d = dict()
        d["data"] = data_pf[i]
        d["direction"] = data_direction[i]
        d["index"] = i
        data.append(d)

    test = []
    for i in range(test_pf.shape[0]):
        d = dict()
        d["data"] = test_pf[i]
        d["direction"] = test_direction[i]
        d["index"] = i
        test.append(d)

    return data, test


def change(train_data):
    train_sound = train_data.iloc[:, 0]
    train_sound_not_target = train_data.iloc[:, 1]
    train_EEG = train_data.iloc[:, 2:]
    data_pf = pd.concat(
        [train_sound, train_EEG, train_sound_not_target], axis=1, ignore_index=True)
    return data_pf


def main(name="S2"):
    print("start!")
    
    # 提取不同频带数据并汇总得到训练集和测试集
    max_band = bands_number
    train = []
    test = []
    for i in range(max_band):
        # 读取数据
        data_pf = read(name, i+1)
        # 调整数据结构
        data_pf = change(data_pf)
        # 划分时间窗口并生成训练集和测试集
        newtrain, newtest = timeSplit(data_pf, name)
        train += newtrain
        test += newtest

    train = np.array(train)
    test = np.array(test)

    # 合并和保存结果
    data_pf = [train, test]
    data_pf = np.array(data_pf)
    np.save("./data_new/CNN1_" + name, data_pf)

    print("finish!")


if __name__ == "__main__":
    if (len(sys.argv) == 2):
        main(sys.argv[1])
    else:
        main()


start!
finish!


In [3]:
from scipy.signal import hilbert
import numpy as np
import scipy.io as scio
import pandas as pd
import torch
import torch.nn as nn
import matplotlib

matplotlib.use('agg')
import matplotlib.pyplot as plt
import torch.nn.functional as func
from torch.autograd import Function
from torch.autograd import Variable
import math
import random
import sys
from parameters import *
import seaborn as sns


def vali_split(train):
    window_number = train.shape[0]
    # 随机抽取的验证窗口长度
    vali_window_length = math.floor(window_number * vali_percent)
    # 随机抽取的验证窗口
    vali_window_left = random.randint(0, window_number - vali_window_length)
    vali_window_right = vali_window_left + vali_window_length - 1
    # 重复距离
    overlap_distance = math.floor(1 / (1 - overlap)) - 1

    train_window = []
    vali_window = []

    for i in range(window_number):
        # 如果不是要抽取的验证窗口
        if vali_window_left - i > overlap_distance or i - vali_window_right > overlap_distance:
            train_window.append(train[i])
        elif i >= vali_window_left and i <= vali_window_right:
            vali_window.append(train[i])

    return np.array(train_window), np.array(vali_window)


def heatmap(_, epoch, number, matrix_num, title='crossmodal attention matrix', x_label='No.', y_label='No.'):
    #     print(type(_[0]))
    #     print(_[0].cpu().detach().numpy().shape)
    a = _[0].cpu().detach().numpy()
    #     print(a.shape)
    fig, ax = plt.subplots(figsize=(20, 20))
    heatmap = sns.heatmap(pd.DataFrame(np.round(a, 2)), xticklabels=True, yticklabels=True, square=True, center=0)
    ax.set_title(title, fontsize=18)
    ax.set_ylabel(y_label, fontsize=18)
    ax.set_xlabel(x_label, fontsize=18)  # 横变成y轴，跟矩阵原始的布局情况是一样的

    fig.savefig("./picture/" + names[0] + "/" + str(epoch) + "_" + str(number) + "_" + str(matrix_num) + ".png")


def isLeftPredicted(result):
    result = result.cpu().detach().numpy()
    result = np.expand_dims(result, axis=0)
    result = torch.from_numpy(result).to(device)
    lossL = loss_func(result, torch.tensor([0]).to(device)).cpu().detach().numpy()
    lossR = loss_func(result, torch.tensor([1]).to(device)).cpu().detach().numpy()
    return lossL < lossR


def train(train, epoch, targetIndex):
    losses = 0

    for turn in range(math.floor(train.shape[0] / batch_size)):
        optimzer.zero_grad()
        temp = train[turn]["data"].T
        batchData = np.ndarray((0, 1, temp.shape[0], temp.shape[1]))
        allTarget = []
        for k in range(batch_size):
            input = train[turn * batch_size + k]["data"].T
            input = np.expand_dims(input, axis=0)
            input = np.expand_dims(input, axis=0)
            batchData = np.concatenate((batchData, input), axis=0)
            target = [train[turn * batch_size + k]["direction"] - 1]
            allTarget = np.concatenate((allTarget, target), axis=0)
        x = torch.tensor(batchData, dtype=torch.float32)
        x = x.to(device)
        out, aeA, audio, ae_fcA, eaA, aeB, eaB = myNet(x)

        for i in range(batch_size):
            itemIndex = train[turn * batch_size + i]["index"]
            isLeft = isLeftPredicted(out[i])
            isAAttend = allTarget[i] == 0
            if itemIndex in targetIndex:
                ma1 = pd.DataFrame(ae_fcA[-1].cpu().detach().numpy())
                ma1.to_csv(
                    "./picture/" + names[0] + "/" + str(epoch) + "_" + str(itemIndex) + "_qkv.csv")

                attend_str = " attend=A" if isAAttend else " attend=A"
                predict_str = " predict=A" if isLeft else " predict=B"
                title_str = attend_str + predict_str
                heatmap(aeA[0], epoch, itemIndex, "aeA_first", title="cmA q=audio kv=eeg" + title_str, x_label="eeg",
                        y_label="audio")
                heatmap(aeA[-1], epoch, itemIndex, "aeA_last", title="cmA q=audio kv=eeg" + title_str, x_label="eeg",
                        y_label="audio")
                heatmap(eaA[0], epoch, itemIndex, "eaA_first", title="cmA q=eeg kv=audio" + title_str, x_label="audio",
                        y_label="eeg")
                heatmap(eaA[-1], epoch, itemIndex, "eaA_last", title="cmA q=eeg kv=audio" + title_str, x_label="audio",
                        y_label="eeg")

                heatmap(aeB[0], epoch, itemIndex, "aeB_first", title="cmB q=audio kv=eeg" + title_str, x_label="eeg",
                        y_label="audio")
                heatmap(aeB[-1], epoch, itemIndex, "aeB_last", title="cmB q=audio kv=eeg" + title_str, x_label="eeg",
                        y_label="audio")
                heatmap(eaB[0], epoch, itemIndex, "eaB_first", title="cmB q=eeg kv=audio" + title_str, x_label="audio",
                        y_label="eeg")
                heatmap(eaB[-1], epoch, itemIndex, "eaB_last", title="cmB q=eeg kv=audio" + title_str, x_label="audio",
                        y_label="eeg")

                heatmap([ae_fcA[-1]], epoch, itemIndex, "qkv", title="qkv weight")
                heatmap(audio[0], epoch, itemIndex, "wavA_before", title="wavA_before")
                heatmap(audio[1], epoch, itemIndex, "wavA_after", title="wavA_after")

        loss = loss_func(out, torch.tensor(allTarget, dtype=torch.long).to(device))
        losses = losses + loss.cpu().detach().numpy()
        loss.backward()
        optimzer.step()
    # scheduler.step()
    scheduler.step(metrics=0.1)

    return losses / (math.floor(train.shape[0] / batch_size))


def test(cv):
    losses = 0

    for turn in range(math.floor(cv.shape[0] / batch_size)):
        optimzer.zero_grad()
        temp = cv[turn]["data"].T
        batchData = np.ndarray((0, 1, temp.shape[0], temp.shape[1]))
        allTarget = []
        for k in range(batch_size):
            input = cv[turn * batch_size + k]["data"].T
            input = np.expand_dims(input, axis=0)
            input = np.expand_dims(input, axis=0)
            batchData = np.concatenate((batchData, input), axis=0)
            target = [cv[turn * batch_size + k]["direction"] - 1]
            allTarget = np.concatenate((allTarget, target), axis=0)
        x = torch.tensor(batchData, dtype=torch.float32)
        x = x.to(device)
        out, aeA, audio, ae_fcA, eaA, aeB, eaB = myNet(x)
        loss = loss_func(out, torch.tensor(allTarget, dtype=torch.long).to(device))
        losses = losses + loss.cpu().detach().numpy()

    return losses / (math.floor(cv.shape[0] / batch_size))


def trainEpoch(data, test_data):
    min_loss = 100
    early_stop_number = 0

    targetIndex = np.random.choice(a=math.floor(data[0].shape[0] * 0.9), size=5, replace=False)
    print(targetIndex)

    for epoch in range(max_epoch):

        # 打乱非测试数据集并划分训练集和验证集
        dataset = data[0].copy()
        train_data, cv_data = vali_split(dataset)
        np.random.shuffle(train_data)

        loss_train = train(train_data, epoch, targetIndex)
        loss = test(cv_data)
        loss2 = test(test_data)

        print(str(epoch) + " " + str(loss_train) + " " + str(loss) + " " + str(loss2), end="")

        if loss > min_loss:
            early_stop_number = early_stop_number + 1
        else:
            early_stop_number = 0
            min_loss = loss

        print(" early_stop_number: ", end="")
        print(early_stop_number, end="")
        print()

        if isEarlyStop and epoch > min_epoch and early_stop_number >= 10:
            break


def testEpoch(test_data):
    for num in range(10):
        t_num = 0
        f_num = 0
        for turn in range(math.floor(test_data.shape[0] / batch_size)):
            optimzer.zero_grad()
            temp = test_data[turn]["data"].T
            batchData = np.ndarray((0, 1, temp.shape[0], temp.shape[1]))
            allTarget = []
            for k in range(batch_size):
                input = test_data[turn * batch_size + k]["data"].T
                input = np.expand_dims(input, axis=0)
                input = np.expand_dims(input, axis=0)
                batchData = np.concatenate((batchData, input), axis=0)
                target = [test_data[turn * batch_size + k]["direction"] - 1]
                allTarget = np.concatenate((allTarget, target), axis=0)
            x = torch.tensor(batchData, dtype=torch.float32)
            x = x.to(device)
            out, aeA, audio, ae_fcA, eaA, aeB, eaB = myNet(x)

            for i in range(batch_size):
                ifLeft = isLeftPredicted(out[i])
                if ifLeft == (allTarget[i] == 0):
                    t_num = t_num + 1
                else:
                    f_num = f_num + 1

        print(str(t_num) + " " + str(f_num))
    print(str(t_num / (t_num + f_num)))


def main(name="S2"):
    # 参数init
    name_number = int(name[1:])

    # 先读取测试数据
    data = np.load("./data_new/CNN1_" + name + ".npy", allow_pickle=True)
    test_data = data[0] if isALLTrain and need_pretrain and not need_train else data[1]

    # 读取数据并预训练
    if need_pretrain:
        print("pretrain start!")
        basic_name = "S" + str(name_number % (people_number - 1) + 1)
        b = np.load("./data_new/CNN1_" + basic_name + ".npy", allow_pickle=True)
        for k in range(people_number):
            filelable = "S" + str(k + 1)
            if (not isALLTrain or filelable != name) and filelable != basic_name:
                # 读取数据
                a = np.load("./data_new/CNN1_" + filelable + ".npy", allow_pickle=True)
                b[0] = np.hstack((a[0], b[0]))
                b[1] = np.hstack((a[1], b[1]))
        data = b
        trainEpoch(data, test_data)
        print()

    # 读取数据并训练
    if need_train:
        # 降低学习率
        if need_pretrain:
            for p in optimzer.param_groups:
                p['lr'] *= 0.1

        print("train start!")
        data = np.load("./data_new/CNN1_" + name + ".npy", allow_pickle=True)

        # # 随机选取N个数 临时起作用
        # np.random.shuffle(data[0])
        # print(data[0].size)
        # length = math.floor(data[0].size / (700 / window_length))
        # data[0] = data[0][:length]
        # print(data[0].size)

        trainEpoch(data, test_data)
        print()

    # 测试
    print("test start!")
    testEpoch(test_data)


if __name__ == "__main__":
    if (len(sys.argv) > 1 and sys.argv[1].startswith("S")):
        main(sys.argv[1])
    else:
        main()

train start!
[ 85 231 101 535 214]




KeyboardInterrupt: 

In [17]:
import os
import glob
import time
import shutil
import logging
from importlib import reload
import re

label = ""
names = []
logger = 0
result_document = "./result/"
all_names = ["S1", "S2", "S3", "S4", "S5", "S6", "S7", "S8", "S9", "S10", "S11", "S12", "S13", "S14", "S15", "S16", "S17", "S18"]


def run_split():
    # for i in range(len(all_names)):
    for i in range(len(names)):
        # os.system("nohup /home/lipeiwen/anaconda3/bin/python -u CNN_split.py " + all_names[i] + " > " +
        #           result_document + label + "/CNN_split_" + all_names[i] + "_" + label + ".log 2>&1 &")
        os.system("nohup ~/anaconda3/bin/python -u CNN_split.py " + names[i] + " > " +
                  result_document + label + "/CNN_split_" + names[i] + "_" + label + ".log 2>&1 &")


def run_train():
    for i in range(len(names)):
        os.system("nohup ~/anaconda3/bin/python -u CNN.py " + names[i] + " > " +
                  result_document + label + "/CNN_" + names[i] + "_" + label + ".log 2>&1 &")


def __get_last_line(filename):
    try:
        filesize = os.path.getsize(filename)
        if filesize == 0:
            return None
        else:
            with open(filename, 'rb') as fp:  # to use seek from end, must use mode 'rb'
                offset = -2                # initialize offset
                while -offset < filesize:   # offset cannot exceed file size
                    # read # offset chars from eof(represent by number '2')
                    fp.seek(offset, 2)
                    lines = fp.readlines()  # read from fp to eof
                    if len(lines) >= 2:     # if contains at least 2 lines
                        # then last line is totally included
                        return lines[-1]
                    else:
                        offset *= 2         # enlarge offset
                fp.seek(0)
                lines = fp.readlines()
                return lines[-1]
    except FileNotFoundError:
        logger.error(filename + ' not found!')
        return None


def search_all_files_return_by_time_reversed(path, reverse=True):
    return sorted(glob.glob(os.path.join(path, '*')), key=lambda x: time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(os.path.getctime(x))), reverse=reverse)


def MonitorSplit():
    while True:
        isFinish = True
        for i in range(len(names)):
            filename = result_document + label + "/CNN_split_" + names[i] + "_" + label + ".log"
            str = __get_last_line(filename).decode()
            if ("finish" not in str):
                isFinish = False
                break
        if isFinish:
            break

        time.sleep(60)


def MonitorTrain():
    while True:
        isFinish = True
        for i in range(len(names)):
            filename = result_document + label + "/CNN_" + names[i] + "_" + label + ".log"
            str = __get_last_line(filename).decode()
            if (not str.startswith("0.") and not str.startswith("1.")):
                isFinish = False
                break

        if isFinish:
            break

        time.sleep(60)


def MonitorParameters():
    while True:
        file_list = search_all_files_return_by_time_reversed("./parameters")
        if len(file_list) > 0:
            if os.path.exists("./parameters.py"):
                os.remove("./parameters.py")
            os.rename(file_list[0], "./parameters.py")

            break
        time.sleep(60)
    return


def init():
    import parameters
    reload(parameters)
    global label
    global names
    label = parameters.label
    names = parameters.names

    if not os.path.exists(result_document + label):
        os.mkdir(result_document + label)
    shutil.copy("./parameters.py", result_document + label)

    reload(logging)
    global logger
    # 第一步，创建一个logger
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)   

    # 第二步，创建一个handler，用于写入日志文件
    logfile = result_document + label + "/logger.txt"
    fh = logging.FileHandler(logfile, mode='w')
    fh.setLevel(logging.DEBUG)

    # 第三步，再创建一个handler，用于输出到控制台
    ch = logging.StreamHandler()
    ch.setLevel(logging.INFO)  

    # 第四步，定义handler的输出格式
    formatter = logging.Formatter("%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s")
    fh.setFormatter(formatter)
    ch.setFormatter(formatter)

    # 第五步，将logger添加到handler里面
    logger.addHandler(fh)
    logger.addHandler(ch)

    logger.info("model label: " + label)


def output_result():
    output = "result: \n"
    for i in range(len(names)):
        filename = result_document + label + "/CNN_" + names[i] + "_" + label + ".log"
        str = __get_last_line(filename).decode()
        output = output + str
    logger.info(output)

def grid_search(filename, pattern, parameter_range):
    for i in range(len(parameter_range)):
        file = open("parameters.py", "r", encoding=("utf-8"))
        string = file.read()
        string = re.sub(pattern, parameter_range[i], string)
        print(string)
        file.close()
        
        file = open("parameters.py", "w", encoding=("utf-8"))
        file.write(string)
    return None

def loopMonitor():
    while True:
        MonitorParameters()
        init()

        logger.info("split start!")
        run_split()
        time.sleep(10)
        MonitorSplit()
        logger.info("split finish!")

        logger.info("train start!")
        run_train()
        time.sleep(60)
        MonitorTrain()
        logger.info("train finish!")

        output_result()


if __name__ == "__main__":
    pattern = "lr=(-?\d+)(\.\d+)?"
    parameter_range = ["lr=0.1", "lr=0.5"]
    grid_search("filename", pattern, parameter_range)


from scipy.signal import hilbert
import numpy as np
import scipy.io as scio
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torch.nn.functional as func
from torch.autograd import Function
from torch.autograd import Variable
import math
import random
import sys
import os
from modules.transformer import TransformerEncoder

# 使用的参数
# 输入数据选择
# label 为该次训练的标识
# ConType 为选用数据的声学环境，如果ConType = ["No", "Low", "High"]，则将三种声学数据混合在一起后进行训练
# names 为这次训练用到的被试数据
label = "cm_textCNN_4cm_heatmap"
ConType = ["No"]
# names = ["S1", "S2", "S3", "S4", "S5", "S6", "S7", "S8", "S9", "S10", "S11", "S12", "S13", "S14", "S15", "S16", "S17",
#          "S18"]
# names = ["S1", "S2", "S3", "S4", "S5", "S6", "S7", "S8", "S9", "S10", "S11", "S12", "S13", "S14", "S15", "S16", "S17", "S18"]
names = ["S2"]

# 所用的数据目录路径
# data_document_path = "../dataset"
data_document_path = "../dataset_16"
# data_document_path = "../dataset_csp"
# data_docum

loss_func = loss_func.to(device)
from scipy.signal import hilbert
import numpy as np
import scipy.io as scio
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torch.nn.functional as func
from torch.autograd import Function
from torch.autograd import Variable
import math
import random
import sys
import os
from modules.transformer import TransformerEncoder

# 使用的参数
# 输入数据选择
# label 为该次训练的标识
# ConType 为选用数据的声学环境，如果ConType = ["No", "Low", "High"]，则将三种声学数据混合在一起后进行训练
# names 为这次训练用到的被试数据
label = "cm_textCNN_4cm_heatmap"
ConType = ["No"]
# names = ["S1", "S2", "S3", "S4", "S5", "S6", "S7", "S8", "S9", "S10", "S11", "S12", "S13", "S14", "S15", "S16", "S17",
#          "S18"]
# names = ["S1", "S2", "S3", "S4", "S5", "S6", "S7", "S8", "S9", "S10", "S11", "S12", "S13", "S14", "S15", "S16", "S17", "S18"]
names = ["S2"]

# 所用的数据目录路径
# data_document_path = "../dataset"
data_document_path = "../dataset_16"
# data_document_pat

loss_func = loss_func.to(device)
