In [1]:
import os
import sys
import random
import mne
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
from pytorch_metric_learning import losses
from torchsummary import summary

from sklearn.preprocessing import LabelEncoder, OneHotEncoder, LabelBinarizer, normalize
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.decomposition import PCA
from imblearn.over_sampling import SMOTE
from sklearn.model_selection import KFold, StratifiedKFold
from scipy.stats import mode

import pandas as pd
import logging
import time
import math
import json
from collections import Counter

from spikingjelly.activation_based import neuron, functional, surrogate, layer

from einops import rearrange, repeat, einsum

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = "cpu"

In [2]:
# https://blog.csdn.net/cskywit/article/details/137448871
# https://github.com/johnma2006/mamba-minimal/blob/master/model.py

class MambaBlock(nn.Module):
    def __init__(self, input_channels):
        """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1]."""
        super().__init__()

        self.d_model = input_channels
        self.d_inner = self.d_model * 2
        self.dt_rank = math.ceil(self.d_model / 16)
        self.d_state = 16
            
        self.in_proj = nn.Linear(self.d_model, self.d_inner * 2)#, bias=args.bias)

        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            # bias=args.conv_bias,
            kernel_size=3,
            groups=self.d_inner,
            padding=2,
        )

        # x_proj takes in `x` and outputs the input-specific Δ, B, C
        self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False)
        
        # dt_proj projects Δ from dt_rank to d_in
        self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)

        A = repeat(torch.arange(1, self.d_state + 1), 'n -> d n', d=self.d_inner)
        self.A_log = nn.Parameter(torch.log(A))
        self.D = nn.Parameter(torch.ones(self.d_inner))
        self.out_proj = nn.Linear(self.d_inner, self.d_model) #, bias=args.bias)
        

    def forward(self, x):
        """Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1].
    
        Args:
            x: shape (b, l, d)    (See Glossary at top for definitions of b, l, d_in, n...)
    
        Returns:
            output: shape (b, l, d)
        
        Official Implementation:
            class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
            mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
            
        """
        (b, l, d) = x.shape
        
        x_and_res = self.in_proj(x)  # shape (b, l, 2 * d_in)
        (x, res) = x_and_res.split(split_size=[self.d_inner, self.d_inner], dim=-1)

        x = rearrange(x, 'b l d_in -> b d_in l')
        x = self.conv1d(x)[:, :, :l]
        x = rearrange(x, 'b d_in l -> b l d_in')
        
        x = F.silu(x)

        y = self.ssm(x)
        
        y = y * F.silu(res)
        
        output = self.out_proj(y)

        return output

    
    def ssm(self, x):
        """Runs the SSM. See:
            - Algorithm 2 in Section 3.2 in the Mamba paper [1]
            - run_SSM(A, B, C, u) in The Annotated S4 [2]

        Args:
            x: shape (b, l, d_in)    (See Glossary at top for definitions of b, l, d_in, n...)
    
        Returns:
            output: shape (b, l, d_in)

        Official Implementation:
            mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
            
        """
        (d_in, n) = self.A_log.shape

        # Compute ∆ A B C D, the state space parameters.
        #     A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
        #     ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
        #                                  and is why Mamba is called **selective** state spaces)
        
        A = -torch.exp(self.A_log.float())  # shape (d_in, n)
        D = self.D.float()

        x_dbl = self.x_proj(x)  # (b, l, dt_rank + 2*n)
        
        (delta, B, C) = x_dbl.split(split_size=[self.dt_rank, n, n], dim=-1)  # delta: (b, l, dt_rank). B, C: (b, l, n)
        delta = F.softplus(self.dt_proj(delta))  # (b, l, d_in)
        
        y = self.selective_scan(x, delta, A, B, C, D)  # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2]
        
        return y

    
    def selective_scan(self, u, delta, A, B, C, D):
        """Does selective scan algorithm. See:
            - Section 2 State Space Models in the Mamba paper [1]
            - Algorithm 2 in Section 3.2 in the Mamba paper [1]
            - run_SSM(A, B, C, u) in The Annotated S4 [2]

        This is the classic discrete state space formula:
            x(t + 1) = Ax(t) + Bu(t)
            y(t)     = Cx(t) + Du(t)
        except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t).
    
        Args:
            u: shape (b, l, d_in)    (See Glossary at top for definitions of b, l, d_in, n...)
            delta: shape (b, l, d_in)
            A: shape (d_in, n)
            B: shape (b, l, n)
            C: shape (b, l, n)
            D: shape (d_in,)
    
        Returns:
            output: shape (b, l, d_in)
    
        Official Implementation:
            selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86
            Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly.
            
        """
        (b, l, d_in) = u.shape
        n = A.shape[1]
        
        # Discretize continuous parameters (A, B)
        # - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper [1])
        # - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors:
        #   "A is the more important term and the performance doesn't change much with the simplification on B"
        deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
        deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')
        
        # Perform selective scan (see scan_SSM() in The Annotated S4 [2])
        # Note that the below is sequential, while the official implementation does a much faster parallel scan that
        # is additionally hardware-aware (like FlashAttention).
        x = torch.zeros((b, d_in, n), device=deltaA.device)
        ys = []    
        for i in range(l):
            x = deltaA[:, i] * x + deltaB_u[:, i]
            y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
            ys.append(y)
        y = torch.stack(ys, dim=1)  # shape (b, l, d_in)
        
        y = y + u * D
    
        return y

class RMSNorm(nn.Module):
    def __init__(self,
                 d_model: int,
                 eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))


    def forward(self, x):
        output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight

        return output

In [3]:
def compute_overlap(interval1, interval2):
    # Determine if there is an intersection
    if interval1[1] < interval2[0] or interval2[1] < interval1[0]:
        return 0
    
    # compute the start and the end of the overlap
    start = max(interval1[0], interval2[0])
    end = min(interval1[1], interval2[1])
    
    # compute the size of overlap
    overlap = end - start + 1
    
    return overlap

def check_overlap(interval1_np_list, interval2):
    # Determin if there is an intersection of length 0.1 s
    flag = True
    for i in range( len(interval1_np_list[0]) ):
        if compute_overlap([interval1_np_list[0][i], interval1_np_list[1][i]], interval2) >= NEW_SAMP_RATE/10:
            return i

    return -1



def get_names( ch_num, l ):
    name_list = []
    for i in range(ch_num):
        name_list.append( l[i]['name'][0] )
        # print(l[i]['name'][0])
    return name_list


def top_n_elements(lists, n=6):
   
    overall_counter = Counter()
    
 
    for sublist in lists:
        overall_counter.update(get_names(n, sublist))

 
    top_elements_with_counts = overall_counter.most_common(n)

    top_elements = [element for element, count in top_elements_with_counts]
    counts = [count for element, count in top_elements_with_counts]
    
    return top_elements, counts


(4,)

In [5]:
class ListMerger:
    def __init__(self, list1, list2):
        self.list1 = list1
        self.list2 = list2

    def __getitem__(self, index):
        if isinstance(index, slice):
            # 处理切片
            return [self[i] for i in range(*index.indices(len(self)))]
        elif index < len(self.list1):
            return self.list1[index]
        else:
            return self.list2[index - len(self.list1)]

    def __len__(self):
        # 返回合并后的列表的总长度
        return len(self.list1) + len(self.list2)

    def __iter__(self):
        return iter(self.list1 + self.list2)

class MergedDataset(Dataset):
    def __init__(self, merger, labels, indices):
        """
        :param merger: ListMerger 实例，包含合并后的数据
        :param indices: 要使用的索引列表
        """
        self.merger = merger
        self.labels = labels
        self.indices = indices

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        # 获取索引
        index = self.indices[idx]
        # 返回合并后的特征和标签
        return self.merger[index], self.labels[index]

def do_overlap( original_seq, seg_len, overlap_len ):
    '''
    Input:
        original_seq: 2d tensor, [ch, seq_len]
        overlap_len: float, overlapping length of 2 successive segments
    Outputs:
        res: a list of tensors
    '''
    res = []

    for i in range( 0, original_seq.shape[1], seg_len-overlap_len ):
        if original_seq.shape[1] - i < seg_len:
            break
        res.append( original_seq[:, i:(i+seg_len)] )

    return res

def read_edf_data_separately( data_folder, patient_to_use_list, COMMON_CH, SEG_TIME ):
    # These files have no common channels with other files
    FILE_EXCLUDED = ['chb12_27.edf', 'chb12_28.edf', 'chb12_29.edf'] 

    SAMPLE_RATE = 256
    DOWN_SAMP_RATE = 1
    NEW_SAMP_RATE = int(SAMPLE_RATE / DOWN_SAMP_RATE)
    
    SEG_LEN = NEW_SAMP_RATE * SEG_TIME
    
    # STEP_FOR_OVERLAP = int(NEW_SAMP_RATE * (SEG_TIME - OVERLAP_TIME))
    # STEP_NO_OVERLAP  = SEG_LEN
    
    SPH = 30 * 60 * NEW_SAMP_RATE# Seizure Prediction Horizon (SPH)
    # PIL = (15+SPH) * 60 * NEW_SAMP_RATE# Pre-ictal interval (PIL)
    # POST_ICTAL_LEN = 15 * 60 * NEW_SAMP_RATE
    
    raws = []
    file_names = []
    labels = []
    sec_seizure = {}

    seizure_cnt = 0
    
    for patient_to_use in patient_to_use_list:
        patient_path_folder = os.path.join(data_folder, patient_to_use)
        # extract seizure information from summary
        summary_path = os.path.join(patient_path_folder, patient_to_use+"-summary.txt")
        with open(summary_path, 'r') as file:
            content = file.read()
        sections = content.split('\n\n')
        for section in sections:
            if ("Seizure Start Time" in section) or ("Seizure 1 Start Time" in section):
                lines = section.split('\n')
                temp_name = ''
                for i, line in enumerate(lines):
                    if "File Name:" in line:
                        temp_name = line.replace('File Name: ', '')
                        sec_seizure[temp_name] = {}
                        sec_seizure[temp_name]['start'] = []
                        sec_seizure[temp_name]['end'] = []
                    elif ("Seizure" in line) and ("Start Time" in line):
                        parts = line.split(":")
                        sec_seizure[temp_name]['start'].append( int(''.join(filter(str.isdigit, parts[1]))) )
                        parts_end = lines[i+1].split(":")
                        sec_seizure[temp_name]['end'].append( int(''.join(filter(str.isdigit, parts_end[1]))) )
                        seizure_cnt += 1
                        
                        if seizure_cnt == SEIZURE_FOR_TEST:
                            sec_seizure[temp_name]['for_test'] = len(sec_seizure[temp_name]['start'])-1

                        # print(sec_seizure[temp_name])
        # read edf data
        for file in os.listdir(patient_path_folder):
            if file.endswith('.edf') and file not in FILE_EXCLUDED:
                file_path = os.path.join(patient_path_folder, file)
                raw = mne.io.read_raw_edf(file_path)
                raws.append(raw.pick(COMMON_CH))
                file_names.append(file)
                if file in sec_seizure:
                    labels.append({**sec_seizure[file], 'name':file})
                else:
                    labels.append('')
            
    # split the data into 15 sec segments
    segments_inter = []
    segments_pre   = []
    # seg_labels = []
    
    one_patient_segs_inter = []
    one_patient_segs_pre = []
    # one_patient_labels = []
    total_len_inter = 0
    total_len_pre = 0
    
    for index, (raw, label, file_name) in enumerate(zip(raws, labels, file_names)):
        # raw.load_data()
        # raw.filter(1, 45, fir_design='firwin')
        data = raw.get_data()#[:, ::DOWN_SAMP_RATE]
        data = data * 1e4

        if label == '':
            one_patient_segs_inter.append( torch.tensor(data, dtype=torch.float32) ) 
            total_len_inter += data.shape[1]
        else:
            # print(label)
            # split data with 4 sec overlapping
            seizure_start = np.array( label['start'] ) * NEW_SAMP_RATE
            seizure_end   = np.array( label['end'] ) * NEW_SAMP_RATE
            # print( seizure_start-SPH, seizure_start)

            last_end = -1
            for sez_i in range( len( seizure_start ) ):
                beg = seizure_start[sez_i] - SPH
                end = seizure_start[sez_i]
                # print( beg, end )

                # If the gap between the start of seizure and start of the file is less than SPH,
                # you should manually modify the summary file first. Add the seizure start and end
                # time of current file to the last file but add the length of the last file and the 
                # gap between the end of the last file and the start of the current file
                if beg < 0:
                    # temp_data = one_patient_segs_inter.pop().numpy()
                    # total_len_inter -= temp_data.shape[1]
                    
                    # one_patient_segs_pre.append( torch.tensor(temp_data[:, beg:], dtype=torch.float32) )
                    # total_len_pre += one_patient_segs_pre[-1].shape[1]
                    # print( one_patient_segs_pre[-1].shape[1], SEG_LEN )
                    one_patient_segs_pre.append( torch.tensor(data[:, 0:seizure_start[sez_i]], dtype=torch.float32) )
                    total_len_pre += seizure_start[sez_i]

                    # one_patient_segs_inter.append( torch.tensor(temp_data[:, 0:-SEG_LEN], dtype=torch.float32) )
                    # total_len_inter += one_patient_segs_inter[-1].shape[1]
                else:
                    one_patient_segs_pre.append( torch.tensor(data[:, beg:end+1], dtype=torch.float32) )
                    total_len_pre += one_patient_segs_pre[-1].shape[1]
                    one_patient_segs_inter.append( torch.tensor(data[:, (last_end+1) : beg], dtype=torch.float32) )
                    total_len_inter += one_patient_segs_inter[-1].shape[1]
                
                last_end = seizure_end[sez_i]

            one_patient_segs_inter.append( torch.tensor(data[:, (last_end+1):], dtype=torch.float32) ) 
            total_len_inter += one_patient_segs_inter[-1].shape[1]

        # Finish one patient or finish the last file
        if (index != 0) and (file_name[0:5] != file_names[index-1][0:5]):
            # segments.append( one_patient_segs )
            # seg_labels.append( one_patient_labels )
            segments_inter.extend( one_patient_segs_inter )
            segments_pre.extend( one_patient_segs_pre )
            # seg_labels.extend( one_patient_labels )
            one_patient_segs_inter = []
            one_patient_segs_pre = []
            # one_patient_labels = []
            # print(index, file_name, file_names[index-1][0:5])
        if index == len(raws)-1:
            # segments.append( one_patient_segs )
            # seg_labels.append( one_patient_labels )
            segments_inter.extend( one_patient_segs_inter )
            segments_pre.extend( one_patient_segs_pre )
            # seg_labels.extend( one_patient_labels )
            one_patient_segs_inter = []
            one_patient_segs_pre = []
            # one_patient_labels = []
            # print(index, file_name, file_names[index-1][0:5])

    segments_pre_overlap = []
    segments_inter_no_overlap = []
    step_for_overlap = int( SEG_LEN / ( total_len_inter / total_len_pre ) )
    print("step_for_overlap:", step_for_overlap)
    print(total_len_inter, total_len_pre)
    while segments_inter or segments_pre:
        if segments_pre:
            # segments_pre_overlap.append( do_overlap( seg_pre, SEG_LEN, OVERLAP_TIME*NEW_SAMP_RATE ) )
            segments_pre_overlap.extend( do_overlap( segments_pre[0], SEG_LEN, SEG_LEN-step_for_overlap) ) #OVERLAP_TIME*NEW_SAMP_RATE ) )
            segments_pre.pop(0)
            # print( len(segments_pre_overlap) )
        if segments_inter:
            # segments_inter_no_overlap.append( do_overlap( segments_inter[0], SEG_LEN, OVERLAP_TIME*NEW_SAMP_RATE ) )
            segments_inter_no_overlap.extend( do_overlap( segments_inter[0], SEG_LEN, 0 ) )
            segments_inter.pop(0)
            # print( len(segments_inter_no_overlap) )
    
    
    return segments_inter_no_overlap, segments_pre_overlap
    

In [7]:
## https://www.kaggle.com/code/debarshichanda/pytorch-supervised-contrastive-learning
class SupervisedContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.1):
        super(SupervisedContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, feature_vectors, labels):
        # Normalize feature vectors
        feature_vectors_normalized = F.normalize(feature_vectors, p=2, dim=1)
        # Compute logits
        logits = torch.div(
            torch.matmul(
                feature_vectors_normalized, torch.transpose(feature_vectors_normalized, 0, 1)
            ),
            self.temperature,
        )
        # print( logits.shape, labels.shape )
        return losses.NTXentLoss(temperature=self.temperature)(logits, labels)
# super_contras = SupervisedContrastiveLoss(0.1)
# super_contras( x, label )

In [8]:


def count_label_np(one_hot_data, label):
    labels = np.argmax(one_hot_data, axis=1)  
    count = np.sum(labels == label)  
    return count  

def calculate_accuracy_np(y_true, y_pred):
    labels_true = y_true #np.argmax(y_true, axis=1)  
    labels_pred = y_pred #np.argmax(y_pred, axis=1)  
    accuracy = np.sum(labels_true.T == labels_pred) / len(labels_true)  
    return accuracy

def calculate_sensitivity_np(y_true, y_pred, class_index):
    # labels_true = np.argmax(y_true, axis=1)  
    # labels_pred = np.argmax(y_pred, axis=1)  
    labels_true = y_true #np.argmax(y_true, axis=1)  
    labels_pred = y_pred #np.argmax(y_pred, axis=1)  
    positive_labels = (labels_true == class_index)  
    true_positive = np.sum((labels_pred == labels_true) * positive_labels)  
    false_negative = np.sum((labels_pred != labels_true) * positive_labels)  
    if (true_positive + false_negative) == 0:
        return -1
    sensitivity = true_positive / (true_positive + false_negative)  
    return sensitivity

In [12]:
# 1 dimentional version + Mamba
## Patient-Specific Seizure Prediction via Adder Network and Supervised Contrastive Learning
## ADDNet-SCL-1d

class OneDCNN(nn.Module):
    def __init__(self, input_channels=3):
        super(OneDCNN, self).__init__()
        self.input_channels = input_channels
      
        self.conv1 = nn.Sequential(
            nn.Conv1d(in_channels=input_channels, out_channels=16, kernel_size=21, stride=1, padding=10),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=8, stride=8)
        )

        self.conv2_1 = nn.Sequential(
            nn.Conv1d(in_channels=16, out_channels=16, kernel_size=1, stride=1),
            nn.ReLU()
        )
        self.conv2_2 = nn.Sequential(
            nn.Conv1d(in_channels=16, out_channels=16, kernel_size=11, stride=1, padding=5),
            nn.ReLU(),
            nn.Conv1d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )

        self.pool3 = nn.MaxPool1d(kernel_size=4, stride=4)

        self.conv4_1 = nn.Sequential(
            nn.Conv1d(in_channels=16, out_channels=32, kernel_size=1, stride=2),
            nn.ReLU()
        )
        self.conv4_2 = nn.Sequential(
            nn.Conv1d(in_channels=16, out_channels=32, kernel_size=5, stride=2, padding=2),
            nn.ReLU(),
            nn.Conv1d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )

        # self.conv5_1 = nn.Sequential(
        #     nn.Conv1d(in_channels=32, out_channels=32, kernel_size=1, stride=2),
        #     nn.ReLU()
        # )
        # self.conv5_2 = nn.Sequential(
        #     nn.Conv1d(in_channels=32, out_channels=32, kernel_size=5, stride=2, padding=2),
        #     nn.ReLU(),
        #     nn.Conv1d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
        #     nn.ReLU()
        # )

        self.mixer = MambaBlock(32)
        self.norm = RMSNorm(32)

        self.adaptive_avg_pool = nn.AdaptiveAvgPool1d(output_size=1)
        
        self.classifier = nn.Sequential(
            nn.Linear(32, 2)
        )

    def forward(self, x):
        # x = x.unsqueeze(1)
        # print(x.shape)
        
        x = self.conv1(x)

        x = self.pool3( self.conv2_1(x)+self.conv2_2(x) )

        x = self.conv4_1(x)+self.conv4_2(x)

        # x = self.conv5_1(x)+self.conv5_2(x)

        # x = [batch, channels, seq_len]
        x = x.permute(0, 2, 1) # [batch, seq_len, channels]
        x = self.mixer(self.norm(x)) + x
        x = x.permute(0, 2, 1) # [batch, channels, seq_len]

        x = self.adaptive_avg_pool( x )
        
        x_digits = x.contiguous().view(x.size(0), -1)
        # print(f'x.shape = {x.shape}')
        x_res = self.classifier(x_digits)
        return x_res, x_digits


In [18]:
patient_index_list = [5, 6, 8, 12, 13, 15, 16, 19, 1, 2, 3]
# patient_index_list = [1, 2, 3]
for patient_index in range(7, 24+1):
    if patient_index in patient_index_list:
        continue
    # 设置数据文件夹路径
    data_folder = 'E:\\EEG\\chb-mit-scalp-eeg-database-1.0.0'
    # patients_folder = []
    # for i in range(1, 25):
    #     if i < 10:
    #         patients_folder.append("chb0"+str(i))
    #     else:
    #         patients_folder.append("chb"+str(i))
    # os.path.join(data_folder, patients_folder[0])
    
    # NUM_OF_PATIENTS_USED = 1000 # Use all data
    PATIENT_TO_USE = ['chb'+str(patient_index).zfill(2)] #06' #('chb0' + str(patient_index)) if patient_index < 10 else 'chb'+str(patient_index) #chb04'
    print(PATIENT_TO_USE)

    with open(PATIENT_TO_USE[0]+'_sel_ch_30iter_with_SMOTE.json', 'r') as json_file:
        loaded_data = json.load(json_file)
    
    COMMON_CH, _ = top_n_elements( loaded_data, 6 ) # 8 channels are acceptable

    log_file_name = 'chb'+str(patient_index).zfill(2)+'_'+str(len(COMMON_CH))+'chs_log_with_mamba_10fold.txt'
    with open(log_file_name, "w") as f:
        f.write("AddNet-SCL With Mamba 1D - 10-fold\n" + str(len(COMMON_CH)) + " channels\n" )
    
    SAMPLE_RATE = 256
    DOWN_SAMP_RATE = 1
    NEW_SAMP_RATE = int(SAMPLE_RATE / DOWN_SAMP_RATE)
    
    # split the raw data to 5 sec segments
    SEG_TIME = 4
    SEG_LEN = NEW_SAMP_RATE * SEG_TIME

    
    # STEP_FOR_OVERLAP = int(NEW_SAMP_RATE * (SEG_TIME - OVERLAP_TIME))
    # STEP_NO_OVERLAP  = SEG_LEN
    
    # SPH = 15 * 60 * NEW_SAMP_RATE# Seizure Prediction Horizon (SPH)
    # PIL = (15+SPH) * 60 * NEW_SAMP_RATE# Pre-ictal interval (PIL)
    # POST_ICTAL_LEN = 15 * 60 * NEW_SAMP_RATE
    
    # # Different patients have different channels, so pick the commont channels
    # COMMON_CH_TOTAL = ['P4-O2', 'T7-P7', 'C4-P4', 'FZ-CZ', 'F3-C3', 'P7-O1', 'CZ-PZ', 'P8-O2', 'F4-C4', 'P3-O1', 'FP2-F8', 'F8-T8', 'FP1-F3', 'F7-T7', 'C3-P3', 'FP1-F7', 'FP2-F4']
    
    
    
    # These files have no common channels with other files
    FILE_EXCLUDED = ['chb12_27.edf', 'chb12_28.edf', 'chb12_29.edf'] 
    
    logger = logging.getLogger('mne')
    # logger.setLevel(logging.WARNING)
    logger.setLevel(logging.ERROR) # 每读一个edf文件都会有一个通道名重复的警告
    
    
    SEIZURE_FOR_TEST = 4
    # 导入数据
    raws = []
    file_names = []
    labels = []
    sec_seizure = {}
    
    seizure_cnt = 0
        
    segments_inter, segments_pre = read_edf_data_separately( data_folder, PATIENT_TO_USE, COMMON_CH, SEG_TIME )
    
    segments = ListMerger( segments_inter, segments_pre )
    seg_labels = [0] * len(segments_inter) + [1] * len(segments_pre)
    seg_labels = torch.nn.functional.one_hot(torch.tensor(seg_labels), num_classes=2)
    
    
    kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
    test_acc = []
    test_sens = []
    test_spec = []
    for i, (train_index, test_index) in enumerate(kf.split(segments, seg_labels.argmax(axis=1))):
        # 获取训练和测试数据，都是np.array
        X_train_indices = train_index
        X_test_indices = test_index
    
        # 从训练集中划分出验证集
        val_size = int(len(X_train_indices) * 0.2)  # 20% 用作验证集
        val_indices = np.random.choice(X_train_indices, size=val_size, replace=False)
        train_indices = np.setdiff1d(X_train_indices, val_indices)
    
        # 创建数据集
        train_dataset = MergedDataset(segments, seg_labels, train_indices)
        val_dataset = MergedDataset(segments, seg_labels, val_indices)
        test_dataset = MergedDataset(segments, seg_labels, X_test_indices)
    
        # 创建数据加载器
        train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)#, pin_memory=True)
        val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)#, pin_memory=True)
        test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)#, pin_memory=True)
        # break
    
        best_test_accuracy = 0
        best_epoch = -1
        sens_p = 0
        spec_p = 0
        
        
        model = OneDCNN(len(COMMON_CH)).to(device)
        if patient_index == 1:
            summary(model, input_size=(len(COMMON_CH), 1024))
        # criterion = nn.MSELoss().to(device)
        # criterion = nn.BCELoss().to(device)
        criterion = nn.CrossEntropyLoss().to(device)
        super_contras = SupervisedContrastiveLoss(0.08).to(device)
        # criterion = WeightedMSELoss(4).to(device)
        optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
        
        num_epochs = 30
        # VOTE_NUM = 5
        y_test = np.array(seg_labels[X_test_indices])
        y_test = np.argmax( y_test, axis=1)
        y_val = np.array(seg_labels[val_indices])
        y_val = np.argmax( y_val, axis=1)
        # y_test_flat = y_test[VOTE_NUM - 1:]
        for epoch in range(num_epochs):
            start_time = time.time()
            # 训练阶段
            model.train()
            
            train_loss = 0.0
            train_loss_cross = 0.0
            train_loss_cont  = 0.0
            
            correct = 0
            for interation, (inputs, targets) in enumerate(train_loader):
                # print(interation)
                
                inputs = inputs.to(device)
                targets = targets.to(device).float()
                outputs, feats = model(inputs)
                loss_cross = criterion(outputs, targets)
                tars = targets.argmax(dim=1)
                # print(tars.shape)
                loss_cont = super_contras( feats, tars )
                loss = 0.5 * loss_cross + 0.5 * loss_cont
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                # functional.reset_net( model )
        
                # print(outputs.shape, targets.shape)
                predicted = torch.argmax(outputs, dim=1) #(outputs > 0.5).float()
                targets = torch.argmax(targets, dim=1) #(targets>0.5).to(device)
                correct += (predicted == targets).sum().item()
        
                train_loss += loss.item()
                train_loss_cross += loss_cross.item()
                train_loss_cont  += loss_cont.item()
                   
            train_loss /= len(train_loader)
            train_loss_cross /= len(train_loader)
            train_loss_cont /= len(train_loader)
            train_acc = correct / len(train_dataset)
        
        
            # 验证阶段
            model.eval()
            
            val_loss = 0.0
            val_loss_cross = 0.0
            val_loss_cont  = 0.0
            
            correct = 0
            total = 0
            TP = 0
            FN = 0
            TN = 0
            FP = 0
            predictions = []
            # targets = []
            with torch.inference_mode():
                for inputs, targets in val_loader:
                    inputs = inputs.to(device)
                    targets = targets.to(device).float()
                    outputs, feats = model(inputs)
                    predicted = torch.argmax(outputs, dim=1) #(outputs > 0.5).float()
                    loss_cross = criterion(outputs, targets)
                    tars = targets.argmax(dim=1)
                    loss_cont = super_contras( feats, tars )
                    loss = loss_cross + loss_cont
                    
                    val_loss += loss.item()
                    val_loss_cross += loss_cross.item()
                    val_loss_cont += loss_cont.item()
                    
                    predictions.extend( predicted.cpu() )
                    
                    # functional.reset_net( model )
        
                predictions = np.array( predictions )
                # final_predictions_flat = np.array( [ mode(predictions[ i-(VOTE_NUM-1) : i+1 ]).mode for i in range(VOTE_NUM-1, len(predictions)) ] )
                
                # 计算混淆矩阵的四个基本元素
                TP = np.sum((predictions == 1) & (y_val == 1))  # True Positives
                TN = np.sum((predictions == 0) & (y_val == 0))  # True Negatives
                FP = np.sum((predictions == 1) & (y_val == 0))  # False Positives
                FN = np.sum((predictions == 0) & (y_val == 1))  # False Negatives
                
                    
                    
            val_loss /= len(val_loader)
            val_loss_cross /= len(val_loader)
            val_loss_cont  /= len(val_loader)
            
            test_accuracy = (TP + TN) / (TP + TN + FP + FN)
            sens_p_epoch = TP / (TP + FN) if (TP+FN != 0 ) else 0
            spec_p_epoch = TN / (TN + FP) if (TN+FP != 0 ) else 0
        
            
            if test_accuracy > best_test_accuracy: # sens_p_epoch > sens_p:  
                best_test_accuracy = test_accuracy
                best_epoch = epoch
                sens_p = TP / (TP + FN) if (TP+FN != 0 ) else 0
                spec_p = TN / (TN + FP) if (TN+FP != 0 ) else 0
                torch.save(model, PATIENT_TO_USE[0]+str(best_epoch)+'.pt')
            end_time = time.time()
            time_cost = end_time - start_time
            # print(f'\nEpoch [{epoch+1}/{num_epochs}], time cost: {time_cost:.4f},\n Train Cross Loss: {train_loss_cross:.4f}, Train Contras Loss: {train_loss_cont:.4f} Train acc: {train_acc: .4f},\n Val Cross loss: {val_loss_cross:.4f}, Val Contras loss: {val_loss_cont:.4f}, Test accuracy: {test_accuracy:.4f}, Sensitivity: {sens_p_epoch:.4f}, Specificity: {spec_p_epoch:.4f}, Test acc vote: {test_acc_vote:.4f}, Sens vote: {sens_p_epoch_vote:.4f}, Spec vote: {spec_p_epoch_vote:.4f}')
            training_info = f'\nFold [{i+1}/10] Epoch [{epoch+1}/{num_epochs}], time cost: {time_cost:.4f},\n Train Cross Loss: {train_loss_cross:.4f}, Train Contras Loss: {train_loss_cont:.4f} Train acc: {train_acc: .4f},\n Val Cross loss: {val_loss_cross:.4f}, Val Contras loss: {val_loss_cont:.4f}, Val accuracy: {test_accuracy:.4f}, Sensitivity: {sens_p_epoch:.4f}, Specificity: {spec_p_epoch:.4f}'
            print(training_info)
            with open(log_file_name, 'a') as f:
                f.write( training_info+'\n' )
        

        train_summary = ', '.join(COMMON_CH)
        train_summary = train_summary + f'\n***********")\nBest epoch: {best_epoch}, Acc: {best_test_accuracy}, Sensitivity for pre-ictal: {sens_p}, Specificity: {spec_p}\n***********\n'
        with open(log_file_name, 'a') as f:
            f.write( train_summary )
        print(COMMON_CH)
        # print("***********\nSensitivity for pre-ictal:", sens_p, ", sensitivity for inter:", sens_inter)
        print("\n***********")
        print("Best epoch:", best_epoch, "Acc:", best_test_accuracy, "Sensitivity for pre-ictal:", sens_p, "Specificity:", spec_p)
        print("***********\n")
    
    
        model_for_test = torch.load( PATIENT_TO_USE[0]+str(best_epoch)+'.pt' )
        model_for_test.eval()
        test_loss = 0.0
        correct = 0
        # total = 0
        TP = 0
        FN = 0
        TN = 0
        FP = 0
        predictions = []
        outputs_list = []
        with torch.inference_mode():
            for inputs, targets in test_loader:
                inputs = inputs.to(device)
                targets = targets.to(device).float()
        
                # print( inputs.shape, targets.shape )
                
                outputs, feats = model_for_test(inputs)
        
                # print( outputs.shape, targets )
                # break
                
                predicted = torch.argmax(outputs, dim=1) #(outputs >= 0.5).float()
                loss = criterion(outputs, targets)
                targets = torch.argmax(targets, dim=1)
                test_loss += loss.item()
                # total += targets.size(0)
                positive_labels = ( targets == 1 )
                correct += (predicted == targets).sum().item()
                TP += ( positive_labels * ( predicted == targets ) ).sum().item()
                FN += ( positive_labels * ( predicted != targets ) ).sum().item()
                # functional.reset_net( model )
        
                predictions.extend(predicted.cpu())
                outputs_list.extend(outputs.cpu())
                
                # print(predicted.shape, targets.shape)
                
                # break
        predictions = np.array( predictions)
        TP = np.sum((predictions == 1) & (y_test == 1))  # True Positives
        TN = np.sum((predictions == 0) & (y_test == 0))  # True Negatives
        FP = np.sum((predictions == 1) & (y_test == 0))  # False Positives
        FN = np.sum((predictions == 0) & (y_test == 1))  # False Negatives
            
        
        test_loss /= len(test_loader)
        test_accuracy = correct / len(test_dataset)
        test_sensitivity = TP / (TP + FN) if (TP+FN != 0 ) else 0
        test_specificity = TN / (TN + FP) if (TN+FP != 0 ) else 0

        test_info = f'Test Loss: {test_loss:.4f}, Test accuracy: {test_accuracy:.4f}, Sensitivity: {test_sensitivity:.4f}, Specificity: {test_specificity:.4f}'
        with open(log_file_name, 'a') as f:
            f.write( test_info )
        print(test_info)
    
        test_acc.append( test_accuracy )
        test_sens.append( test_sensitivity )
        test_spec.append( test_specificity )
    
        # break # only use 80-20 split for channel experiments
    
    avg_test_info = f'\n\nAvg test acc: {sum(test_acc)/len(test_acc):.4f}, avg test sensitivity: {sum(test_sens)/len(test_sens):.4f}, avg test specificity: {sum(test_spec)/len(test_spec):.4f} \n'
    with open(log_file_name, 'a') as f:
        f.write( avg_test_info )
    print( avg_test_info )
    # print( "Avg test acc:", sum(test_acc)/len(test_acc), "avg test sensitivity:", sum(test_sens)/len(test_sens), "avg test specificity", sum(test_spec)/len(test_spec) )

['chb07']
step_for_overlap: 23
60329725 1382403

Fold [1/10] Epoch [1/30], time cost: 75.8651,
 Train Cross Loss: 0.2338, Train Contras Loss: 1.3188 Train acc:  0.9030,
 Val Cross loss: 0.1426, Val Contras loss: 0.8427, Val accuracy: 0.9474, Sensitivity: 0.9853, Specificity: 0.9089

Fold [1/10] Epoch [2/30], time cost: 68.0297,
 Train Cross Loss: 0.1439, Train Contras Loss: 0.8184 Train acc:  0.9465,
 Val Cross loss: 0.1183, Val Contras loss: 0.7026, Val accuracy: 0.9577, Sensitivity: 0.9888, Specificity: 0.9262

Fold [1/10] Epoch [3/30], time cost: 69.4112,
 Train Cross Loss: 0.1233, Train Contras Loss: 0.6780 Train acc:  0.9559,
 Val Cross loss: 0.0939, Val Contras loss: 0.5947, Val accuracy: 0.9680, Sensitivity: 0.9827, Specificity: 0.9531

Fold [1/10] Epoch [4/30], time cost: 68.7489,
 Train Cross Loss: 0.1098, Train Contras Loss: 0.5922 Train acc:  0.9614,
 Val Cross loss: 0.0895, Val Contras loss: 0.5269, Val accuracy: 0.9707, Sensitivity: 0.9723, Specificity: 0.9691

Fold [1/10]

KeyboardInterrupt: 

In [None]:
# for inputs, targets in train_loader:
#     print(inputs)

In [None]:
summary( model, (8, 1024) )

In [None]:
type( segments_pre )

In [None]:
len(segments_pre) / len(segments_inter), test_sens

In [None]:
print( sum(test_acc)/len(test_acc), sum(test_sens)/len(test_sens), sum(test_spec)/len(test_spec) )

In [None]:
model_for_test = torch.load( PATIENT_TO_USE[0]+str(best_epoch)+'.pt' )
model_for_test.eval()
test_loss = 0.0
correct = 0
# total = 0
TP = 0
FN = 0
predictions = []
outputs_list = []
with torch.inference_mode():
    for inputs, targets in test_loader:
        inputs = inputs.to(device)
        targets = targets.to(device).float()

        # print( inputs.shape, targets.shape )
        
        outputs, feats = model_for_test(inputs)

        # print( outputs.shape, targets )
        # break
        
        predicted = torch.argmax(outputs, dim=1) #(outputs >= 0.5).float()
        loss = criterion(outputs, targets)
        targets = torch.argmax(targets, dim=1)
        test_loss += loss.item()
        # total += targets.size(0)
        positive_labels = ( targets == 1 )
        correct += (predicted == targets).sum().item()
        # print( (predicted == targets).sum().item() )
        TP += ( positive_labels * ( predicted == targets ) ).sum().item()
        FN += ( positive_labels * ( predicted != targets ) ).sum().item()
        # functional.reset_net( model )

        predictions.extend(predicted.cpu())
        outputs_list.extend(outputs.cpu())
        
        # print(predicted.shape, targets.shape)
        
        # break
predictions = np.array( predictions)
TP = np.sum((predictions == 1) & (y_test == 1))  # True Positives
TN = np.sum((predictions == 0) & (y_test == 0))  # True Negatives
FP = np.sum((predictions == 1) & (y_test == 0))  # False Positives
FN = np.sum((predictions == 0) & (y_test == 1))  # False Negatives
test_loss /= len(test_loader)
test_accuracy = (TP + TN) / (TP + TN + FP + FN)
test_sensitivity = TP / (TP + FN) if (TP+FN != 0 ) else 0
test_specificity = TN / (TN + FP) if (TN+FP != 0 ) else 0

print(f'Test Loss: {test_loss:.4f}, Test accuracy: {test_accuracy:.4f}, Sensitivity: {test_sensitivity:.4f}, Specificity: {test_specificity:.4f}')

In [None]:
predicted.shape, targets.shape, correct, len(test_loader), len(train_loader)

In [None]:
import itertools

large_list1 = list(range(1000000))

large_list2 = list(range(1000000, 2000000))

merged_list = ListMerger(large_list1, large_list2)

merged_list[1000000-2 : 1000000+2], large_list1[-2:], large_list2[0:2]

In [None]:
# 准备数据
X1 = np.random.rand(10, 3)  # 100 个样本，10 个特征
X2 = np.random.rand(10, 3)  # 另一个特征集
y = np.random.randint(0, 2, size=(100,))  # 100 个标签（0 或 1）

# 创建 ListMerger 实例
X = ListMerger(X1, X2)

# 创建 KFold 实例
kf = KFold(n_splits=5, shuffle=True, random_state=42)

# KFold 交叉验证
for i, (train_index, test_index) in enumerate(kf.split(X)):
    print( type(train_index) )
    print(f"Fold {i}:")
    print(f"  Train: index={train_index}")
    print(f"  Test:  index={test_index}")

In [None]:
seg_labels_for_train.shape, seg_labels_for_test.shape, pre_len_for_test

In [None]:
target_pre_start, target_pre_end, seg_labels[target_pre_start-1:target_pre_end+2]

In [None]:
pre_len_for_test  * 15 / 60

In [None]:
len(seg_labels), index  #_for_test

In [None]:
896/128

In [None]:
functional.reset_net( model )
model.eval()
val_loss = 0.0
correct = 0
total = 0
TP = 0
FN = 0
predictions = []
outputs_list = []
with torch.inference_mode():
    for inputs, targets in val_loader:
        inputs = inputs.to(device)
        targets = targets.to(device)

        # print( inputs.shape, targets.shape )
        
        outputs = model(inputs)

        # print( outputs.shape, targets )
        # break
        
        predicted = (outputs >= 0.5).float()
        loss = criterion(outputs, targets)
        val_loss += loss.item()
        total += targets.size(0)
        positive_labels = ( targets == 1 )
        correct += (predicted == targets).sum().item()
        TP += ( positive_labels * ( predicted == targets ) ).sum().item()
        FN += ( positive_labels * ( predicted != targets ) ).sum().item()
        functional.reset_net( model )

        predictions.extend(predicted.cpu())
        outputs_list.extend(outputs.cpu())
        
        # print(predicted.shape, targets.shape)
        
        # break
predictions = np.array( predictions)
TP = np.sum((predictions == 1) & (y_test == 1))  # True Positives
TN = np.sum((predictions == 0) & (y_test == 0))  # True Negatives
FP = np.sum((predictions == 1) & (y_test == 0))  # False Positives
FN = np.sum((predictions == 0) & (y_test == 1))  # False Negatives
# outputs_list = np.array( outputs_list)
val_loss /= len(val_loader)
test_accuracy = (TP + TN) / (TP + TN + FP + FN)
sens_p_epoch = TP / (TP + FN) if (TP+FN != 0 ) else sens_p
if test_accuracy > best_test_accuracy:
    best_test_accuracy = test_accuracy
    sens_p = TP / (TP + FN) if (TP+FN != 0 ) else sens_p
print(f'Epoch [{epoch+1}/{num_epochs}], Test Loss: {val_loss:.4f}, Test accuracy: {test_accuracy:.4f}, Sensitivity: {sens_p_epoch:.4f}')


In [None]:
# predictions.shape, seg_labels_for_test.shape, outputs_list
print( step_on_both_sides, pre_len_for_test )
corrects_list = predictions == seg_labels_for_test
TPaaa = (predictions == 1) & (y_test == 1) 
TPaaa[120:120+120], outputs_list[60:60+120]
segments.min()
(outputs_list > 0.5).sum() / len(outputs_list)

In [None]:
a = np.ones((1, 10))
b = np.ones((1, 10))
b[0][4] = 0
( a == 1 ) & ( b == 1 )

In [None]:
with torch.inference_mode():
    print(inputs)
    print(model(inputs))

In [None]:
functional.reset_net( model )
model.eval()
val_loss = 0.0
correct = 0
total = 0
TP = 0
FN = 0

out_list = []
with torch.inference_mode():
    for inputs, targets in train_loader:
        inputs = inputs.to(device)
        targets = targets.to(device)
        outputs = model(inputs)
        out_list.extend(outputs.cpu())

        predicted = (outputs > 0.5).float()
        loss = criterion(outputs, targets)
        val_loss += loss.item()
        total += targets.size(0)
        positive_labels = ( predicted == 1 )
        correct += (predicted == targets).sum().item()
        TP += ( positive_labels * ( predicted == targets ) ).sum().item()
        FN += ( positive_labels * ( predicted != targets ) ).sum().item()
        functional.reset_net( model )

out_list = np.array(out_list)
val_loss /= len(train_loader)
test_accuracy = correct / total
sens_p_epoch = TP / (TP + FN) if (TP+FN != 0 ) else sens_p
if test_accuracy > best_test_accuracy:
    best_test_accuracy = test_accuracy
    sens_p = TP / (TP + FN) if (TP+FN != 0 ) else sens_p
print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {val_loss:.4f}, Train accuracy: {test_accuracy:.4f}, Sensitivity: {sens_p_epoch:.4f}')


In [None]:
(out_list > 0.5).sum() / len(out_list)
seg_labels_for_train

In [None]:
# model = OneDCNN(6).to(device)
# torch.save(model, 'test.pt')

In [None]:
def get_parameter_number(model):
    total_num = sum(p.numel() for p in model.parameters())
    trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}
get_parameter_number(model)

In [None]:
22 * 21 * 20 / 3 /2/1

In [None]:
feature_extract_parameters = 0
for name, parameter in model.named_parameters():
    print(name, parameter.numel())
    if 'features' in name:
        feature_extract_parameters += parameter.numel()
print( feature_extract_parameters )

In [None]:
classifier_parameters = 0
for name, parameter in model.named_parameters():
    print(name, parameter.numel())
    if 'classifier' in name:
        classifier_parameters += parameter.numel()
print( classifier_parameters )

In [None]:
for iii in range(segments.shape[1]):
    segments[:, iii, :] = normalize( segments[:, iii, :].squeeze(), axis=1, norm='max' )
np.max( segments[10, 1, :] ), np.min(segments[10, 1, :]), np.mean(segments[10, 1, :])

In [None]:
seg_labels_for_train.shape, seg_labels_for_test.shape

In [None]:
outputs.shape, targets.unsqueeze(1).shape#, (predicted == targets).sum().item()

In [None]:
predicted.shape, targets.shape

In [None]:
# # for inputs, targets in train_loader:
# #     print(targets.shape)
# #     targets = targets.flatten().to(device)
# #     print(targets.shape)
# #     inputs = inputs.to(device)
# #     outputs = model(inputs)
# #     print("outputs.shape =", outputs.shape)
# #     targets = targets.to(device)
# #     print(targets.unsqueeze(1).shape)
# #     break
# model00 = OneDCNN().to(device)
# ccc = 0
# for inputs, targets in train_loader:
#     inputs = inputs.to(device)
#     print(inputs.shape)
#     targets = targets.to(device)
#     # print(targets.shape)
#     outputs, feature = model(inputs)
#     print(inputs.shape, outputs.shape, targets.shape)
#     print(feature[0, :, :])

#     ccc += 1
#     if ccc >= 2:
#         break
    

In [None]:
np.sum(seg_labels_for_train == 1)# / len(seg_labels_for_train)

In [None]:
# temp = segments_for_train[:, i, :].squeeze()
# temp_nor = normalize(temp, axis=1)#(temp - temp.min(axis=1, keepdims=True)) / ( temp.max(axis=1, keepdims=True) - temp.min(axis=1, keepdims=True) )
# temp_nor.mean(axis=1)

In [None]:
# outputs.shape, targets.unsqueeze(1).shape

In [None]:
# orig_seg_labels[0:5], seg_labels[0:5

In [None]:
np.sum(seg_labels_for_test == 1)# / seg_labels_for_test.shape[0]

In [None]:
VOTE_NUM = 5
y_test = np.array(seg_labels_for_test)
y_test_flat = y_test[VOTE_NUM - 1:]
for epoch in range(50, 100):
    start_time = time.time()
    # 训练阶段
    model.train()
    train_loss = 0.0
    correct = 0
    for interation, (inputs, targets) in enumerate(train_loader):
        # print(interation)
        inputs = inputs.to(device)
        targets = targets.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        functional.reset_net( model )

        # print(outputs.shape, targets.shape)
        predicted = (outputs > 0.5).float()
        targets = (targets>0.5).to(device)
        correct += (predicted == targets).sum().item()

        train_loss += loss.item()
    train_loss /= len(train_loader)
    train_acc = correct / len(seg_labels_for_train)


    # 验证阶段
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    TP = 0
    FN = 0
    TN = 0
    FP = 0
    predictions = []
    # targets = []
    with torch.inference_mode():
        for inputs, targets in val_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)
            predicted = (outputs > 0.5).float()
            loss = criterion(outputs, targets)
            val_loss += loss.item()
            predictions.extend( predicted.cpu() )
            
            functional.reset_net( model )

        predictions = np.array( predictions )
        final_predictions_flat = np.array( [ mode(predictions[ i-(VOTE_NUM-1) : i+1 ]).mode for i in range(VOTE_NUM-1, len(predictions)) ] )
        
        # 计算混淆矩阵的四个基本元素
        TP = np.sum((predictions == 1) & (y_test == 1))  # True Positives
        TN = np.sum((predictions == 0) & (y_test == 0))  # True Negatives
        FP = np.sum((predictions == 1) & (y_test == 0))  # False Positives
        FN = np.sum((predictions == 0) & (y_test == 1))  # False Negatives
        
        TP_vote = np.sum((final_predictions_flat == 1) & (y_test_flat == 1))  # True Positives
        TN_vote = np.sum((final_predictions_flat == 0) & (y_test_flat == 0))  # True Negatives
        FP_vote = np.sum((final_predictions_flat == 1) & (y_test_flat == 0))  # False Positives
        FN_vote = np.sum((final_predictions_flat == 0) & (y_test_flat == 1))  # False Negatives
            
            
        
            # total += targets.size(0)
            # positive_labels = ( targets == 1 )
            # negative_labels = ( targets == 0 )
            # correct += (predicted == targets).sum().item()
            # TP += ( positive_labels * ( predicted == targets ) ).sum().item()
            # FN += ( positive_labels * ( predicted != targets ) ).sum().item()
            # TN += ( negative_labels * ( predicted == targets ) ).sum().item()
            # FP += ( negative_labels * ( predicted != targets ) ).sum().item()
            
            # print( predicted.shape, targets.shape )
            
            
    val_loss /= len(val_loader)
    test_accuracy = (TP + TN) / (TP + TN + FP + FN)
    sens_p_epoch = TP / (TP + FN) if (TP+FN != 0 ) else 0
    spec_p_epoch = TN / (TN + FP) if (TN+FP != 0 ) else 0

    test_acc_vote = (TP_vote + TN_vote) / (TP_vote + TN_vote + FP_vote + FN_vote)
    sens_p_epoch_vote = TP_vote / (TP_vote + FN_vote) if (TP_vote+FN_vote != 0 ) else 0
    spec_p_epoch_vote = TN_vote / (TN_vote + FP_vote) if (TN_vote+FP_vote != 0 ) else 0
    
    if test_accuracy > best_test_accuracy:
        best_test_accuracy = test_accuracy
        sens_p = TP / (TP + FN) if (TP+FN != 0 ) else 0
        spec_p = TN / (TN + FP) if (TN+FP != 0 ) else 0
    end_time = time.time()
    time_cost = end_time - start_time
    print(f'\nEpoch [{epoch+1}/{num_epochs}], time cost: {time_cost:.4f}, Train Loss: {train_loss:.4f}, Train acc: {train_acc: .4f}, Val loss: {val_loss:.4f}, Test accuracy: {test_accuracy:.4f}, Sensitivity: {sens_p_epoch:.4f}, Specificity: {spec_p_epoch:.4f}, Test acc vote: {test_acc_vote:.4f}, Sens vote: {sens_p_epoch_vote:.4f}, Spec vote: {spec_p_epoch_vote:.4f}')



# clf = DecisionTreeClassifier(random_state=0)
# clf.fit(segments_for_train_pca, seg_labels_for_train)

# y_pred = clf.predict(segments_for_test_pca)

# # seg_labels_for_test = np.array(seg_labels_for_test)
# # true_positive = np.sum(( seg_labels_for_test == 2 ) * ( y_pred == 2 ))
# # false_negtive = 

# sens_p = calculate_sensitivity_np(seg_labels_for_test, y_pred, 1)
# # # sens_i = calculate_sensitivity_np(seg_labels_for_test, y_pred, 0)
# sens_inter = calculate_sensitivity_np(seg_labels_for_test, y_pred, 0)
# accuracy = calculate_accuracy_np(seg_labels_for_test, y_pred)
print(COMMON_CH)
# print("***********\nSensitivity for pre-ictal:", sens_p, ", sensitivity for inter:", sens_inter)
print("\n***********")
print("Acc:", best_test_accuracy, "Sensitivity for pre-ictal:", sens_p, "Specificity:", spec_p)
print("***********\n")

In [None]:
VOTE_NUM = 5
y_test = np.array(seg_labels_for_test)
y_test_flat = y_test[VOTE_NUM - 1:]
for epoch in range(100, 200):
    start_time = time.time()
    # 训练阶段
    model.train()
    train_loss = 0.0
    correct = 0
    for interation, (inputs, targets) in enumerate(train_loader):
        # print(interation)
        inputs = inputs.to(device)
        targets = targets.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        functional.reset_net( model )

        # print(outputs.shape, targets.shape)
        predicted = (outputs > 0.5).float()
        targets = (targets>0.5).to(device)
        correct += (predicted == targets).sum().item()

        train_loss += loss.item()
    train_loss /= len(train_loader)
    train_acc = correct / len(seg_labels_for_train)


    # 验证阶段
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    TP = 0
    FN = 0
    TN = 0
    FP = 0
    predictions = []
    # targets = []
    with torch.inference_mode():
        for inputs, targets in val_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)
            predicted = (outputs > 0.5).float()
            loss = criterion(outputs, targets)
            val_loss += loss.item()
            predictions.extend( predicted.cpu() )
            
            functional.reset_net( model )

        predictions = np.array( predictions )
        final_predictions_flat = np.array( [ mode(predictions[ i-(VOTE_NUM-1) : i+1 ]).mode for i in range(VOTE_NUM-1, len(predictions)) ] )
        
        # 计算混淆矩阵的四个基本元素
        TP = np.sum((predictions == 1) & (y_test == 1))  # True Positives
        TN = np.sum((predictions == 0) & (y_test == 0))  # True Negatives
        FP = np.sum((predictions == 1) & (y_test == 0))  # False Positives
        FN = np.sum((predictions == 0) & (y_test == 1))  # False Negatives
        
        TP_vote = np.sum((final_predictions_flat == 1) & (y_test_flat == 1))  # True Positives
        TN_vote = np.sum((final_predictions_flat == 0) & (y_test_flat == 0))  # True Negatives
        FP_vote = np.sum((final_predictions_flat == 1) & (y_test_flat == 0))  # False Positives
        FN_vote = np.sum((final_predictions_flat == 0) & (y_test_flat == 1))  # False Negatives
            
            
        
            # total += targets.size(0)
            # positive_labels = ( targets == 1 )
            # negative_labels = ( targets == 0 )
            # correct += (predicted == targets).sum().item()
            # TP += ( positive_labels * ( predicted == targets ) ).sum().item()
            # FN += ( positive_labels * ( predicted != targets ) ).sum().item()
            # TN += ( negative_labels * ( predicted == targets ) ).sum().item()
            # FP += ( negative_labels * ( predicted != targets ) ).sum().item()
            
            # print( predicted.shape, targets.shape )
            
            
    val_loss /= len(val_loader)
    test_accuracy = (TP + TN) / (TP + TN + FP + FN)
    sens_p_epoch = TP / (TP + FN) if (TP+FN != 0 ) else 0
    spec_p_epoch = TN / (TN + FP) if (TN+FP != 0 ) else 0

    test_acc_vote = (TP_vote + TN_vote) / (TP_vote + TN_vote + FP_vote + FN_vote)
    sens_p_epoch_vote = TP_vote / (TP_vote + FN_vote) if (TP_vote+FN_vote != 0 ) else 0
    spec_p_epoch_vote = TN_vote / (TN_vote + FP_vote) if (TN_vote+FP_vote != 0 ) else 0
    
    if test_accuracy > best_test_accuracy:
        best_test_accuracy = test_accuracy
        sens_p = TP / (TP + FN) if (TP+FN != 0 ) else 0
        spec_p = TN / (TN + FP) if (TN+FP != 0 ) else 0
    end_time = time.time()
    time_cost = end_time - start_time
    print(f'\nEpoch [{epoch+1}/{num_epochs}], time cost: {time_cost:.4f}, Train Loss: {train_loss:.4f}, Train acc: {train_acc: .4f}, Val loss: {val_loss:.4f}, Test accuracy: {test_accuracy:.4f}, Sensitivity: {sens_p_epoch:.4f}, Specificity: {spec_p_epoch:.4f}, Test acc vote: {test_acc_vote:.4f}, Sens vote: {sens_p_epoch_vote:.4f}, Spec vote: {spec_p_epoch_vote:.4f}')



# clf = DecisionTreeClassifier(random_state=0)
# clf.fit(segments_for_train_pca, seg_labels_for_train)

# y_pred = clf.predict(segments_for_test_pca)

# # seg_labels_for_test = np.array(seg_labels_for_test)
# # true_positive = np.sum(( seg_labels_for_test == 2 ) * ( y_pred == 2 ))
# # false_negtive = 

# sens_p = calculate_sensitivity_np(seg_labels_for_test, y_pred, 1)
# # # sens_i = calculate_sensitivity_np(seg_labels_for_test, y_pred, 0)
# sens_inter = calculate_sensitivity_np(seg_labels_for_test, y_pred, 0)
# accuracy = calculate_accuracy_np(seg_labels_for_test, y_pred)
print(COMMON_CH)
# print("***********\nSensitivity for pre-ictal:", sens_p, ", sensitivity for inter:", sens_inter)
print("\n***********")
print("Acc:", best_test_accuracy, "Sensitivity for pre-ictal:", sens_p, "Specificity:", spec_p)
print("***********\n")

In [None]:
VOTE_NUM = 3
y_test = np.array(seg_labels_for_test)
y_test_flat = y_test[VOTE_NUM - 1:]
for epoch in range(200, 210):
    start_time = time.time()
    # 训练阶段
    model.train()
    train_loss = 0.0
    correct = 0
    for interation, (inputs, targets) in enumerate(train_loader):
        # print(interation)
        inputs = inputs.to(device)
        targets = targets.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        functional.reset_net( model )

        # print(outputs.shape, targets.shape)
        predicted = (outputs > 0.5).float()
        targets = (targets>0.5).to(device)
        correct += (predicted == targets).sum().item()

        train_loss += loss.item()
    train_loss /= len(train_loader)
    train_acc = correct / len(seg_labels_for_train)


    # 验证阶段
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    TP = 0
    FN = 0
    TN = 0
    FP = 0
    predictions = []
    # targets = []
    with torch.inference_mode():
        for inputs, targets in val_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)
            predicted = (outputs > 0.5).float()
            loss = criterion(outputs, targets)
            val_loss += loss.item()
            predictions.extend( predicted.cpu() )
            
            functional.reset_net( model )

        predictions = np.array( predictions )
        final_predictions_flat = np.array( [ mode(predictions[ i-(VOTE_NUM-1) : i+1 ]).mode for i in range(VOTE_NUM-1, len(predictions)) ] )
        
        # 计算混淆矩阵的四个基本元素
        TP = np.sum((predictions == 1) & (y_test == 1))  # True Positives
        TN = np.sum((predictions == 0) & (y_test == 0))  # True Negatives
        FP = np.sum((predictions == 1) & (y_test == 0))  # False Positives
        FN = np.sum((predictions == 0) & (y_test == 1))  # False Negatives
        
        TP_vote = np.sum((final_predictions_flat == 1) & (y_test_flat == 1))  # True Positives
        TN_vote = np.sum((final_predictions_flat == 0) & (y_test_flat == 0))  # True Negatives
        FP_vote = np.sum((final_predictions_flat == 1) & (y_test_flat == 0))  # False Positives
        FN_vote = np.sum((final_predictions_flat == 0) & (y_test_flat == 1))  # False Negatives
            
            
        
            # total += targets.size(0)
            # positive_labels = ( targets == 1 )
            # negative_labels = ( targets == 0 )
            # correct += (predicted == targets).sum().item()
            # TP += ( positive_labels * ( predicted == targets ) ).sum().item()
            # FN += ( positive_labels * ( predicted != targets ) ).sum().item()
            # TN += ( negative_labels * ( predicted == targets ) ).sum().item()
            # FP += ( negative_labels * ( predicted != targets ) ).sum().item()
            
            # print( predicted.shape, targets.shape )
            
            
    val_loss /= len(val_loader)
    test_accuracy = (TP + TN) / (TP + TN + FP + FN)
    sens_p_epoch = TP / (TP + FN) if (TP+FN != 0 ) else 0
    spec_p_epoch = TN / (TN + FP) if (TN+FP != 0 ) else 0

    test_acc_vote = (TP_vote + TN_vote) / (TP_vote + TN_vote + FP_vote + FN_vote)
    sens_p_epoch_vote = TP_vote / (TP_vote + FN_vote) if (TP_vote+FN_vote != 0 ) else 0
    spec_p_epoch_vote = TN_vote / (TN_vote + FP_vote) if (TN_vote+FP_vote != 0 ) else 0
    
    if test_accuracy > best_test_accuracy:
        best_test_accuracy = test_accuracy
        sens_p = TP / (TP + FN) if (TP+FN != 0 ) else 0
        spec_p = TN / (TN + FP) if (TN+FP != 0 ) else 0
    end_time = time.time()
    time_cost = end_time - start_time
    print(f'\nEpoch [{epoch+1}/{num_epochs}], time cost: {time_cost:.4f}, Train Loss: {train_loss:.4f}, Train acc: {train_acc: .4f}, Val loss: {val_loss:.4f}, Test accuracy: {test_accuracy:.4f}, Sensitivity: {sens_p_epoch:.4f}, Specificity: {spec_p_epoch:.4f}, Test acc vote: {test_acc_vote:.4f}, Sens vote: {sens_p_epoch_vote:.4f}, Spec vote: {spec_p_epoch_vote:.4f}')



# clf = DecisionTreeClassifier(random_state=0)
# clf.fit(segments_for_train_pca, seg_labels_for_train)

# y_pred = clf.predict(segments_for_test_pca)

# # seg_labels_for_test = np.array(seg_labels_for_test)
# # true_positive = np.sum(( seg_labels_for_test == 2 ) * ( y_pred == 2 ))
# # false_negtive = 

# sens_p = calculate_sensitivity_np(seg_labels_for_test, y_pred, 1)
# # # sens_i = calculate_sensitivity_np(seg_labels_for_test, y_pred, 0)
# sens_inter = calculate_sensitivity_np(seg_labels_for_test, y_pred, 0)
# accuracy = calculate_accuracy_np(seg_labels_for_test, y_pred)
print(COMMON_CH)
# print("***********\nSensitivity for pre-ictal:", sens_p, ", sensitivity for inter:", sens_inter)
print("\n***********")
print("Acc:", best_test_accuracy, "Sensitivity for pre-ictal:", sens_p, "Specificity:", spec_p)
print("***********\n")

In [None]:
print("\n***********")
print("Acc:", best_test_accuracy, "Sensitivity for pre-ictal:", sens_p, "Specificity:", spec_p)
print("***********\n")

In [None]:
torch.save(model, PATIENT_TO_USE+str(epoch)+'.pt')

In [None]:
## torch.save(model.state_dict(), str(epoch)+'_state_dict.pt')