In [1]:
import time
import torch
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import StepLR

In [2]:
#read the model
class EEGNet(nn.Module):
    def __init__(self):
        super(EEGNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=(1, 4), stride=(1, 2))
        self.bn1 = nn.BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.pool1 = nn.MaxPool2d(kernel_size=(1, 4), stride=(1, 4))
        self.dropout1 = nn.Dropout(p=0.25)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=(1, 2), stride=(1, 2))
        self.bn2 = nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.pool2 = nn.MaxPool2d(kernel_size=(1, 4), stride=(1, 4))
        self.dropout2 = nn.Dropout(p=0.25)
        self.fc1 = nn.Linear(48576, 128)
        self.dropout3 = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(128, 6)


    def forward(self, x):
        x = torch.unsqueeze(x, 1)
        #print('x:', x.shape)
        x = self.conv1(x)
        #print('conv1:', x.shape)
        x = self.bn1(x)
        x = torch.relu(x)
        x = self.pool1(x)
        x = self.dropout1(x)
        x = self.conv2(x)
        #print('conv2:', x.shape)
        x = self.bn2(x)
        x = torch.relu(x)
        x = self.pool2(x)
        x = self.dropout2(x)
        x = x.view(x.size(0), -1)
        #print('flatten:', x.shape)
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.dropout3(x)
        x = self.fc2(x)
        return x

model = EEGNet()

In [3]:
import numpy as np

def sliding_window(signal, window_size, step_size):
    n_channels, n_samples = signal.shape
    n_windows = int((n_samples - window_size) / step_size) + 1
    windows = np.zeros((n_channels, window_size, n_windows))
    for i in range(n_windows):
        windows[:, :,i ] = signal.iloc[:, i*step_size:i*step_size+window_size]
    return windows

signal = pd.read_csv("C:\\Users\\a1882\Desktop\EEG\eegdata\\raw\lefthand_zyy_05_epocflex_2023.03.22t16.50.54+08.00.md.bp.csv", header=None)
signal = pd.DataFrame(signal)
# 定义滑窗大小和滑动步长
window_size = 3000
step_size = 100

# 对信号进行滑窗处理
windows = sliding_window(signal, window_size, step_size)

# 输出滑窗后的信号形状
print("滑窗后信号形状：", windows.shape)
#print(windows)

滑窗后信号形状： (33, 3000, 126)


In [4]:
save_path = './cnn_3000_30e_26.pt'

In [11]:
model.load_state_dict(torch.load(save_path, map_location=torch.device('cpu')))
model.eval()
total_time = 0
for i in range(windows.shape[2]):
    start_time = time.time()
    signal = pd.DataFrame(windows[:,:,i])
    input = torch.from_numpy(signal.values)
    input = input.unsqueeze(0)
    output = model(input.float())
    _, pred = torch.max(output, dim=1)  # 找到预测分数最大的类别，得到预测类别
    end_time = time.time()
    inference_time = end_time - start_time
    print('inference time: {:.2f} seconds'.format(inference_time))
    label_map = {0: 'lefthand', 1:'read' ,  2:'rest', 3: 'walkbase', 4: 'walkl' ,5: 'walkfocus'}
    print(label_map[pred.item()])
    total_time += inference_time

average_time = total_time / 126

# 输出结果
print('Average inference time: {:.2f} seconds'.format(average_time))

inference time: 0.33 seconds
lefthand
inference time: 0.09 seconds
lefthand
inference time: 0.05 seconds
lefthand
inference time: 0.05 seconds
lefthand
inference time: 0.13 seconds
lefthand
inference time: 0.07 seconds
lefthand
inference time: 0.02 seconds
lefthand
inference time: 0.02 seconds
lefthand
inference time: 0.02 seconds
lefthand
inference time: 0.02 seconds
lefthand
inference time: 0.02 seconds
lefthand
inference time: 0.02 seconds
lefthand
inference time: 0.02 seconds
lefthand
inference time: 0.03 seconds
lefthand
inference time: 0.07 seconds
lefthand
inference time: 0.03 seconds
lefthand
inference time: 0.02 seconds
lefthand
inference time: 0.02 seconds
lefthand
inference time: 0.03 seconds
lefthand
inference time: 0.02 seconds
lefthand
inference time: 0.02 seconds
lefthand
inference time: 0.02 seconds
lefthand
inference time: 0.02 seconds
lefthand
inference time: 0.03 seconds
lefthand
inference time: 0.02 seconds
lefthand
inference time: 0.02 seconds
lefthand
inference ti