In [1]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import StepLR


In [2]:
class Transformer(nn.Module):
    def __init__(self, d_model, nhead, num_layers):
        super(Transformer, self).__init__()
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(d_model, 6)

    def forward(self, x):
        x = self.transformer_encoder(x)
        x = x.mean(dim=1)  # 取平均作为输出
        x = self.fc(x)
        return x

In [3]:
d_model = 32  # Transformer模型中特征的维度
nhead = 4  # 多头自注意力头数
num_layers = 4  # Transformer编码器层数
num_classes = 6  # 分类类别数

In [30]:
import numpy as np

def sliding_window(signal, window_size, step_size):
    n_channels, n_samples = signal.shape
    n_windows = int((n_samples - window_size) / step_size) + 1
    windows = np.zeros((n_channels, window_size, n_windows))
    for i in range(n_windows):
        windows[:, :,i ] = signal.iloc[:, i*step_size:i*step_size+window_size]
    return windows

signal = pd.read_csv("C:\\Users\\a1882\Desktop\EEG\eegdata\\raw\\lefthand_zyy_05_epocflex_2023.03.22t16.50.54+08.00.md.bp.csv")
signal = pd.DataFrame(signal)
# 定义滑窗大小和滑动步长
window_size = 1000
step_size = 100

# 对信号进行滑窗处理
windows = sliding_window(signal, window_size, step_size)

# 输出滑窗后的信号形状
print("滑窗后信号形状：", windows.shape)
#print(windows)

滑窗后信号形状： (32, 1000, 146)


In [31]:
import onnx
import onnxruntime
save_path = "C:\\Users\\a1882\Desktop\\EEG\\model_saved\\transformer_1000_100e_1.pt"

onnx_file_name = "./cnn_3000_100e.onnx"
torch.load(save_path, map_location=torch.device('cpu'))

OrderedDict([('encoder_layer.self_attn.in_proj_weight',
              tensor([[-0.0299,  0.0612, -0.1074,  ...,  0.1874, -0.1438,  0.1018],
                      [-0.2022,  0.1797, -0.0896,  ...,  0.0619,  0.1461, -0.0290],
                      [ 0.1697,  0.1872, -0.1305,  ...,  0.0373, -0.1238,  0.0025],
                      ...,
                      [ 0.0434, -0.0207, -0.1276,  ..., -0.2075, -0.1074,  0.0820],
                      [-0.0797,  0.0602, -0.1537,  ..., -0.1581, -0.0481, -0.1953],
                      [-0.1915,  0.0174,  0.1310,  ...,  0.0938, -0.1755, -0.0699]])),
             ('encoder_layer.self_attn.in_proj_bias',
              tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
  

In [32]:
import onnx
# 我们可以使用异常处理的方法进行检验
try:
    # 当我们的模型不可用时，将会报出异常
    onnx.checker.check_model(onnx_file_name)
except onnx.checker.ValidationError as e:
    print("The model is invalid: %s"%e)
else:
    # 模型可用时，将不会报出异常，并会输出“The model is valid!”
    print("The model is valid!")

The model is valid!


In [33]:
for i in range(windows.shape[2]):
    signal = pd.DataFrame(windows[:,:,i]).T
    model = Transformer(d_model=32, nhead=4, num_layers=4)
    model.load_state_dict(torch.load(save_path))
    model.eval()
    input = torch.from_numpy(signal.values)
    input = input.unsqueeze(0)
    with torch.no_grad():
        output = model(input.float())
        #print(output)
    # probabilities = torch.nn.functional.softmax(output, dim=1)
    # print(probabilities)

    _, pred = torch.max(output, dim=1)  # 找到预测分数最大的类别，得到预测类别
    label_map = {0: 'lefthand', 1:'read' ,  2:'rest', 3: 'walkbase', 4: 'walkl' ,5: 'walkfocus'}
    print(label_map[pred.item()])

walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
rest
rest
rest
rest
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
walkfocus
