In [None]:
# BMINT MODEL TRAINING
# Author： YanLi@Fudan university
# SVM

In [None]:
import numpy as np
import mne
import os
import re
# %matplotlib notebook
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler,MinMaxScaler
from scipy.signal import welch,sosfiltfilt,detrend,butter

from sklearn.svm import SVC
from sklearn.metrics import classification_report,accuracy_score,roc_curve
from sklearn.metrics import confusion_matrix,auc,RocCurveDisplay,plot_confusion_matrix
from sklearn.model_selection import cross_val_score
from sklearn.utils import class_weight
import math
import warnings
from sklearn.exceptions import DataConversionWarning
# 忽略DataConversionWarning警告
warnings.filterwarnings("ignore", category=DataConversionWarning)

In [None]:
# 读取数据集的numpy环境
# train_s
# train_ns
# test_s
# test_ns
data = np.load('dataset.npz')

# 训练集及label生成
xtrs_r = data['train_s']
xtrns_r = data['train_ns']
ytrs = np.zeros((xtrs_r.shape[0],1))
ytrns = np.zeros((xtrns_r.shape[0],1))
# seizure is 1, non seizure is 0
ytrs = ytrs + 1
ytrns = ytrns

# 测试集及label生成
xtes_r = data['test_s']
xtens_r = data['test_ns']
ytes = np.zeros((xtes_r.shape[0],1))
ytens = np.zeros((xtens_r.shape[0],1))
# seizure is 1, non seizure is 0
ytes = ytes + 1
ytens = ytens

# 构建完整的训练集和测试集
X_train = np.concatenate([xtrs_r,xtrns_r])
y_train = np.concatenate([ytrs,ytrns])

X_test = np.concatenate([xtes_r,xtens_r])
y_test = np.concatenate([ytes,ytens])

print(y_train.shape, y_test.shape)
plt.plot(y_train)

In [None]:
# 特征提取，四个时域特征，方差，能量，非线性能量，香农熵
def shannon_entropy(X):
    n = len(X)
    # 使用Sturges规则计算直方图的bins数
    k = int(math.ceil(1 + math.log2(n)))
    # 计算每个bin中数据的数量
    hist, _ = np.histogram(X, bins=k)
    # 计算每个bin中数据的概率
    probs = hist / float(n)
    # 计算香农熵
    entropy = -np.sum([p * np.log2(p) for p in probs if p != 0])
    return entropy

def feature_extraction(X_A, X_B):
    # Calculate time-domain features for X_A
    var_A = np.var(X_A, axis=1)
    energy_A = np.sum(np.square(X_A), axis=1)
    nonlinear_energy_A = np.sum(X_A[:, 1:-1]**2 - X_A[:, :-2]*X_A[:, 2:], axis=1)
    entropy_A = np.apply_along_axis(shannon_entropy, axis=1, arr=X_A)
    X_train_features = np.column_stack((var_A, energy_A, nonlinear_energy_A, entropy_A))

    # Calculate time-domain features for X_B
    var_B = np.var(X_B, axis=1)
    energy_B = np.sum(np.square(X_B), axis=1)
    nonlinear_energy_B = np.sum(X_B[:, 1:-1]**2 - X_B[:, :-2]*X_B[:, 2:], axis=1)
    entropy_B = np.apply_along_axis(shannon_entropy, axis=1, arr=X_B)
    X_test_features = np.column_stack((var_B, energy_B, nonlinear_energy_B, entropy_B))

    return X_train_features, X_test_features

In [None]:
X_train_features , X_test_features = feature_extraction(X_train , X_test)
print(X_train_features.shape,X_test_features.shape)

In [None]:
# training the SVM model
# X_trainf , X_testf
target_names = ['non-seizure', 'seizure']
svm = SVC(class_weight='balanced') # can change different kernel kernel='linear'
svm.fit(X_train_features , y_train)
ypred = svm.predict(X_test_features)

# plt.plot(ypred, label='ypred')
plt.plot(ypred, 'o', color='black', label='ypred');
plt.plot(y_test, label='ytest')
plt.legend()
y_score=svm.decision_function(X_test_features)

print("accuracy is:", 100*accuracy_score(y_test , ypred))
print (classification_report(y_test, ypred, target_names=target_names))
fpr, tpr, thresholds = roc_curve(y_test , y_score)
roc_auc = auc(fpr, tpr)
display = RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc,estimator_name='svm')
display.plot()


confusion = confusion_matrix(y_test,ypred)
TP = confusion[1, 1]
TN = confusion[0, 0]
FP = confusion[0, 1]
FN = confusion[1, 0]
Accuracy=(TP+TN)/float(TP+TN+FP+FN)
Sensitivity=TP / float(TP+FN)
Specificity=TN / float(TN+FP)
print('acc', Accuracy, 'sensitivity',Sensitivity,'specificity',Specificity)

In [None]:
import seaborn as sns
ax = sns.heatmap(confusion, annot=True, fmt='g', cmap='Blues')
sns.set(font_scale=1)
ax.set_title('Confusion Matrix\n\n');
ax.set_xlabel('\nPredicted Values')
ax.set_ylabel('Actual Values ');

## Ticket labels - List must be in alphabetical order
ax.xaxis.set_ticklabels(['non-seizure', 'seizure'])
ax.yaxis.set_ticklabels(['non-seizure', 'seizure'])

## Display the visualization of the Confusion Matrix.
plt.savefig('matrix.png', dpi=500)

In [None]:
# 保存SVM模型
from joblib import dump, load
directory = 'svm'
if not os.path.exists(directory):
    os.makedirs(directory)

dump(svm, 'svm/svm.joblib') 