In [1]:
import optuna
import torch
import torch.optim as optim
import torch.nn as nn
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import scipy.io as sio
import numpy as np
import os


In [None]:

class AEA(nn.Module):
    def __init__(self, E_channels=30, sampleLength=384, i=3, classes=5):
        super(AEA, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d((E_channels, 1))
        self.max_pool = nn.AdaptiveMaxPool2d((E_channels, 1))
        self.temporal_fc= nn.Conv2d(1, 1, kernel_size=(30, 1), stride=(1,1),padding=(0, 0))
        self.fc1 = nn.Conv2d(1, 32, kernel_size=(i, 1), padding=(1, 0))
        self.relu = nn.ReLU()
        self.fc2 = nn.Conv2d(32, 1, kernel_size=(i, 1), padding=(1, 0))
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        temporal_weigh=self.temporal_fc(x)
        temporal_weigh=self.relu(temporal_weigh)

        # 获取每个电极的全局统计特征
        max_val = torch.max(x, dim=-1, keepdim=True)[0]  # 每个电极的最大值
        min_val = torch.min(x, dim=-1, keepdim=True)[0]  # 每个电极的最小值
        mean_val = torch.mean(x, dim=-1, keepdim=True)  # 每个电极的均值
        std_val = torch.std(x, dim=-1, keepdim=True)  # 每个电极的标准差

        # 计算基于这些统计量的波动度量
        abs_diff = torch.abs(max_val - min_val)  # 基于最大最小值的波动度量

        score_avg = self.avg_pool(abs_diff)
        out_avg = self.fc1(score_avg)
        out_avg = self.relu(out_avg)
        out_avg = self.fc2(out_avg)
        weight_avg = self.sigmoid(out_avg)

        score_max = self.max_pool(abs_diff)
        out_max = self.fc1(score_max)
        out_max = self.relu(out_max)
        out_max = self.fc2(out_max)
        weight_max = self.sigmoid(out_max)

        # 合并加权池化
        weight = (weight_avg + weight_max)
        return weight

In [None]:
class AEA_ICNN(torch.nn.Module):

    def __init__(self, classes=2, sampleChannel=30, sampleLength=384, N1=16, d=2, kernelLength=64):
        super(AEA_ICNN,self).__init__()      

        self.AEA = AEA()
        self.pointwise = torch.nn.Conv2d(1,16, (sampleChannel, 1))     
        self.depthwise = torch.nn.Conv2d(16,32, (1, kernelLength),groups=16)
        self.activ = torch.nn.ReLU() 
        self.batchnorm = torch.nn.BatchNorm2d(32, track_running_stats=False)  # 归一化处理
        self.GAP = torch.nn.AvgPool2d((1, sampleLength - kernelLength + 1))  # 平均池化
        self.fc = torch.nn.Linear(32, classes)  # 全连接层
#         self.softmax_EA = torch.nn.Softmax(dim=0)  #SA需要沿dim=0，沿行；dim=1是沿列累加为1.
        self.softmax = torch.nn.LogSoftmax(dim=1)  # sotmax层
       
    
    def forward(self, inputdata):
        attention_vectors = self.AEA(inputdata)
        intermediate = inputdata * attention_vectors
        intermediate = self.pointwise(intermediate)
        intermediate = self.depthwise(intermediate)        
        intermediate = self.activ(intermediate)
        intermediate = self.batchnorm(intermediate)
        intermediate = self.GAP(intermediate)
        intermediate = intermediate.view(intermediate.size()[0], -1)
        intermediate = self.fc(intermediate)
        output = self.softmax(intermediate)

        return output