In [None]:
# BMINT MODEL TRAINING
# Author： YanLi@Fudan university
# CNN

In [None]:
import torch.nn as nn
import copy
import time
from tqdm import tqdm
import torch, gc
from torch.utils.data import TensorDataset, DataLoader
import shutil
import numpy as np
import matplotlib.pyplot as plt
import os 
from scipy import signal
import torch.optim as optim

In [None]:
data = np.load('dataset.npz')

xtrs_r = data['train_s']
xtrns_r = data['train_ns']
xtes_r = data['test_s']
xtens_r = data['test_ns']

fs = 256
nseg = 128
overlap = 116

upsize = 2
imagesize = 22 * upsize

x_train = []
y_train = []
x_test = []
y_test = []

# train set
for i in tqdm(range(xtrs_r.shape[0])):
    f, t, Zxx = signal.stft(xtrs_r[i], fs, nperseg=nseg, noverlap=overlap, padded=False)
    data = np.abs(Zxx[(1< f)&(f<46)])
    data = data.repeat(upsize, axis = 0).repeat(upsize, axis = 1)
    data = data.reshape((1,imagesize,imagesize)).astype('float32')
    x_train.append(data)
    y_train.append(int(1))
#     data = torch.tensor(data)

for i in tqdm(range(xtrns_r.shape[0])):
    f, t, Zxx = signal.stft(xtrns_r[i], fs, nperseg=nseg, noverlap=overlap, padded=False)
    data = np.abs(Zxx[(1< f)&(f<46)])
    data = data.repeat(upsize, axis = 0).repeat(upsize, axis = 1)
    data = data.reshape((1,imagesize,imagesize)).astype('float32')
#     data = torch.tensor(data)    
    x_train.append(data)
    y_train.append(int(0))

# test set
for i in tqdm(range(xtes_r.shape[0])):
    f, t, Zxx = signal.stft(xtes_r[i], fs, nperseg=nseg, noverlap=overlap, padded=False)
    data = np.abs(Zxx[(1< f)&(f<46)])
    data = data.repeat(upsize, axis = 0).repeat(upsize, axis = 1)
    data = data.reshape((1,imagesize,imagesize)).astype('float32')
#     data = torch.tensor(data)
    x_test.append(data)
    y_test.append(int(1))

for i in tqdm(range(xtens_r.shape[0])):
    f, t, Zxx = signal.stft(xtens_r[i], fs, nperseg=nseg, noverlap=overlap, padded=False)
    data = np.abs(Zxx[(1< f)&(f<46)])
    data = data.repeat(upsize, axis = 0).repeat(upsize, axis = 1)
    data = data.reshape((1,imagesize,imagesize)).astype('float32')
    x_test.append(data)
    y_test.append(int(0))
#     data = torch.tensor(data)


x_train = torch.Tensor(np.array(x_train))
y_train = torch.Tensor(np.array(y_train)).long()
x_test = torch.Tensor(np.array(x_test))
y_test = torch.Tensor(np.array(y_test)).long()


# 将标签转换为 PyTorch 张量
y_train_tensor = torch.tensor(y_train)
# 计算每个类别的数量
class_counts = torch.bincount(y_train_tensor)
# 计算类别权重
class_weights = 1.0 / class_counts.float()
print(class_weights)

In [None]:
# 将数据打包成dataset
my_dataset = TensorDataset(x_train,y_train) # create your datset
train_dataloader= DataLoader(dataset=my_dataset, num_workers=4, pin_memory=False, batch_size=256, shuffle=True)

my_dataset = TensorDataset(x_test,y_test) # create your datset
# test_dataloader = DataLoader(dataset=my_dataset, num_workers=4, pin_memory=False, batch_size=len(my_dataset))
test_dataloader = DataLoader(dataset=my_dataset, num_workers=4, pin_memory=False, batch_size=256)
use_gpu = torch.cuda.is_available()
dataloaders = {'train': train_dataloader, 'val': test_dataloader}

a = len(train_dataloader.dataset)
b = len(test_dataloader.dataset)
dataset_sizes = {'train': a, 'val': b}

t_loss = []
v_loss = []
t_acc = []
v_acc = []

In [None]:
# 定义LeNet-5网络结构
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet,self).__init__()
        # 定义模型
        self.features = nn.Sequential(
            nn.Conv2d(in_channels=1,out_channels=6,kernel_size=(5,5),stride=1,padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2,stride=2),
            nn.Conv2d(in_channels=6,out_channels=16,kernel_size=(5,5),stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2,stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(in_features=1296, out_features=120),
            nn.ReLU(),
            nn.Linear(in_features=120, out_features=32),
            nn.ReLU(),
            nn.Linear(in_features=32, out_features=2)
        )

    def forward(self,x):
        # 定义前向算法
        x = self.features(x)
        x = torch.flatten(x,1)
        result = self.classifier(x)
        return result

In [None]:
# 训练函数
def train(model, train_loader, test_loader, criterion, optimizer):
    # 训练阶段
    model.train()
    total_loss = 0
    correct = 0
    TP = 0
    FN = 0
    TN = 0
    FP = 0
    total_samples = 0
    for data, target in train_loader:
        data,target = data.cuda(),target.cuda()
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = torch.max(output.data, 1)
        total_samples += target.size(0)
        correct += (predicted == target).sum().item()
        TP += ((predicted == 1) & (target == 1)).sum().item()
        FN += ((predicted == 0) & (target == 1)).sum().item()
        TN += ((predicted == 0) & (target == 0)).sum().item()
        FP += ((predicted == 1) & (target == 0)).sum().item()

    acc = correct / total_samples
    sen = TP / (TP + FN)
    spe = TN / (TN + FP)
    average_loss = total_loss / len(train_loader)
    print(f"Train Epoch {epoch+1}/{num_epochs}, Acc: {acc:.4f}, Sen: {sen:.4f}, Spe: {spe:.4f}, Loss: {loss:.4f}")
    
    # 测试阶段
    model.eval()
    total_loss = 0
    correct = 0
    TP = 0
    FN = 0
    TN = 0
    FP = 0
    total_samples = 0
    last_metric = 0

    for data, target in test_loader:
        data,target = data.cuda(),target.cuda()
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        total_loss += loss.item()
        _, predicted = torch.max(output.data, 1)
        total_samples += target.size(0)
        correct += (predicted == target).sum().item()
        TP += ((predicted == 1) & (target == 1)).sum().item()
        FN += ((predicted == 0) & (target == 1)).sum().item()
        TN += ((predicted == 0) & (target == 0)).sum().item()
        FP += ((predicted == 1) & (target == 0)).sum().item()

    acc = correct / total_samples
    sen = TP / (TP + FN)
    spe = TN / (TN + FP)
    score = sen*spe
    average_loss = total_loss / len(test_loader)
    print(f"Test Epoch {epoch+1}/{num_epochs}, Acc: {acc:.4f}, Sen: {sen:.4f}, Spe: {spe:.4f}, Loss: {loss:.4f}")
    return model, score, sen, spe, acc

In [None]:
 # 设置训练参数
num_epochs = 50
batch_size = 128
learning_rate = 0.0001

# 创建模型实例
lenet = LeNet()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))
optimizer = optim.Adam(lenet.parameters(), lr=learning_rate)

lenet.to(device)

best_model = None
b_score = 0
b_acc = 0
b_sen = 0
b_spe = 0
# 训练循环
for epoch in range(num_epochs):
    trained_model, score, sen, spe, acc = train(lenet, train_dataloader, test_dataloader, criterion, optimizer)
    if score > b_score:
        best_model = copy.deepcopy(trained_model)
        b_score = score
        b_acc = acc
        b_sen = sen
        b_spe = spe
print(f"Final metrics , Score: {b_score:.4f}, Acc: {b_acc:.4f}, Sen: {b_sen:.4f}, Spe: {b_spe:.4f}")

In [None]:
# 保存model为onnx格式
# Prepare input tensor
input_tensor = torch.randn(1, 1, 44, 44)
device = torch.device('cpu')

# Move model to the device
best_model.to(device)

directory = 'cnn'
if not os.path.exists(directory):
    os.makedirs(directory)
# Export models as ONNX files
torch.onnx.export(best_model, input_tensor, "cnn/lenet5.onnx", opset_version=11)

In [None]:
with torch.no_grad():
    outputs = best_model(x_test.to(device))
    ypred = torch.argmax(outputs.cpu(), dim=1)
    TP = ((ypred == 1) & (y_test == 1)).sum()
    FN = ((ypred == 0) & (y_test == 1)).sum()
    TN = ((ypred == 0) & (y_test == 0)).sum()
    FP = ((ypred == 1) & (y_test == 0)).sum()
    sen = TP / (TP + FN)
    spe = TN / (TN + FP)
    print(sen,spe)
    print(TP,FN,TN,FP)
    print(ypred.sum(),y_test.sum())

In [None]:
# estimate model
from sklearn.metrics import classification_report,accuracy_score,roc_curve
from sklearn.metrics import confusion_matrix,auc,RocCurveDisplay,plot_confusion_matrix
import seaborn as sns

target_names = ['non-seizure', 'seizure']
print(classification_report(y_test, ypred, target_names=target_names))
confusion = confusion_matrix(y_test,ypred)
ax = sns.heatmap(confusion, annot=True, fmt='g', cmap='Blues')
ax.set_title('Confusion Matrix\n\n');
ax.set_xlabel('\nPredicted Values')
ax.set_ylabel('Actual Values ');

## Ticket labels - List must be in alphabetical order
ax.xaxis.set_ticklabels(['non-seizure', 'seizure'])
ax.yaxis.set_ticklabels(['non-seizure', 'seizure'])

TP = confusion[1, 1]
TN = confusion[0, 0]
FP = confusion[0, 1]
FN = confusion[1, 0]

Acc=(TP+TN)/float(TP+TN+FP+FN)
Sen=TP / float(TP+FN)
Spe=TN / float(TN+FP)
print('acc', Acc, 'sen',Sen,'spe',Spe)
plt.savefig('cnn/matrix.png', dpi=500)