In [6]:
import mne

# 设置编码为 'latin1'
raw = mne.io.read_raw_edf('糖尿病认知障碍与对照脑电数据/认知障碍/陈艳杰.edf', preload=True, encoding='latin1')
import warnings
import mne

# 忽略 RuntimeWarning 警告
warnings.filterwarnings("ignore", category=RuntimeWarning)
# 查看数据的基本信息
print(raw.info)
raw.ch_names

Extracting EDF parameters from /home/lyq/Desktop/DL/NEW_STCGRU/糖尿病认知障碍与对照脑电数据/认知障碍/陈艳杰.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 236999  =      0.000 ...   473.998 secs...
<Info | 8 non-empty values
 bads: []
 ch_names: EEG Fp1-Ref, EEG Fp2-Ref, EEG F3-Ref, EEG F4-Ref, EEG C3-Ref, ...
 chs: 43 EEG
 custom_ref_applied: False
 highpass: 0.0 Hz
 lowpass: 250.0 Hz
 meas_date: 2024-01-11 10:31:24 UTC
 nchan: 43
 projs: []
 sfreq: 500.0 Hz
 subject_info: 3 items (dict)
>


['EEG Fp1-Ref',
 'EEG Fp2-Ref',
 'EEG F3-Ref',
 'EEG F4-Ref',
 'EEG C3-Ref',
 'EEG C4-Ref',
 'EEG P3-Ref',
 'EEG P4-Ref',
 'EEG O1-Ref',
 'EEG O2-Ref',
 'EEG F7-Ref',
 'EEG F8-Ref',
 'EEG T3-Ref',
 'EEG T4-Ref',
 'EEG T5-Ref',
 'EEG T6-Ref',
 'EEG Fz-Ref',
 'EEG Cz-Ref',
 'EEG Pz-Ref',
 'POL E',
 'POL PG1',
 'POL PG2',
 'EEG A1-Ref',
 'EEG A2-Ref',
 'POL T1',
 'POL T2',
 'POL X1',
 'POL X2',
 'POL X3',
 'POL X4',
 'POL X5',
 'POL X6',
 'POL X7',
 'POL SpO2',
 'POL EtCO2',
 'POL DC03',
 'POL DC04',
 'POL DC05',
 'POL DC06',
 'POL Pulse',
 'POL CO2Wave',
 'POL $A1',
 'POL $A2']

In [7]:
data = raw.get_data()

In [8]:
import torch
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [27]:

class STCGRU(nn.Module):
    def __init__(self):
        super(STCGRU, self).__init__()
        # 大卷积层
        self.cnn_layer_1_large = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=1,
                    kernel_size=(1, 77), stride=(1,3),bias=False),
            nn.BatchNorm2d(num_features=1),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=(1, 2), stride=(1, 2)),
            
            nn.Conv2d(in_channels=1, out_channels=1,
                    kernel_size=(1, 39), stride=(1,3), bias=False),
            nn.BatchNorm2d(num_features=1),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=(1, 2), stride=(1, 2)),
        )

        # 小卷积层
        self.cnn_layer_1_small = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=1,
                    kernel_size=(1, 21), stride=(1,3), bias=False),
            nn.BatchNorm2d(num_features=1),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=(1, 2), stride=(1, 2)),
            
            nn.Conv2d(in_channels=1, out_channels=1,
                    kernel_size=(1,11), stride=(1,3), bias=False),
            nn.BatchNorm2d(num_features=1),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=(1, 2), stride=(1, 2)),
        )

        # 一维卷积层 SCNN
        self.cnn_layer_2 = nn.Sequential(
            nn.Conv1d(in_channels=19, out_channels=19,
                kernel_size=1, stride=1, bias=False),
            nn.BatchNorm1d(num_features=19),
            nn.ReLU(),
            nn.Conv1d(in_channels=19, out_channels=38,
                kernel_size=1, stride=1, bias=False),
            nn.BatchNorm1d(num_features=38),
            nn.ReLU(),
            nn.Conv1d(in_channels=38, out_channels=76,
                    kernel_size=1, stride=1, bias=False),
            nn.BatchNorm1d(num_features=76),
            nn.ReLU(),
        )

        self.fc_layer_1 = nn.Linear(58, 60)
        self.gru = nn.GRU(input_size=76, hidden_size=32,
                        num_layers=1, batch_first=True, bidirectional=True)
        self.flatten = nn.Flatten()
        self.fc_layer_2 = nn.Linear(32 * 2, 2)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # 通过两个卷积层
        out1 = self.cnn_layer_1_large(torch.unsqueeze(x, dim=1))
        out2 = self.cnn_layer_1_small(torch.unsqueeze(x, dim=1))

        # 确保尺寸匹配
        print("out1 shape:", out1.shape)  # 检查 out1 形状
        print("out2 shape:", out2.shape)  # 检查 out2 形状

        # 合并输出
        out = torch.cat((out1, out2), dim=3)

        # 改变维度以匹配 cnn_layer_2 的输入
        out = out.view(out.size(0), out.size(2), -1)

        out = self.cnn_layer_2(out)

        # 全连接层和GRU
        out = self.fc_layer_1(out)
        out = out.permute(0, 2, 1)
        h0 = torch.zeros(1 * 2, x.size(0), 32).to(device)

        # GRU 层
        _, hn = self.gru(out, h0)

        # 展平和输出
        out = self.flatten(hn.permute(1, 0, 2))
        out = self.fc_layer_2(out)
        out = self.sigmoid(out)
        return out


In [28]:
import torch
stcgru = STCGRU().to(device)
test = torch.randn( 1, 19, 1250).to(device)
output = stcgru(test)

out1 shape: torch.Size([1, 1, 19, 26])
out2 shape: torch.Size([1, 1, 19, 32])


In [29]:
print(stcgru)

STCGRU(
  (cnn_layer_1_large): Sequential(
    (0): Conv2d(1, 1, kernel_size=(1, 77), stride=(1, 3), bias=False)
    (1): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): AvgPool2d(kernel_size=(1, 2), stride=(1, 2), padding=0)
    (4): Conv2d(1, 1, kernel_size=(1, 39), stride=(1, 3), bias=False)
    (5): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): AvgPool2d(kernel_size=(1, 2), stride=(1, 2), padding=0)
  )
  (cnn_layer_1_small): Sequential(
    (0): Conv2d(1, 1, kernel_size=(1, 21), stride=(1, 3), bias=False)
    (1): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): AvgPool2d(kernel_size=(1, 2), stride=(1, 2), padding=0)
    (4): Conv2d(1, 1, kernel_size=(1, 11), stride=(1, 3), bias=False)
    (5): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): AvgPool2d(kernel_s