In [1]:
import numpy as np
import paddle
import paddle.nn as nn
import h5py
import matplotlib_inline
from paddle.io import DataLoader
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score
from sklearn.model_selection import KFold
import paddle.nn.functional as F
import copy
from Tsception_data_process import  PrepareData
from TSception import TSception
from MSBAM import MSBAM
from MSBAM_data_process import  DataDel

In [2]:
class process:
    def __init__(self,model = 1):
        self.model = model
    def train_one_epoch(self,data_loader, net , loss_func, optimizer):
        net.train()
        floss = 0
        for i, (x_batch, y_batch) in enumerate(data_loader()):
            out = net(x_batch)
            loss = loss_func(out, y_batch)
        #    avloss = paddle.mean(loss)
            floss = floss+loss
        #    _, pred = paddle.max(out, 1)
        #    pred_train.extend(pred.data.tolist())
        #    act_train.extend(y_batch.data.tolist())  #tolist() 返回列表或者数字
            optimizer.clear_grad()
            loss.backward()
            optimizer.step()
        #floss = floss/i+1
        return floss

    def get_model(self,num_classes=2,input_size=(1,28,512),sampling_rate=128,num_T=15,num_S=15,hidden=32,dropout_rate=0.8):
        if self.model == 1:

            model = TSception(
                num_classes=num_classes, input_size=input_size,
                sampling_rate=sampling_rate, num_T=num_T, num_S=num_S,
                hidden=hidden, dropout_rate=dropout_rate)
        else:
            model = MSBAM(2)
        return model

    def split_balance_class(self,data,label, k = 8, random = True):
        np.random.seed(888)
        KF = KFold(n_splits=k,shuffle=True)
        for idx_fold, (index_train, index_test) in enumerate(KF.split(data)):
            train_index,test_index = index_train,index_test
            break
        train = data[train_index]
        train_label = label[train_index]
        val = data[test_index]
        val_label = label[test_index]
        return train, train_label, val, val_label

    def normalize(self,train, test):

        """
            this function do standard normalization for EEG channel by channel
            :param train: training data
            :param test: testing data
            :return: normalized training and testing data
        """
        # data: sample x 1 x channel x data
        mean = 0
        std = 0
        for channel in range(train.shape[2]):
            mean = np.mean(train[:, :, channel, :])
            std = np.std(train[:, :, channel, :])
            train[:, :, channel, :] = (train[:, :, channel, :] - mean) / std
            test[:, :, channel, :] = (test[:, :, channel, :] - mean) / std
        return train, test


    def train(self,data_train, label_train, data_val, label_val,epochs = 100,batch_size = 64,save_path = None,crosv = False,k = None):

        CUDA = True
        loss_list = []
        #   save_name = '_sub' + str(subject) + '_trial' + str(fold)
        #   set_up(args)
        #    train_dataset = MyDataset(data_train,label_train)
        train_loader = self.get_dataloader(data_train,label_train,shuffle = True,batch_size=batch_size)

        #    val_dataset = MyDataset(data_val,label_val)
        val_loader = self.get_dataloader(data_val,label_val,shuffle = True,batch_size = batch_size)

        loss_fn = nn.CrossEntropyLoss()
        model = self.get_model()
        # para = get_trainable_parameter_num(model)
        # print('Model {} size:{}'.format(args.model, para))
        #可以省略

        if CUDA:
            paddle.set_device('gpu') if paddle.is_compiled_with_cuda() else paddle.set_device('cpu')
        if self.model == 1:
            lr = 0.001
        else:
            lr = 0.0001
        optimizer = paddle.optimizer.Adam(learning_rate=lr,parameters = model.parameters())

        for epoch in range(1, epochs+1):

            loss = self.train_one_epoch(data_loader=train_loader, net=model,loss_func = loss_fn, optimizer=optimizer)
            loss_list.append(loss.numpy())
    #    acc_train, f1_train, _ = get_metrics(y_pred=pred_train, y_true=act_train)
    #    print('epoch {}, loss={:.4f} acc={:.4f} f1={:.4f}'
    #          .format(epoch, loss_train, acc_train, f1_train))

    #    loss_val, pred_val, act_val = predict(
    #        data_loader=val_loader, net=model, loss_fn=loss_fn
    #    )
    #    acc_val, f1_val, _ = get_metrics(y_pred=pred_val, y_true=act_val)
    #    print('epoch {}, val, loss={:.4f} acc={:.4f} f1={:.4f}'.
    #          format(epoch, loss_val, acc_val, f1_val))
            print('epoch: {} --- loss:{}'.format(epoch,loss.numpy()))
            act,pre,ls = self.predict(val_loader,model,loss_fn)   #ls = loss
            acc,f1,_ = self.get_metrics(pre,act)
            print('  val: loss:{}  acc:{}  f1:{}'.format(ls.numpy(),acc,f1))
            if epoch % 100 ==0 and crosv == False:
                paddle.save(model.state_dict(),'./models-MSBAM/model_epoch_{}.pdparams'.format(epoch))
        if crosv ==False:
            paddle.save(model.state_dict(),'./models-MSBAM/model_final.pdparams')
        else:
            paddle.save(model.state_dict(),save_path+'/model_{}_fold.pdparams'.format(k))

#data_loader = paddle.io.DataLoader(train_dataset, batch_size=64, shuffle=True)
    def predict(self,data_loader, net,loss_func):
        net.eval()
        avloss = 0
        pred_val = []
        act_val = []
        with paddle.no_grad():
            for i, data in enumerate(data_loader):
                out = net(data[0])
                loss =loss_func(out, data[1])
                out = F.softmax(out)
                pred = paddle.argmax(out, 1)
                avloss = avloss + loss
                pred_val.extend(pred.numpy())
                act_val.extend(data[1].numpy())
        return  act_val,pred_val,avloss/i+1

    def test(self,data, label, model_path ,batch_size=32,k=None,crosv = False):
        CUDA = True
        test_loader = self.get_dataloader(data, label,batch_size = batch_size, shuffle=False)
        model = self.get_model()
        if crosv:
            path = model_path+'/model_{}_fold.pdparams'.format(k)
        else:
            path = model_path
        model.load_dict(paddle.load(path))
    #CUDA = torch.cuda.is_available()  #自己添加的 TQ
        if CUDA:
            paddle.set_device('gpu') if paddle.is_compiled_with_cuda() else paddle.set_device('cpu')
        loss_fn = nn.CrossEntropyLoss()

        """
        if reproduce:
            model_name_reproduce = 'sub' + str(subject) + '_fold' + str(fold) + '.pth'
            data_type = 'model_{}_{}_{}'.format(args.dataset, args.data_format, args.label_type)
            save_path = osp.join(args.save_path, data_type)
            ensure_path(save_path)
            model_name_reproduce = osp.join(save_path, model_name_reproduce)
            model.load_state_dict(torch.load(model_name_reproduce,map_location=torch.device('cpu')))
    
        else:
            model.load_state_dict(torch.load(args.load_path))
        loss, pred, act = predict(
            data_loader=test_loader, net=model, loss_fn=loss_fn
        )
        """
        act,pred,loss = self.predict(test_loader,model,loss_fn)
        acc, f1, _ = self.get_metrics(y_pred=pred, y_true=act)
        print('>>> Test:  loss={:.4f} acc={:.4f} f1={:.4f}'.format(loss.item(), acc, f1))
        return act, pred, loss
    def trial_wise_voting(self,act, pred, num_segment_per_trial = 15):
        """
        this function do voting within each tiral to get the label of entire trial
        :param act: [num_sample] list
        :param pred: [num_sample] list
        :param num_segment_per_trial: how many samples per trial
        :param trial_in_fold: how many trials in this fold
        :return: trial-wise actual label and predicted label
        """
        trial_in_fold = len(pred)/num_segment_per_trial
        num_trial = int(len(act)/num_segment_per_trial)
        assert num_trial == trial_in_fold
        act_trial = np.reshape(act, (num_trial, num_segment_per_trial))   #4*15
        pred_trial = np.reshape(pred, (num_trial, num_segment_per_trial))
        #print("\n", "act:", act_trial)
        act_trial = np.mean(act_trial, axis=-1).tolist()
        pred_vote = []
        for trial in pred_trial:    #for each row
            index_0 = np.where(trial == 0)[0]
            index_1 = np.where(trial == 1)[0]
            if len(index_1) >= len(index_0):
                label = 1
            else:
                label = 0
            pred_vote.append(label)
        #print("\n","act_trial:",act_trial,'\npred_vote:',pred_vote)
        return act_trial, pred_vote


    def get_metrics(self,y_pred, y_true, classes=None):
        acc = accuracy_score(y_true, y_pred)
        f1 = f1_score(y_true, y_pred)
        if classes is not None:
            cm = confusion_matrix(y_true, y_pred, labels=classes)
        else:
            cm = confusion_matrix(y_true, y_pred)
        return acc, f1, cm

    def get_dataloader(self,data,label,shuffle,batch_size):
        dataset = MyDataset(data,label)
        loader = DataLoader(dataset,shuffle = shuffle,batch_size=batch_size)
        return loader
    def prepare_data(self,idx_train, idx_test, data, label):
        """
            1. get training and testing data according to the index
            2. numpy.array-->torch.tensor
            :param idx_train: index of training data
            :param idx_test: index of testing data
            :param data: (trial, segments, 1, channel, data)
            :param label: (trial, segments,)
            :return: data and label
        """
        data_train = data[idx_train]
        label_train = label[idx_train]
        data_test = data[idx_test]
        label_test = label[idx_test]

        data_train = np.concatenate(data_train, axis=0)
        label_train = np.concatenate(label_train, axis=0)

            # the testing data do not need to be concatenated, when doing leave-one-trial-out
        if len(data_test.shape)>4:
            data_test = np.concatenate(data_test, axis=0)
            label_test = np.concatenate(label_test, axis=0)
    
#if MSBAM dont need this -------!!!!!!!!!!!!!!!!!!!!!!!!!!!!
        if self.model == 1:
            data_train, data_test = normalize(train=data_train, test=data_test)

        # Prepare the data format for training the model
        #data_train = torch.from_numpy(data_train).float()
        #abel_train = torch.from_numpy(label_train).long()

        #data_test = torch.from_numpy(data_test).float()
        #label_test = torch.from_numpy(label_test).long()
        return data_train, label_train, data_test, label_test

class MyDataset(paddle.io.Dataset):
    def __init__(self,data,label):
        self.x = data
        self.y = label
        assert self.x.shape[0] == self.y.shape[0]

    def __getitem__(self, index):
        return self.x[index], self.y[index]

    def __len__(self):
        return len(self.y)

In [3]:
def log2txt(filename,content):
    """
    this function log the content to results.txt
    :param content: string, the content to log
    """
    file = open(filename, 'a')
    file.write(str(content) + '\n')
    file.close()
def corss_validation(VV,LL,data_path,model,Kfold = 10,depend = True,train = True,save_path = None,):
    if model == 1:
        log_file = "TSception_results.txt"
    else:
        log_file = "MSBAM_results.txt"
    file = open(log_file, 'a')
    file.write("---------results----------")
    file.close
    proce = process(model = model)
    """
    if model == 1:
        predata = PrepareData()
        VV,LL = predata.together(data_path)
    else:
        predata = DataDel()
        VV,LL = predata.getdata(data_path)
    """
    VV = VV.astype('float32')
    LL = LL.astype('int')
    np.random.seed(666)
    kf = KFold(n_splits=Kfold ,shuffle=True)
    for idx_fold, (index_train, index_test) in enumerate(kf.split(VV)):
        V,L,data_test,label_test = proce.prepare_data(index_train,index_test,VV,LL)
        V,L,val,val_l = proce.split_balance_class(V,L,k=8)
        if train:
            proce.train(V,L,val,val_l,epochs=150,crosv=True,save_path=save_path,k = idx_fold+1)
        act,pred,loss = proce.test(data_test,label_test,model_path = save_path,k = idx_fold+1,crosv=True)
        acc,f1,_ = proce.get_metrics(pred,act)  #pred true
        content = '{}_fold:  acc:{}  f1:{}'.format(idx_fold+1,acc,f1)
        log2txt(log_file,content)
        if depend:
            if model == 1:
                seg = 15
            else:
                seg = 12
            trial_act,trial_pred = proce.trial_wise_voting(act,pred,seg)  #act, pred, num_segment_per_trial 
            trial_acc,trial_f1,_ = proce.get_metrics(trial_pred,trial_act)
            content = '{}_fold:  trial_acc:{}  trial_f1:{} \n'.format(idx_fold+1,trial_acc,trial_f1)
            print(content)
            log2txt(log_file,content)
        

In [5]:
corss_validation(VV,LL,'DATA_MSBAM_DEAP_A/',model=2,save_path='MSBAM_K_fold')

W1120 09:19:53.540599   562 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W1120 09:19:53.545145   562 device_context.cc:465] device: 0, cuDNN Version: 7.6.
  "When training, we now always track global mean and variance.")


epoch: 1 --- loss:[401.03143]
  val: loss:[1.8524892]  acc:0.5416666666666666  f1:0.667785234899329


  "When training, we now always track global mean and variance.")


epoch: 2 --- loss:[381.9377]
  val: loss:[1.8133776]  acc:0.5381944444444444  f1:0.6661087866108787


  "When training, we now always track global mean and variance.")


epoch: 3 --- loss:[340.12115]
  val: loss:[1.7907305]  acc:0.5370370370370371  f1:0.6669442131557036


  "When training, we now always track global mean and variance.")


epoch: 4 --- loss:[315.1444]
  val: loss:[1.7719383]  acc:0.5480324074074074  f1:0.6787330316742082


  "When training, we now always track global mean and variance.")


epoch: 5 --- loss:[296.9502]
  val: loss:[1.7602701]  acc:0.5532407407407407  f1:0.6897106109324759


  "When training, we now always track global mean and variance.")


epoch: 6 --- loss:[277.64407]
  val: loss:[1.7579997]  acc:0.5555555555555556  f1:0.6913183279742765


  "When training, we now always track global mean and variance.")


epoch: 7 --- loss:[260.5379]
  val: loss:[1.7425532]  acc:0.5549768518518519  f1:0.6872712484749899


  "When training, we now always track global mean and variance.")


epoch: 8 --- loss:[250.65637]
  val: loss:[1.732152]  acc:0.5578703703703703  f1:0.6866283839212469


  "When training, we now always track global mean and variance.")


epoch: 9 --- loss:[229.69954]
  val: loss:[1.7262111]  acc:0.5665509259259259  f1:0.701949860724234


  "When training, we now always track global mean and variance.")


epoch: 10 --- loss:[222.40126]
  val: loss:[1.7203166]  acc:0.5711805555555556  f1:0.7022900763358779


  "When training, we now always track global mean and variance.")


epoch: 11 --- loss:[214.10301]
  val: loss:[1.7178171]  acc:0.5734953703703703  f1:0.7031816351188079


  "When training, we now always track global mean and variance.")


epoch: 12 --- loss:[199.9915]
  val: loss:[1.719619]  acc:0.5798611111111112  f1:0.7091346153846154


  "When training, we now always track global mean and variance.")


epoch: 13 --- loss:[200.91685]
  val: loss:[1.714485]  acc:0.5850694444444444  f1:0.7084180561203741


  "When training, we now always track global mean and variance.")


epoch: 14 --- loss:[189.57465]
  val: loss:[1.711288]  acc:0.5833333333333334  f1:0.7009966777408637


  "When training, we now always track global mean and variance.")


epoch: 15 --- loss:[185.59071]
  val: loss:[1.7101222]  acc:0.5902777777777778  f1:0.7079207920792078


  "When training, we now always track global mean and variance.")


epoch: 16 --- loss:[180.7455]
  val: loss:[1.7083565]  acc:0.5856481481481481  f1:0.7055921052631579


  "When training, we now always track global mean and variance.")


epoch: 17 --- loss:[173.36946]
  val: loss:[1.7041032]  acc:0.5844907407407407  f1:0.699581589958159


  "When training, we now always track global mean and variance.")


epoch: 18 --- loss:[170.60855]
  val: loss:[1.7003131]  acc:0.59375  f1:0.7125307125307125


  "When training, we now always track global mean and variance.")


epoch: 19 --- loss:[165.6615]
  val: loss:[1.700494]  acc:0.5920138888888888  f1:0.7051442910915935


  "When training, we now always track global mean and variance.")


epoch: 20 --- loss:[164.56876]
  val: loss:[1.6972437]  acc:0.5931712962962963  f1:0.70793518903199


  "When training, we now always track global mean and variance.")


epoch: 21 --- loss:[161.15039]
  val: loss:[1.6945732]  acc:0.5995370370370371  f1:0.7123857024106401


  "When training, we now always track global mean and variance.")


epoch: 22 --- loss:[159.87468]
  val: loss:[1.6919625]  acc:0.6059027777777778  f1:0.7196377109921779


  "When training, we now always track global mean and variance.")


epoch: 23 --- loss:[152.2246]
  val: loss:[1.6890421]  acc:0.6105324074074074  f1:0.721094073767095


  "When training, we now always track global mean and variance.")


epoch: 24 --- loss:[148.8473]
  val: loss:[1.6869483]  acc:0.6082175925925926  f1:0.7201322860686235


  "When training, we now always track global mean and variance.")


epoch: 25 --- loss:[148.73637]
  val: loss:[1.6851771]  acc:0.6168981481481481  f1:0.71733561058924


  "When training, we now always track global mean and variance.")


epoch: 26 --- loss:[147.21408]
  val: loss:[1.6814905]  acc:0.6307870370370371  f1:0.7238095238095237


  "When training, we now always track global mean and variance.")


epoch: 27 --- loss:[145.51137]
  val: loss:[1.6791654]  acc:0.6359953703703703  f1:0.7230295024218405


  "When training, we now always track global mean and variance.")


epoch: 28 --- loss:[142.03044]
  val: loss:[1.6747274]  acc:0.6348379629629629  f1:0.7201773835920177


  "When training, we now always track global mean and variance.")


epoch: 29 --- loss:[139.87111]
  val: loss:[1.667811]  acc:0.6429398148148148  f1:0.7273530711444985


  "When training, we now always track global mean and variance.")


epoch: 30 --- loss:[138.45457]
  val: loss:[1.6665031]  acc:0.6446759259259259  f1:0.7283185840707964


  "When training, we now always track global mean and variance.")


epoch: 31 --- loss:[137.2665]
  val: loss:[1.6613358]  acc:0.6597222222222222  f1:0.743231441048035


  "When training, we now always track global mean and variance.")


epoch: 32 --- loss:[135.89622]
  val: loss:[1.65316]  acc:0.6603009259259259  f1:0.7392270102176811


  "When training, we now always track global mean and variance.")


epoch: 33 --- loss:[nan]
  val: loss:[nan]  acc:0.4253472222222222  f1:0.0


  "When training, we now always track global mean and variance.")


epoch: 34 --- loss:[nan]
  val: loss:[nan]  acc:0.4253472222222222  f1:0.0


  "When training, we now always track global mean and variance.")


epoch: 35 --- loss:[nan]
  val: loss:[nan]  acc:0.4253472222222222  f1:0.0


  "When training, we now always track global mean and variance.")


epoch: 36 --- loss:[nan]
  val: loss:[nan]  acc:0.4253472222222222  f1:0.0


  "When training, we now always track global mean and variance.")


epoch: 37 --- loss:[nan]
  val: loss:[nan]  acc:0.4253472222222222  f1:0.0


  "When training, we now always track global mean and variance.")


epoch: 38 --- loss:[nan]
  val: loss:[nan]  acc:0.4253472222222222  f1:0.0


  "When training, we now always track global mean and variance.")


epoch: 39 --- loss:[nan]
  val: loss:[nan]  acc:0.4253472222222222  f1:0.0


  "When training, we now always track global mean and variance.")


epoch: 40 --- loss:[nan]
  val: loss:[nan]  acc:0.4253472222222222  f1:0.0


  "When training, we now always track global mean and variance.")


epoch: 41 --- loss:[nan]
  val: loss:[nan]  acc:0.4253472222222222  f1:0.0


  "When training, we now always track global mean and variance.")


epoch: 42 --- loss:[nan]
  val: loss:[nan]  acc:0.4253472222222222  f1:0.0


  "When training, we now always track global mean and variance.")


epoch: 43 --- loss:[nan]
  val: loss:[nan]  acc:0.4253472222222222  f1:0.0


  "When training, we now always track global mean and variance.")


epoch: 44 --- loss:[nan]
  val: loss:[nan]  acc:0.4253472222222222  f1:0.0


  "When training, we now always track global mean and variance.")


epoch: 45 --- loss:[nan]
  val: loss:[nan]  acc:0.4253472222222222  f1:0.0


  "When training, we now always track global mean and variance.")


epoch: 46 --- loss:[nan]
  val: loss:[nan]  acc:0.4253472222222222  f1:0.0


  "When training, we now always track global mean and variance.")


epoch: 47 --- loss:[nan]
  val: loss:[nan]  acc:0.4253472222222222  f1:0.0


  "When training, we now always track global mean and variance.")


epoch: 48 --- loss:[nan]
  val: loss:[nan]  acc:0.4253472222222222  f1:0.0


  "When training, we now always track global mean and variance.")


epoch: 49 --- loss:[nan]
  val: loss:[nan]  acc:0.4253472222222222  f1:0.0


  "When training, we now always track global mean and variance.")


epoch: 50 --- loss:[nan]
  val: loss:[nan]  acc:0.4253472222222222  f1:0.0


  "When training, we now always track global mean and variance.")


epoch: 51 --- loss:[nan]
  val: loss:[nan]  acc:0.4253472222222222  f1:0.0


  "When training, we now always track global mean and variance.")


epoch: 52 --- loss:[nan]
  val: loss:[nan]  acc:0.4253472222222222  f1:0.0


  "When training, we now always track global mean and variance.")


epoch: 53 --- loss:[nan]
  val: loss:[nan]  acc:0.4253472222222222  f1:0.0


  "When training, we now always track global mean and variance.")


epoch: 54 --- loss:[nan]
  val: loss:[nan]  acc:0.4253472222222222  f1:0.0


  "When training, we now always track global mean and variance.")


KeyboardInterrupt: 

In [4]:
predata = DataDel()
VV,LL = predata.getdata('DATA_MSBAM_DEAP_A/')

In [3]:
predata = DataDel(data_path='data/data_preprocessed_matlab')
#import os
V,L = predata.getdata('DATA_MSBAM_DEAP_A/')
V = V.astype('float32')
L = L.astype('int')

In [8]:
def we():
    file = open('what.txt', 'a')
    file.write("---------results----------")
    file.close
we()

In [5]:
np.random.seed(666)
kf = KFold(n_splits=4, shuffle=True)
for idx_fold, (index_train, index_test) in enumerate(kf.split(V)):
    train_index,test_index = index_train,index_test
    break
len(train_index),len(test_index)

(960, 320)

In [9]:
V,L,data_test,lable_test = proce.prepare_data(train_index,test_index,V,L)
V.shape,L.shape,data_test.shape,lable_test.shape

((11520, 1, 640, 9, 9), (11520,), (3840, 1, 640, 9, 9), (3840,))

In [10]:
V,L,val,val_l = proce.split_balance_class(V,L,k=8)
V.shape,L.shape,val.shape,val_l.shape

((10080, 1, 640, 9, 9), (10080,), (1440, 1, 640, 9, 9), (1440,))

In [15]:
proce.train(V,L,val,val_l,epochs=300)

In [21]:
act,pred,loss = proce.test(val,val_l,'models-MSBAM/model_epoch_10.pdparams')
acc , f1 ,_ = proce.get_metrics(act,pred)
#a,b = trial_wise_voting(act,pred,num_segment_per_trial=12)
#acc , f1 ,_ = get_metrics(a,b)
acc,f1

>>> Test:  loss=1.7171 acc=0.5910 f1=0.6794


(0.5909722222222222, 0.6793685356559608)

In [25]:
path = '32'+'/model_{}_fold'.format(3)

In [26]:
path

'32/model_3_fold'