In [36]:
import os
import random
from matplotlib import axis
import scipy.io as sio
import numpy as np
import math
import mne
from mne.preprocessing import ICA, create_eog_epochs, create_ecg_epochs
import joblib
import torch
from torch.utils.data import TensorDataset
from sklearn.model_selection import KFold, train_test_split
import model as dl  # Ensure this module contains necessary utility functions
import logging
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

In [None]:
def ensure_dir(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)


In [38]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 使用 GPU
seed = 22
dl.seed_everything(seed)
# EEG data parameters
duration = 2500
low_pass = 0.5
high_pass = 250



# 定义脑区索引
regions = {
    "prefrontal": [0, 1, 2, 3, 10, 11, 16],
    "central": [4, 5, 17],
    "temporal": [12, 13, 14, 15],
    "parietal": [6, 7, 18],
    "occipital": [8, 9],
    "all":[]
}

# 动态获取变量值
partition = "all"

In [39]:
def importAndCropData(file_paths, duration, labels,patition):
    EEG_list = []
    label = []
    for i, file in enumerate(file_paths):
        try:
            raw = mne.io.read_raw_edf(file, preload=True, encoding='latin1',verbose='Warning')
            # raw.filter(l_freq=low_pass, h_freq=high_pass)                # bandpass filter
            # raw.notch_filter(50)                            # notch filter
            data = raw.get_data()[0:19]
            for j in range(len(patition)):
                data[patition[j]] = 0
            if data.shape[1] > duration:
                epochs = data.shape[1] // duration
                data_crop = data[:,0:epochs*duration]
            else:
                continue
            label += [labels[i]] * epochs
            channels = 19
            data_new = data_crop.reshape(channels, -1, duration).transpose(1, 0, 2)
            EEG_list.append(data_new)
            logging.info(f"Processed file {file}: {epochs} epochs")
        except Exception as e:
            logging.error(f"Error processing file {file}: {e}")
            continue

    if not EEG_list:
        raise ValueError("No data was loaded. Please check the file paths and formats.")
    
    EEG = np.concatenate(EEG_list)
    label = np.array(label)
    logging.info(f"Total epochs: {EEG.shape[0]}, Normal: {np.sum(label == 1)}, "
            f"MCI: {np.sum(label == 0)}")
    return EEG,label,high_pass

In [40]:
import os
import warnings
# 忽略 RuntimeWarning 警告
warnings.filterwarnings("ignore", category=RuntimeWarning)
# 定义文件夹路径
base_dir = '糖尿病认知障碍与对照脑电数据'
normal_dir = os.path.join(base_dir, '认知正常')
impaired_dir = os.path.join(base_dir, '认知障碍')

# 获取所有的文件路径
normal_files = [os.path.join(normal_dir, f) for f in os.listdir(normal_dir) if f.endswith('.edf')]
impaired_files = [os.path.join(impaired_dir, f) for f in os.listdir(impaired_dir) if f.endswith('.edf')]

all_files = normal_files + impaired_files
label_single = np.concatenate([np.ones(len(impaired_files)), np.zeros(len(normal_files))],axis=0)
# 将 all_files 和 label_single 中的元素按相同顺序打乱
combined = list(zip(all_files, label_single))
random.shuffle(combined)
all_files[:], label_single[:] = zip(*combined)
original_data,labels,srate = importAndCropData(all_files, duration, label_single,regions[partition])

2024-11-25 20:02:08,392 - INFO - Processed file 糖尿病认知障碍与对照脑电数据\认知障碍\张立志.edf: 116 epochs
2024-11-25 20:02:08,524 - INFO - Processed file 糖尿病认知障碍与对照脑电数据\认知障碍\郭艳丽.edf: 94 epochs
2024-11-25 20:02:08,675 - INFO - Processed file 糖尿病认知障碍与对照脑电数据\认知障碍\贾国强.edf: 108 epochs
2024-11-25 20:02:08,865 - INFO - Processed file 糖尿病认知障碍与对照脑电数据\认知障碍\王翠兰.edf: 135 epochs
2024-11-25 20:02:09,017 - INFO - Processed file 糖尿病认知障碍与对照脑电数据\认知障碍\赵长勇.edf: 120 epochs
2024-11-25 20:02:09,170 - INFO - Processed file 糖尿病认知障碍与对照脑电数据\认知正常\余洪涛.edf: 122 epochs
2024-11-25 20:02:09,320 - INFO - Processed file 糖尿病认知障碍与对照脑电数据\认知障碍\郭秀荣.edf: 108 epochs
2024-11-25 20:02:09,476 - INFO - Processed file 糖尿病认知障碍与对照脑电数据\认知正常\果春胜.edf: 124 epochs
2024-11-25 20:02:09,593 - INFO - Processed file 糖尿病认知障碍与对照脑电数据\认知障碍\陈艳杰.edf: 94 epochs
2024-11-25 20:02:09,730 - INFO - Processed file 糖尿病认知障碍与对照脑电数据\认知正常\张文波.edf: 120 epochs
2024-11-25 20:02:09,880 - INFO - Processed file 糖尿病认知障碍与对照脑电数据\认知正常\孟庆珊.edf: 120 epochs
2024-11-25 20:02:09,991 - INFO - P

In [41]:
train_indices, test_indices = dl.Split_Sets(10, original_data)

# Ensure output directories exist
ensure_dir("EEGData/"+str(partition)+"/TrainData")
ensure_dir("EEGData/"+str(partition)+"/ValidData")
ensure_dir("EEGData/"+str(partition)+"/TestData")

for fold in range(10):
    try:
        # Split into training and test sets
        train_idx = train_indices[fold]
        test_idx = test_indices[fold]

        train_data = original_data[train_idx,:, : ] 
        train_labels = labels[train_idx]
        test_data = original_data[test_idx,:, : ]
        test_labels = labels[test_idx]
        
        # Further split training data into train and validation sets
        train_data_split, valid_data_split, train_labels_split, valid_labels_split = train_test_split(
            train_data, train_labels, test_size=0.1, random_state=seed, stratify=train_labels
        )
        # print(train_data_split.shape,train_labels_split.shape,valid_data_split.shape,valid_labels_split.shape)
        # Convert to PyTorch tensors
        train_tensor = torch.from_numpy(train_data_split).float() # (samples, channels, duration)
        train_labels_tensor = torch.from_numpy(train_labels_split).long()

        valid_tensor = torch.from_numpy(valid_data_split).float()
        valid_labels_tensor = torch.from_numpy(valid_labels_split).long()

        test_tensor = torch.from_numpy(test_data).float()
        test_labels_tensor = torch.from_numpy(test_labels).long()

        # Create TensorDatasets
        train_dataset = TensorDataset(train_tensor, train_labels_tensor)
        valid_dataset = TensorDataset(valid_tensor, valid_labels_tensor)
        test_dataset = TensorDataset(test_tensor, test_labels_tensor)

        # Save datasets
        torch.save(train_dataset, "EEGData/"+str(partition)+f"/TrainData/train_data_{fold + 1}_fold_with_seed_{seed}.pth")
        torch.save(valid_dataset, "EEGData/"+str(partition)+f"/ValidData/valid_data_{fold + 1}_fold_with_seed_{seed}.pth")
        torch.save(test_dataset, "EEGData/"+str(partition)+f"/TestData/test_data_{fold + 1}_fold_with_seed_{seed}.pth")

        logging.info(f"Fold {fold + 1} data saved successfully.")
    except Exception as e:
        logging.error(f"Error processing fold {fold + 1}: {e}")


2024-11-25 20:02:21,673 - INFO - Fold 1 data saved successfully.
2024-11-25 20:02:24,202 - INFO - Fold 2 data saved successfully.
2024-11-25 20:02:26,874 - INFO - Fold 3 data saved successfully.
2024-11-25 20:02:29,367 - INFO - Fold 4 data saved successfully.
2024-11-25 20:02:32,090 - INFO - Fold 5 data saved successfully.
2024-11-25 20:02:34,800 - INFO - Fold 6 data saved successfully.
2024-11-25 20:02:39,138 - INFO - Fold 7 data saved successfully.
2024-11-25 20:02:41,941 - INFO - Fold 8 data saved successfully.
2024-11-25 20:02:44,783 - INFO - Fold 9 data saved successfully.
2024-11-25 20:02:51,566 - INFO - Fold 10 data saved successfully.


In [42]:
z=original_data[0]