In [1]:
import os
import sys
import random
import mne
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
from pytorch_metric_learning import losses
from torchsummary import summary

from sklearn.preprocessing import LabelEncoder, OneHotEncoder, LabelBinarizer, normalize
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.decomposition import PCA
from imblearn.over_sampling import SMOTE
from sklearn.model_selection import KFold, StratifiedKFold
from scipy.stats import mode

import pandas as pd
import logging
import time
import math
import json
from collections import Counter

from datetime import datetime, timezone

from spikingjelly.activation_based import neuron, functional, surrogate, layer

from einops import rearrange, repeat, einsum

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = "cpu"

In [2]:
# f = mne.io.read_raw_edf("./Yingxf/Yingxf.edf")
# f = mne.io.read_raw_edf("./Liucx/Liucx.edf")

data_path = "./SJTU/Dingdw/Dingdw.edf"
raw_file = mne.io.read_raw_edf(data_path)
seiz_start_time = [
    datetime(2023, 5, 19, 21, 7, 16, tzinfo=timezone.utc),
    datetime(2023, 5, 19, 22, 9, 58, tzinfo=timezone.utc),
    datetime(2023, 5, 19, 23, 47, 8, tzinfo=timezone.utc),
    datetime(2023, 5, 20, 0, 25, 59, tzinfo=timezone.utc),
    datetime(2023, 5, 20, 1, 2, 11, tzinfo=timezone.utc),
    datetime(2023, 5, 20, 1, 43, 13, tzinfo=timezone.utc),
    datetime(2023, 5, 20, 4, 48, 44, tzinfo=timezone.utc),
    datetime(2023, 5, 20, 6, 26, 6, tzinfo=timezone.utc)
]

seiz_end_time = [
    datetime(2023, 5, 19, 21, 9, 0, tzinfo=timezone.utc),
    datetime(2023, 5, 19, 22, 11, 18, tzinfo=timezone.utc),
    datetime(2023, 5, 19, 23, 49, 57, tzinfo=timezone.utc),
    datetime(2023, 5, 20, 0, 27, 19, tzinfo=timezone.utc),
    datetime(2023, 5, 20, 1, 4, 19, tzinfo=timezone.utc),
    datetime(2023, 5, 20, 1, 53, 56, tzinfo=timezone.utc),
    datetime(2023, 5, 20, 4, 54, 39, tzinfo=timezone.utc),
    datetime(2023, 5, 20, 6, 28, 4, tzinfo=timezone.utc)
]


# f = mne.io.read_raw_edf("./Chenzh/Chenzh.edf")
# f = mne.io.read_raw_edf("./Dingdw/Dingdw.edf")
# f = mne.io.read_raw_edf("./Dingdw/Suny.edf")
# f = mne.io.read_raw_edf("./Guljy/Guljy.edf")
# f = mne.io.read_raw_edf("./Xuyx/Xuyx.edf")
# f = mne.io.read_raw_edf("./Huangfs/Huangfs.edf")
# f = mne.io.read_raw_edf("./Majp/Majp.edf")

Extracting EDF parameters from E:\EEG\SJTU\Dingdw\Dingdw.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  raw_file = mne.io.read_raw_edf(data_path)
  raw_file = mne.io.read_raw_edf(data_path)


In [3]:
# https://blog.csdn.net/cskywit/article/details/137448871
# https://github.com/johnma2006/mamba-minimal/blob/master/model.py

class MambaBlock(nn.Module):
    def __init__(self, input_channels):
        """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1]."""
        super().__init__()

        self.d_model = input_channels
        self.d_inner = self.d_model * 2
        self.dt_rank = math.ceil(self.d_model / 16)
        self.d_state = 16
            
        self.in_proj = nn.Linear(self.d_model, self.d_inner * 2)#, bias=args.bias)

        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            # bias=args.conv_bias,
            kernel_size=3,
            groups=self.d_inner,
            padding=2,
        )

        # x_proj takes in `x` and outputs the input-specific Δ, B, C
        self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False)
        
        # dt_proj projects Δ from dt_rank to d_in
        self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)

        A = repeat(torch.arange(1, self.d_state + 1), 'n -> d n', d=self.d_inner)
        self.A_log = nn.Parameter(torch.log(A))
        self.D = nn.Parameter(torch.ones(self.d_inner))
        self.out_proj = nn.Linear(self.d_inner, self.d_model) #, bias=args.bias)
        

    def forward(self, x):
        """Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1].
    
        Args:
            x: shape (b, l, d)    (See Glossary at top for definitions of b, l, d_in, n...)
    
        Returns:
            output: shape (b, l, d)
        
        Official Implementation:
            class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
            mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
            
        """
        (b, l, d) = x.shape
        
        x_and_res = self.in_proj(x)  # shape (b, l, 2 * d_in)
        (x, res) = x_and_res.split(split_size=[self.d_inner, self.d_inner], dim=-1)

        x = rearrange(x, 'b l d_in -> b d_in l')
        x = self.conv1d(x)[:, :, :l]
        x = rearrange(x, 'b d_in l -> b l d_in')
        
        x = F.silu(x)

        y = self.ssm(x)
        
        y = y * F.silu(res)
        
        output = self.out_proj(y)

        return output

    
    def ssm(self, x):
        """Runs the SSM. See:
            - Algorithm 2 in Section 3.2 in the Mamba paper [1]
            - run_SSM(A, B, C, u) in The Annotated S4 [2]

        Args:
            x: shape (b, l, d_in)    (See Glossary at top for definitions of b, l, d_in, n...)
    
        Returns:
            output: shape (b, l, d_in)

        Official Implementation:
            mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
            
        """
        (d_in, n) = self.A_log.shape

        # Compute ∆ A B C D, the state space parameters.
        #     A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
        #     ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
        #                                  and is why Mamba is called **selective** state spaces)
        
        A = -torch.exp(self.A_log.float())  # shape (d_in, n)
        D = self.D.float()

        x_dbl = self.x_proj(x)  # (b, l, dt_rank + 2*n)
        
        (delta, B, C) = x_dbl.split(split_size=[self.dt_rank, n, n], dim=-1)  # delta: (b, l, dt_rank). B, C: (b, l, n)
        delta = F.softplus(self.dt_proj(delta))  # (b, l, d_in)
        
        y = self.selective_scan(x, delta, A, B, C, D)  # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2]
        
        return y

    
    def selective_scan(self, u, delta, A, B, C, D):
        """Does selective scan algorithm. See:
            - Section 2 State Space Models in the Mamba paper [1]
            - Algorithm 2 in Section 3.2 in the Mamba paper [1]
            - run_SSM(A, B, C, u) in The Annotated S4 [2]

        This is the classic discrete state space formula:
            x(t + 1) = Ax(t) + Bu(t)
            y(t)     = Cx(t) + Du(t)
        except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t).
    
        Args:
            u: shape (b, l, d_in)    (See Glossary at top for definitions of b, l, d_in, n...)
            delta: shape (b, l, d_in)
            A: shape (d_in, n)
            B: shape (b, l, n)
            C: shape (b, l, n)
            D: shape (d_in,)
    
        Returns:
            output: shape (b, l, d_in)
    
        Official Implementation:
            selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86
            Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly.
            
        """
        (b, l, d_in) = u.shape
        n = A.shape[1]
        
        # Discretize continuous parameters (A, B)
        # - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper [1])
        # - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors:
        #   "A is the more important term and the performance doesn't change much with the simplification on B"
        deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
        deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')
        
        # Perform selective scan (see scan_SSM() in The Annotated S4 [2])
        # Note that the below is sequential, while the official implementation does a much faster parallel scan that
        # is additionally hardware-aware (like FlashAttention).
        x = torch.zeros((b, d_in, n), device=deltaA.device)
        ys = []    
        for i in range(l):
            x = deltaA[:, i] * x + deltaB_u[:, i]
            y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
            ys.append(y)
        y = torch.stack(ys, dim=1)  # shape (b, l, d_in)
        
        y = y + u * D
    
        return y

class RMSNorm(nn.Module):
    def __init__(self,
                 d_model: int,
                 eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))


    def forward(self, x):
        output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight

        return output

In [4]:
def compute_overlap(interval1, interval2):
    # Determine if there is an intersection
    if interval1[1] < interval2[0] or interval2[1] < interval1[0]:
        return 0
    
    # compute the start and the end of the overlap
    start = max(interval1[0], interval2[0])
    end = min(interval1[1], interval2[1])
    
    # compute the size of overlap
    overlap = end - start + 1
    
    return overlap

def check_overlap(interval1_np_list, interval2):
    # Determin if there is an intersection of length 0.1 s
    flag = True
    for i in range( len(interval1_np_list[0]) ):
        if compute_overlap([interval1_np_list[0][i], interval1_np_list[1][i]], interval2) >= NEW_SAMP_RATE/10:
            return i

    return -1



def get_names( ch_num, l ):
    name_list = []
    for i in range(ch_num):
        name_list.append( l[i]['name'] )
        # print(l[i]['name'][0])
    return name_list


def top_n_elements(lists, n=6):
    # 创建一个 Counter 对象以统计所有子列表中的元素
    overall_counter = Counter()
    
    # 遍历每个子列表并更新计数
    for sublist in lists:
        overall_counter.update(get_names(n, sublist))

    # 获取前 n 名元素
    top_elements_with_counts = overall_counter.most_common(n)

    top_elements = [element for element, count in top_elements_with_counts]
    counts = [count for element, count in top_elements_with_counts]
    
    return top_elements, counts


In [5]:
class ListMerger:
    def __init__(self, list1, list2):
        self.list1 = list1
        self.list2 = list2

    def __getitem__(self, index):
        if isinstance(index, slice):
            # 处理切片
            return [self[i] for i in range(*index.indices(len(self)))]
        elif index < len(self.list1):
            return self.list1[index]
        else:
            return self.list2[index - len(self.list1)]

    def __len__(self):
        # 返回合并后的列表的总长度
        return len(self.list1) + len(self.list2)

    def __iter__(self):
        return iter(self.list1 + self.list2)

class MergedDataset(Dataset):
    def __init__(self, merger, labels, indices):
        """
        :param merger: ListMerger 实例，包含合并后的数据
        :param indices: 要使用的索引列表
        """
        self.merger = merger
        self.labels = labels
        self.indices = indices

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        # 获取索引
        index = self.indices[idx]
        # 返回合并后的特征和标签
        return self.merger[index], self.labels[index]

def do_overlap( original_seq, seg_len, overlap_len ):
    '''
    Input:
        original_seq: 2d tensor, [ch, seq_len]
        overlap_len: float, overlapping length of 2 successive segments
    Outputs:
        res: a list of tensors
    '''
    res = []

    for i in range( 0, original_seq.shape[1], seg_len-overlap_len ):
        if original_seq.shape[1] - i < seg_len:
            break
        res.append( original_seq[:, i:(i+seg_len)] )

    return res

def read_SJTU_data( raw, seiz_start_time, seiz_end_time, COMMON_CH, SEG_TIME=4, OVERLAP=True ):
    # f = mne.io.read_raw_edf(data_folder)
    f = raw.copy()
    f.pick( COMMON_CH )
    data = f.get_data()
    data = data * 1e4
    
    meas_start_time = f.info['meas_date']#.strftime('%Y-%m-%d %H:%M:%S')
    sfreq = f.info['sfreq']
    all_channels = f.info['ch_names'][:-1] # drop last channels 'MK'
    
    SPH = int(30 * 60 * sfreq)
    SEG_LEN = int( SEG_TIME * sfreq )

    segments_pre = []
    segments_inter = []

    total_len_pre = 0
    total_len_inter = 0

    last_end = -1
    seiz_start_bias_time = []
    seiz_end_bias_time = []
    for i in range( len(seiz_start_time) ):
        seiz_start_bias_time.append( int((seiz_start_time[i] - meas_start_time).total_seconds() * sfreq) )
        seiz_end_bias_time.append( int((seiz_end_time[i] - meas_start_time).total_seconds() * sfreq) )

        beg = seiz_start_bias_time[i] - SPH
        end = seiz_start_bias_time[i]
        # print(beg, end)
        if beg < 0:
            segments_pre.append( torch.tensor(data[:, 0:seiz_start_bias_time[i]], dtype=torch.float32) )
            total_len_pre += seizure_start[i]
        else:
            segments_pre.append( torch.tensor(data[:, beg:end+1], dtype=torch.float32) )
            total_len_pre += segments_pre[-1].shape[1]
            segments_inter.append( torch.tensor(data[:, (last_end+1) : beg], dtype=torch.float32) )
            total_len_inter += segments_inter[-1].shape[1]
        last_end = seiz_end_bias_time[i]
        
    segments_inter.append( torch.tensor(data[:, (last_end+1):], dtype=torch.float32) ) 
    total_len_inter += segments_inter[-1].shape[1]
    # seiz_start_bias_time, seiz_end_bias_time
    
    segments_pre_overlap = []
    segments_inter_no_overlap = []
    step_for_overlap = int( SEG_LEN / ( total_len_inter / total_len_pre ) )
    print("step_for_overlap:", step_for_overlap)
    print(total_len_inter, total_len_pre)
    while segments_inter or segments_pre:
        if segments_pre:
            # segments_pre_overlap.append( do_overlap( seg_pre, SEG_LEN, OVERLAP_TIME*NEW_SAMP_RATE ) )
            if OVERLAP:
                segments_pre_overlap.extend( do_overlap( segments_pre[0], SEG_LEN, SEG_LEN-step_for_overlap) ) #OVERLAP_TIME*NEW_SAMP_RATE ) )
            else:
                segments_pre_overlap.extend( do_overlap( segments_pre[0], SEG_LEN, 0) )
            segments_pre.pop(0)
            # print( len(segments_pre_overlap) )
        if segments_inter:
            # segments_inter_no_overlap.append( do_overlap( segments_inter[0], SEG_LEN, OVERLAP_TIME*NEW_SAMP_RATE ) )
            segments_inter_no_overlap.extend( do_overlap( segments_inter[0], SEG_LEN, 0 ) )
            segments_inter.pop(0)
            # print( len(segments_inter_no_overlap) )
    
    
    return segments_inter_no_overlap, segments_pre_overlap

    

In [6]:
## https://www.kaggle.com/code/debarshichanda/pytorch-supervised-contrastive-learning
class SupervisedContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.1):
        super(SupervisedContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, feature_vectors, labels):
        # Normalize feature vectors
        feature_vectors_normalized = F.normalize(feature_vectors, p=2, dim=1)
        # Compute logits
        logits = torch.div(
            torch.matmul(
                feature_vectors_normalized, torch.transpose(feature_vectors_normalized, 0, 1)
            ),
            self.temperature,
        )
        # print( logits.shape, labels.shape )
        return losses.NTXentLoss(temperature=self.temperature)(logits, labels)
# super_contras = SupervisedContrastiveLoss(0.1)
# super_contras( x, label )

In [7]:


def count_label_np(one_hot_data, label):
    labels = np.argmax(one_hot_data, axis=1)  # 将one-hot编码转换为标签
    count = np.sum(labels == label)  # 统计标签出现的次数
    return count  

def calculate_accuracy_np(y_true, y_pred):
    labels_true = y_true #np.argmax(y_true, axis=1)  # 将真实标签转换为标签
    labels_pred = y_pred #np.argmax(y_pred, axis=1)  # 将预测结果转换为标签
    accuracy = np.sum(labels_true.T == labels_pred) / len(labels_true)  # 计算准确率
    return accuracy

def calculate_sensitivity_np(y_true, y_pred, class_index):
    # labels_true = np.argmax(y_true, axis=1)  # 将真实标签转换为标签
    # labels_pred = np.argmax(y_pred, axis=1)  # 将预测结果转换为标签
    labels_true = y_true #np.argmax(y_true, axis=1)  # 将真实标签转换为标签
    labels_pred = y_pred #np.argmax(y_pred, axis=1)  # 将预测结果转换为标签
    positive_labels = (labels_true == class_index)  # 将指定类别的标签设为1，其他类别设为0
    true_positive = np.sum((labels_pred == labels_true) * positive_labels)  # 统计真阳性的数量
    false_negative = np.sum((labels_pred != labels_true) * positive_labels)  # 统计假阴性的数量
    if (true_positive + false_negative) == 0:
        return -1
    sensitivity = true_positive / (true_positive + false_negative)  # 计算灵敏度
    return sensitivity

In [8]:
# 1 dimentional version + Mamba
## Patient-Specific Seizure Prediction via Adder Network and Supervised Contrastive Learning
## ADDNet-SCL-1d

class OneDCNN(nn.Module):
    def __init__(self, input_channels=3):
        super(OneDCNN, self).__init__()
        self.input_channels = input_channels
      
        self.conv1 = nn.Sequential(
            nn.Conv1d(in_channels=input_channels, out_channels=16, kernel_size=21, stride=1, padding=10),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=8, stride=8)
        )

        self.conv2_1 = nn.Sequential(
            nn.Conv1d(in_channels=16, out_channels=16, kernel_size=1, stride=1),
            nn.ReLU()
        )
        self.conv2_2 = nn.Sequential(
            nn.Conv1d(in_channels=16, out_channels=16, kernel_size=11, stride=1, padding=5),
            nn.ReLU(),
            nn.Conv1d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )

        self.pool3 = nn.MaxPool1d(kernel_size=4, stride=4)

        self.conv4_1 = nn.Sequential(
            nn.Conv1d(in_channels=16, out_channels=32, kernel_size=1, stride=2),
            nn.ReLU()
        )
        self.conv4_2 = nn.Sequential(
            nn.Conv1d(in_channels=16, out_channels=32, kernel_size=5, stride=2, padding=2),
            nn.ReLU(),
            nn.Conv1d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )

        # self.conv5_1 = nn.Sequential(
        #     nn.Conv1d(in_channels=32, out_channels=32, kernel_size=1, stride=2),
        #     nn.ReLU()
        # )
        # self.conv5_2 = nn.Sequential(
        #     nn.Conv1d(in_channels=32, out_channels=32, kernel_size=5, stride=2, padding=2),
        #     nn.ReLU(),
        #     nn.Conv1d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
        #     nn.ReLU()
        # )

        self.mixer = MambaBlock(32)
        self.norm = RMSNorm(32)

        self.adaptive_avg_pool = nn.AdaptiveAvgPool1d(output_size=1)
        
        self.classifier = nn.Sequential(
            nn.Linear(32, 2)
        )

    def forward(self, x):
        # x = x.unsqueeze(1)
        # print(x.shape)
        
        x = self.conv1(x)

        x = self.pool3( self.conv2_1(x)+self.conv2_2(x) )

        x = self.conv4_1(x)+self.conv4_2(x)

        # x = self.conv5_1(x)+self.conv5_2(x)

        # x = [batch, channels, seq_len]
        x = x.permute(0, 2, 1) # [batch, seq_len, channels]
        x = self.mixer(self.norm(x)) + x
        x = x.permute(0, 2, 1) # [batch, channels, seq_len]

        x = self.adaptive_avg_pool( x )
        
        x_digits = x.contiguous().view(x.size(0), -1)
        # print(f'x.shape = {x.shape}')
        x_res = self.classifier(x_digits)
        return x_res, x_digits


In [9]:


print('SJTU/'+os.path.basename(data_path).split('.')[0]+'_sel_ch_30iter_with_SMOTE.json')

with open('SJTU/'+os.path.basename(data_path).split('.')[0]+'_sel_ch_30iter_with_SMOTE.json', 'r') as json_file:
    loaded_data = json.load(json_file)
COMMON_CH, _ = top_n_elements( loaded_data, 8 ) # 8 channels are acceptable

log_file_name = 'SJTU/'+os.path.basename(data_path).split('.')[0]+'_'+str(len(COMMON_CH))+'chs_log_with_mamba_10fold.txt'
with open(log_file_name, "w") as f:
    f.write("AddNet-SCL With Mamba 1D - 10-fold\n" + str(len(COMMON_CH)) + " channels\n" )



# split the raw data to 4 sec segments
SEG_TIME = 4

logger = logging.getLogger('mne')
# logger.setLevel(logging.WARNING)
logger.setLevel(logging.ERROR) # 每读一个edf文件都会有一个通道名重复的警告

    
segments_inter, segments_pre = read_SJTU_data( raw_file, seiz_start_time, seiz_end_time, COMMON_CH, SEG_TIME=SEG_TIME, OVERLAP=True )

segments = ListMerger( segments_inter, segments_pre )
seg_labels = [0] * len(segments_inter) + [1] * len(segments_pre)
seg_labels = torch.nn.functional.one_hot(torch.tensor(seg_labels), num_classes=2)


kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
test_acc = []
test_sens = []
test_spec = []
for i, (train_index, test_index) in enumerate(kf.split(segments, seg_labels.argmax(axis=1))):
    # 获取训练和测试数据，都是np.array
    X_train_indices = train_index
    X_test_indices = test_index

    # 从训练集中划分出验证集
    val_size = int(len(X_train_indices) * 0.2)  # 20% 用作验证集
    val_indices = np.random.choice(X_train_indices, size=val_size, replace=False)
    train_indices = np.setdiff1d(X_train_indices, val_indices)

    # 创建数据集
    train_dataset = MergedDataset(segments, seg_labels, train_indices)
    val_dataset = MergedDataset(segments, seg_labels, val_indices)
    test_dataset = MergedDataset(segments, seg_labels, X_test_indices)

    # 创建数据加载器
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)#, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)#, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)#, pin_memory=True)
    # break

    best_test_accuracy = 0
    best_epoch = -1
    sens_p = 0
    spec_p = 0
    
    
    model = OneDCNN(len(COMMON_CH)).to(device)

    # criterion = nn.MSELoss().to(device)
    # criterion = nn.BCELoss().to(device)
    criterion = nn.CrossEntropyLoss().to(device)
    super_contras = SupervisedContrastiveLoss(0.08).to(device)
    # criterion = WeightedMSELoss(4).to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
    
    num_epochs = 30
    # VOTE_NUM = 5
    y_test = np.array(seg_labels[X_test_indices])
    y_test = np.argmax( y_test, axis=1)
    y_val = np.array(seg_labels[val_indices])
    y_val = np.argmax( y_val, axis=1)
    # y_test_flat = y_test[VOTE_NUM - 1:]
    for epoch in range(num_epochs):
        start_time = time.time()
        # 训练阶段
        model.train()
        
        train_loss = 0.0
        train_loss_cross = 0.0
        train_loss_cont  = 0.0
        
        correct = 0
        for interation, (inputs, targets) in enumerate(train_loader):
            # print(interation)
            
            inputs = inputs.to(device)
            targets = targets.to(device).float()
            outputs, feats = model(inputs)
            loss_cross = criterion(outputs, targets)
            tars = targets.argmax(dim=1)
            # print(tars.shape)
            loss_cont = super_contras( feats, tars )
            loss = 0.5 * loss_cross + 0.5 * loss_cont
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # functional.reset_net( model )
    
            # print(outputs.shape, targets.shape)
            predicted = torch.argmax(outputs, dim=1) #(outputs > 0.5).float()
            targets = torch.argmax(targets, dim=1) #(targets>0.5).to(device)
            correct += (predicted == targets).sum().item()
    
            train_loss += loss.item()
            train_loss_cross += loss_cross.item()
            train_loss_cont  += loss_cont.item()
               
        train_loss /= len(train_loader)
        train_loss_cross /= len(train_loader)
        train_loss_cont /= len(train_loader)
        train_acc = correct / len(train_dataset)
    
    
        # 验证阶段
        model.eval()
        
        val_loss = 0.0
        val_loss_cross = 0.0
        val_loss_cont  = 0.0
        
        correct = 0
        total = 0
        TP = 0
        FN = 0
        TN = 0
        FP = 0
        predictions = []
        # targets = []
        with torch.inference_mode():
            for inputs, targets in val_loader:
                inputs = inputs.to(device)
                targets = targets.to(device).float()
                outputs, feats = model(inputs)
                predicted = torch.argmax(outputs, dim=1) #(outputs > 0.5).float()
                loss_cross = criterion(outputs, targets)
                tars = targets.argmax(dim=1)
                loss_cont = super_contras( feats, tars )
                loss = loss_cross + loss_cont
                
                val_loss += loss.item()
                val_loss_cross += loss_cross.item()
                val_loss_cont += loss_cont.item()
                
                predictions.extend( predicted.cpu() )
                
                # functional.reset_net( model )
    
            predictions = np.array( predictions )
            # final_predictions_flat = np.array( [ mode(predictions[ i-(VOTE_NUM-1) : i+1 ]).mode for i in range(VOTE_NUM-1, len(predictions)) ] )
            
            # 计算混淆矩阵的四个基本元素
            TP = np.sum((predictions == 1) & (y_val == 1))  # True Positives
            TN = np.sum((predictions == 0) & (y_val == 0))  # True Negatives
            FP = np.sum((predictions == 1) & (y_val == 0))  # False Positives
            FN = np.sum((predictions == 0) & (y_val == 1))  # False Negatives
            
                
                
        val_loss /= len(val_loader)
        val_loss_cross /= len(val_loader)
        val_loss_cont  /= len(val_loader)
        
        test_accuracy = (TP + TN) / (TP + TN + FP + FN)
        sens_p_epoch = TP / (TP + FN) if (TP+FN != 0 ) else 0
        spec_p_epoch = TN / (TN + FP) if (TN+FP != 0 ) else 0
    
        
        if test_accuracy > best_test_accuracy: # sens_p_epoch > sens_p:  
            best_test_accuracy = test_accuracy
            best_epoch = epoch
            sens_p = TP / (TP + FN) if (TP+FN != 0 ) else 0
            spec_p = TN / (TN + FP) if (TN+FP != 0 ) else 0
            torch.save(model, 'SJTU/'+os.path.basename(data_path).split('.')[0]+str(best_epoch)+'.pt')
        end_time = time.time()
        time_cost = end_time - start_time
        # print(f'\nEpoch [{epoch+1}/{num_epochs}], time cost: {time_cost:.4f},\n Train Cross Loss: {train_loss_cross:.4f}, Train Contras Loss: {train_loss_cont:.4f} Train acc: {train_acc: .4f},\n Val Cross loss: {val_loss_cross:.4f}, Val Contras loss: {val_loss_cont:.4f}, Test accuracy: {test_accuracy:.4f}, Sensitivity: {sens_p_epoch:.4f}, Specificity: {spec_p_epoch:.4f}, Test acc vote: {test_acc_vote:.4f}, Sens vote: {sens_p_epoch_vote:.4f}, Spec vote: {spec_p_epoch_vote:.4f}')
        training_info = f'\nFold [{i+1}/10] Epoch [{epoch+1}/{num_epochs}], time cost: {time_cost:.4f},\n Train Cross Loss: {train_loss_cross:.4f}, Train Contras Loss: {train_loss_cont:.4f} Train acc: {train_acc: .4f},\n Val Cross loss: {val_loss_cross:.4f}, Val Contras loss: {val_loss_cont:.4f}, Val accuracy: {test_accuracy:.4f}, Sensitivity: {sens_p_epoch:.4f}, Specificity: {spec_p_epoch:.4f}'
        print(training_info)
        with open(log_file_name, 'a') as f:
            f.write( training_info+'\n' )
    

    train_summary = ', '.join(COMMON_CH)
    train_summary = train_summary + f'\n***********")\nBest epoch: {best_epoch}, Acc: {best_test_accuracy}, Sensitivity for pre-ictal: {sens_p}, Specificity: {spec_p}\n***********\n'
    with open(log_file_name, 'a') as f:
        f.write( train_summary )
    print(COMMON_CH)
    # print("***********\nSensitivity for pre-ictal:", sens_p, ", sensitivity for inter:", sens_inter)
    print("\n***********")
    print("Best epoch:", best_epoch, "Acc:", best_test_accuracy, "Sensitivity for pre-ictal:", sens_p, "Specificity:", spec_p)
    print("***********\n")


    model_for_test = torch.load( 'SJTU/'+os.path.basename(data_path).split('.')[0]+str(best_epoch)+'.pt' )
    model_for_test.eval()
    test_loss = 0.0
    correct = 0
    # total = 0
    TP = 0
    FN = 0
    TN = 0
    FP = 0
    predictions = []
    outputs_list = []
    with torch.inference_mode():
        for inputs, targets in test_loader:
            inputs = inputs.to(device)
            targets = targets.to(device).float()
    
            # print( inputs.shape, targets.shape )
            
            outputs, feats = model_for_test(inputs)
    
            # print( outputs.shape, targets )
            # break
            
            predicted = torch.argmax(outputs, dim=1) #(outputs >= 0.5).float()
            loss = criterion(outputs, targets)
            targets = torch.argmax(targets, dim=1)
            test_loss += loss.item()
            # total += targets.size(0)
            positive_labels = ( targets == 1 )
            correct += (predicted == targets).sum().item()
            TP += ( positive_labels * ( predicted == targets ) ).sum().item()
            FN += ( positive_labels * ( predicted != targets ) ).sum().item()
            # functional.reset_net( model )
    
            predictions.extend(predicted.cpu())
            outputs_list.extend(outputs.cpu())
            
            # print(predicted.shape, targets.shape)
            
            # break
    predictions = np.array( predictions)
    TP = np.sum((predictions == 1) & (y_test == 1))  # True Positives
    TN = np.sum((predictions == 0) & (y_test == 0))  # True Negatives
    FP = np.sum((predictions == 1) & (y_test == 0))  # False Positives
    FN = np.sum((predictions == 0) & (y_test == 1))  # False Negatives
        
    
    test_loss /= len(test_loader)
    test_accuracy = correct / len(test_dataset)
    test_sensitivity = TP / (TP + FN) if (TP+FN != 0 ) else 0
    test_specificity = TN / (TN + FP) if (TN+FP != 0 ) else 0

    test_info = f'Test Loss: {test_loss:.4f}, Test accuracy: {test_accuracy:.4f}, Sensitivity: {test_sensitivity:.4f}, Specificity: {test_specificity:.4f}'
    with open(log_file_name, 'a') as f:
        f.write( test_info )
    print(test_info)

    test_acc.append( test_accuracy )
    test_sens.append( test_sensitivity )
    test_spec.append( test_specificity )

    # break # only use 80-20 split for channel experiments

avg_test_info = f'\n\nAvg test acc: {sum(test_acc)/len(test_acc):.4f}, avg test sensitivity: {sum(test_sens)/len(test_sens):.4f}, avg test specificity: {sum(test_spec)/len(test_spec):.4f} \n'
with open(log_file_name, 'a') as f:
    f.write( avg_test_info )
print( avg_test_info )
# print( "Avg test acc:", sum(test_acc)/len(test_acc), "avg test sensitivity:", sum(test_sens)/len(test_sens), "avg test specificity", sum(test_spec)/len(test_spec) )

SJTU/Dingdw_sel_ch_30iter_with_SMOTE.json
step_for_overlap: 119
7879544 1843208

Fold [1/10] Epoch [1/30], time cost: 15.8769,
 Train Cross Loss: 0.2893, Train Contras Loss: 1.5946 Train acc:  0.8910,
 Val Cross loss: 0.1844, Val Contras loss: 1.0789, Val accuracy: 0.9447, Sensitivity: 0.9546, Specificity: 0.9347

Fold [1/10] Epoch [2/30], time cost: 14.0364,
 Train Cross Loss: 0.1565, Train Contras Loss: 0.9023 Train acc:  0.9481,
 Val Cross loss: 0.1253, Val Contras loss: 0.6947, Val accuracy: 0.9593, Sensitivity: 0.9739, Specificity: 0.9445

Fold [1/10] Epoch [3/30], time cost: 13.8811,
 Train Cross Loss: 0.1332, Train Contras Loss: 0.7923 Train acc:  0.9537,
 Val Cross loss: 0.1358, Val Contras loss: 0.8148, Val accuracy: 0.9530, Sensitivity: 0.9818, Specificity: 0.9238

Fold [1/10] Epoch [4/30], time cost: 14.5079,
 Train Cross Loss: 0.1174, Train Contras Loss: 0.6930 Train acc:  0.9589,
 Val Cross loss: 0.1039, Val Contras loss: 0.6260, Val accuracy: 0.9587, Sensitivity: 0.9843, 