### 读取数据

In [203]:
from dataclasses import dataclass, field
from typing import List
import torch
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import LeavePOut
from collections import Counter
from sklearn.impute import KNNImputer
from sklearn.preprocessing import StandardScaler
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import torch.nn.functional as F

In [62]:
@dataclass
class Config:
  # path
  data_path: str = r"D:\Data\Group\2-nuclear_data\deeplearning"
  log_path: str = os.path.join(data_path, "logs")
  subjects: List[str] = field(default_factory=lambda: [
    'NP03', 'NP04', 'NP05', 
    'NP06', 'NP07', 'NP08', 
    'NP09', 'NP10', 'NP11', 
    'NP12', 'NP13', 'NP14', 
    'NP15', 'NP16', 'NP17', 
    'NP18', 'NP19', 'NP20', 
    'NP21', 'NP22', 'NP23', 
    'NP24', 'NP25', 'NP26',
    'NP27', 'NP28', 'NP29', 
    'NP30', 'NP31', 'NP32'])
  
  # data
  knn_k: int = 5
  smote_seed: int = 42

  # training hyper-params
  batch_size: int = 64
  max_epochs: int = 1
  lr_encoder: float = 1e-3
  lr_classifier: float = 1e-3
  lr_domain_discriminator: float = 1e-5
  clip_grad: float = 5.0

  # MCD iterations
  step1_iter: int = 1
  step2_iter: int = 4
  step3_iter: int = 1
  # step4_iter: int = 1
  lambda_GRL: float = 1.0

  # mwl level
  low_level = 1
  mid_level = 5
  high_level = 9

  # 任务定义
  num_classes: int = 3
  binary_threshold: int = 6

  # misc
  device: str = "cuda" if torch.cuda.is_available() else "cpu"


In [63]:
def normalize_by_rest_state(df: pd.DataFrame,
                            rest_duration_minutes: int,
                            sampling_rate: int,
                            label_column: str = 'MWL_Rating') -> pd.DataFrame:
    """
    基于实验开始前静息阶段进行 Z-score 标准化。

    参数：
    - df: DataFrame，包含按时间顺序排列的模态特征与标签
    - rest_duration_minutes: 静息阶段时长（分钟）
    - sampling_rate: 数据采样频率（每秒多少行）
    - label_column: 标签列名，默认是 'MWL_Rating'

    返回：
    - 标准化后的 DataFrame（标签列保持不变）
    """
    # 静息阶段数据行数
    rest_rows = rest_duration_minutes * 60 * sampling_rate
    # 分离标签
    features = df.drop(columns=label_column)
    labels = df[label_column]

    # 计算静息状态下每个特征的均值和标准差
    rest_means = features.iloc[:rest_rows].mean()
    rest_stds = features.iloc[:rest_rows].std()

    # 防止除以0
    rest_stds[rest_stds == 0] = 1e-8

    # Z-score 标准化
    normalized_features = (features - rest_means) / rest_stds
    # 合并标签列
    normalized_df = pd.concat([normalized_features, labels], axis=1)

    return normalized_df

In [64]:
class LabelClassifier:
    def __init__(self, 
                 cfg: Config):
        """
        :param low_start: basic 模式下低负荷的起始分数
        :param mid_start: basic 模式下中负荷的起始分数
        :param high_start: basic 模式下高负荷的起始分数
        """
        self.num_classes = cfg.num_classes
        self.low_start = cfg.low_level
        self.mid_start = cfg.mid_level
        self.high_start = cfg.high_level
        self.binary_threshold = cfg.binary_threshold
        
    def classify(self, rating):
        """
        将单个标签值分类为 0/1/2。
        :param x: 单个 MWL_Rating 值
        :return: 类别标签 0/1/2
        """
        if self.num_classes == 3:
            if rating < self.mid_start:
                return 0
            elif self.mid_start <= rating < self.high_start:
                return 1
            else:
                return 2
        elif self.num_classes == 2:
            if rating <= self.binary_threshold:
                return 0
            else:
                return 1

In [65]:
def load_eeg_data(cfg: Config):
    """
    加载多个被试的 EEG 数据，并统一处理标签和添加被试编号列。

    :param subjects: 被试编号列表
    :param base_path: 基础文件路径，包含所有被试的子文件夹
    :param low_start: 低类别最低分数
    :param mid_start: 中类别最低分数
    :param high_start: 高类别最低分数
    :return: 合并后的 DataFrame
    """
    all_data = []
    for subject in cfg.subjects:
        file_path = f'{cfg.data_path}/{subject}/20width-4step/combined_eeg_features.csv'
        df = pd.read_csv(file_path)
        # 特征归一化
        normalized_df = normalize_by_rest_state(df, 
                                                rest_duration_minutes=5, 
                                                sampling_rate=256)
        # 标签分界类
        classifier = LabelClassifier(cfg)
        # 统一标签处理
        normalized_df['MWL_Rating'] = \
            normalized_df['MWL_Rating'].apply(classifier.classify)
        # 添加被试编号列
        normalized_df['subject_id'] = subject

        all_data.append(normalized_df)
    # 合并所有数据并返回
    return pd.concat(all_data, ignore_index=True)

### function

In [66]:
cfg = Config()

In [67]:
full_df = load_eeg_data(cfg)

In [68]:
full_df[full_df['MWL_Rating'] == 2]

Unnamed: 0,relative_time,Fp1_delta_PSD,Fp1_theta_PSD,Fp1_alpha_PSD,Fp1_beta_PSD,Fpz_delta_PSD,Fpz_theta_PSD,Fpz_alpha_PSD,Fpz_beta_PSD,Fp2_delta_PSD,...,Oz_skew,Oz_kurt,O2_mean,O2_max,O2_min,O2_std,O2_skew,O2_kurt,MWL_Rating,subject_id
4477,-0.636473,0.368598,-0.073809,-0.039094,0.088054,0.038874,-0.050965,-0.031509,0.035434,-0.149886,...,0.347530,-0.689563,-0.300620,0.740404,-0.059547,1.359457,0.309993,-0.546179,2,NP06
4478,-0.635245,-0.013624,0.183893,0.063822,0.002991,-0.048108,0.010177,0.076055,0.293437,-0.167018,...,0.233352,-0.702112,-1.329780,0.740404,-0.285539,1.400870,0.172236,-0.556415,2,NP06
4479,-0.634018,0.039748,0.757297,0.369634,0.192470,0.078060,0.151745,0.244125,0.589142,0.056493,...,0.364277,-0.425940,0.183650,0.740404,-0.285539,0.895392,0.207846,-0.378532,2,NP06
4480,-0.632790,-0.018267,0.450931,0.220883,0.156591,0.013175,0.054770,0.079424,0.235782,-0.143915,...,0.552443,-0.206697,0.110880,0.740404,-0.285539,0.749128,0.345080,-0.229841,2,NP06
4481,-0.631563,-0.170893,-0.187284,-0.105012,-0.006205,-0.127755,-0.102222,-0.090215,-0.044121,-0.375649,...,-0.056532,-0.376496,1.115595,0.236956,-0.285539,0.474216,-0.215024,-0.314855,2,NP06
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
34450,1.681945,-0.128959,-0.130031,-0.106114,-0.088398,-0.133456,-0.128618,-0.113913,-0.092718,-0.136003,...,0.191339,-0.699942,0.058455,-0.419405,0.425400,-0.366983,0.508738,-0.552424,2,NP26
34451,1.683228,-0.135220,-0.131532,-0.115955,-0.100947,-0.137345,-0.132569,-0.121037,-0.103281,-0.141468,...,0.244570,-0.666639,-0.387212,-0.419405,0.425400,-0.378538,0.562881,-0.487795,2,NP26
34452,1.684510,-0.141286,-0.134908,-0.129167,-0.126950,-0.144531,-0.140792,-0.127829,-0.122956,-0.149108,...,0.284222,-0.610701,-0.248361,-0.419405,0.425400,-0.388985,0.641565,-0.405824,2,NP26
34453,1.685792,-0.142701,-0.137163,-0.133294,-0.139431,-0.146940,-0.144777,-0.133529,-0.135186,-0.152309,...,0.202846,-0.567411,-0.256658,-0.419405,0.458089,-0.432065,0.595052,-0.325192,2,NP26


In [69]:
data = full_df

In [70]:
x = data.drop(columns=["subject_id", "MWL_Rating"]).values
y = data["MWL_Rating"].values
groups = data["subject_id"].values  # 被试的编号
unique_subjects = np.unique(groups) # 取

In [71]:
print(f"x.shape={x.shape}, y.shape={y.shape}, groups={groups.shape}")


x.shape=(43792, 449), y.shape=(43792,), groups=(43792,)


In [72]:
lpo = LeavePOut(p=1)  # 留p个被试做测试集
all_y_true = []
all_y_pred = []
all_y_prob = []
subject_metrics = {
    "subject_id": [],
    "accuracy": [],
    "precision": [],
    "recall": [],
    "f1": [],
    "auc": []}

In [159]:
source_index, target_index = list(lpo.split(unique_subjects))[14]
source_subject = unique_subjects[source_index]
target_subject = unique_subjects[target_index]

In [160]:
x_train_origin, y_train_origin = x[np.isin(groups, source_subject)], y[np.isin(groups, source_subject)]
x_test, y_test = x[np.isin(groups, target_subject)], y[np.isin(groups, target_subject)]

In [178]:
print(x_train_origin.shape, y_train_origin.shape)
print(np.nanmean(x_train_origin), np.nanstd(x_train_origin))

(42367, 449) (42367,)
2.4384957519957036e-18 0.9996576736428038


In [162]:
x_test_clean_mask = ~np.isnan(x_test).any(axis=1)
y_test_clean_mask = ~np.isnan(y_test)
test_clean_mask = x_test_clean_mask & y_test_clean_mask

In [163]:
x_test = x_test[test_clean_mask]
y_test = y_test[test_clean_mask]
label_counter_test = Counter(y_test)

In [167]:
def impute_missing_values_by_knn(x, n_neighbors=5):
    """
    使用 KNN 算法填充缺失值。

    :param x: 特征数据
    :param n_neighbors: 用于 KNN 填充的邻居数量，默认为 5
    :return: 填充后的特征数据
    """
    knn_imp = KNNImputer(n_neighbors=n_neighbors)
    x_imp = knn_imp.fit_transform(x)
    return x_imp


In [168]:
x_train_imputed = impute_missing_values_by_knn(x_train_origin, cfg.knn_k)

In [177]:
print(np.nanmean(x_train_imputed), np.nanstd(x_train_imputed))

-9.040956013531597e-06 0.9996396869009675


In [184]:
x_train = x_train_imputed
y_train = y_train_origin

In [186]:
print(np.nanmean(x_train), np.nanstd(y_train))

-9.040956013531597e-06 0.559282926683751


### 对训练集和测试集的 x 进行标准化
1. 确保所有特征能够具备相同的尺度

In [185]:
scaler = StandardScaler()
x_train_scaled = scaler.fit_transform(x_train)
x_test_scaled = scaler.transform(x_test)

In [187]:
print(np.nanmean(x_train_scaled), np.nanstd(x_test_scaled))

-2.450304115116204e-19 1.0003704208493251


- 添加通道维度，把每个样本的特征序列看作一个时间序列

In [188]:
source_x = torch.tensor(x_train, dtype=torch.float32).unsqueeze(2).to(cfg.device)
source_y = torch.tensor(y_train, dtype=torch.long).to(cfg.device)
target_x = torch.tensor(x_test, dtype=torch.float32).unsqueeze(2).to(cfg.device)

In [190]:
assert not torch.isnan(source_x).any() and not torch.isinf(source_x).any()
assert not torch.isnan(target_x).any() and not torch.isinf(target_x).any()


In [191]:
input_dim = source_x.shape[1]

In [192]:
input_dim

449

In [194]:
class FeatureEncoder(nn.Module):
    def __init__(self, input_dim):
        super(FeatureEncoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv1d(in_channels=input_dim, 
                      out_channels=256, 
                      kernel_size=1, 
                      stride=1),
            nn.ReLU(),
            nn.Conv1d(in_channels=256, 
                      out_channels=128, 
                      kernel_size=1, 
                      stride=1),
            nn.ReLU(),
            nn.Conv1d(in_channels=128, 
                      out_channels=128, 
                      kernel_size=1, 
                      stride=1),
            nn.ReLU(),
            nn.Dropout(0.2))

    def forward(self, x):
        features = self.encoder(x)
        features = features.squeeze(-1)  # 移除最后一个维度，变为 (batch_size, 128)
        return features

In [195]:
model_e = FeatureEncoder(input_dim).to(cfg.device)

In [196]:
model_e

FeatureEncoder(
  (encoder): Sequential(
    (0): Conv1d(449, 256, kernel_size=(1,), stride=(1,))
    (1): ReLU()
    (2): Conv1d(256, 128, kernel_size=(1,), stride=(1,))
    (3): ReLU()
    (4): Conv1d(128, 128, kernel_size=(1,), stride=(1,))
    (5): ReLU()
    (6): Dropout(p=0.2, inplace=False)
  )
)

In [197]:
class Classifier(nn.Module):
    """
    模块二：分类器
    """
    def __init__(self, 
                 feature_dim_1, 
                 feature_dim_2, 
                 feature_dim_3, 
                 num_classes):
        super(Classifier, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(feature_dim_1, 
                      feature_dim_2),
            nn.ReLU(),
            nn.Linear(feature_dim_2, 
                      feature_dim_3),
            nn.ReLU(),
            nn.Linear(feature_dim_3, 
                      num_classes),
        )

    def forward(self, x):
        return self.net(x)

In [198]:
model_c1 = Classifier(128, 64, 32, 3).to(cfg.device)
model_c2 = Classifier(128, 256, 128, 3).to(cfg.device)

In [200]:
optimizer_e = optim.Adam(model_e.parameters(), 
                          lr=cfg.lr_encoder,
                          weight_decay=1e-4)
optimizer_c1 = optim.Adam(model_c1.parameters(), 
                          lr=cfg.lr_classifier,
                          weight_decay=1e-4)
optimizer_c2 = optim.Adam(model_c2.parameters(), 
                          lr=cfg.lr_classifier,
                          weight_decay=1e-4)

In [202]:
source_dataset = TensorDataset(source_x, source_y)
source_loader = DataLoader(source_dataset, batch_size=64, shuffle=True)
target_dataset = TensorDataset(target_x)
target_loader = DataLoader(target_dataset, batch_size=64, shuffle=True)

In [204]:
class LossManager:
    def __init__(self, num_classes=3):
        self.num_classes = num_classes

    @staticmethod
    def step1_loss(out1, out2, y):
        # 默认使用交叉熵
        return F.cross_entropy(out1, y) + F.cross_entropy(out2, y)

    # @staticmethod
    # def discrepancy_loss(out1, out2):
    #     return torch.mean(torch.abs(F.softmax(out1, dim=1) - F.softmax(out2, dim=1)))

    def double_classifier_loss_mse(self, out1, out2, y):
        y_onehot = torch.nn.functional.one_hot(y, num_classes=self.num_classes).float()
        p1 = torch.softmax(out1, dim=1)
        p2 = torch.softmax(out2, dim=1)
        return F.mse_loss(p1, y_onehot) + F.mse_loss(p2, y_onehot)

    @staticmethod
    def double_classifier_loss_ce(p1, p2, y_true):
        return F.cross_entropy(p1, y_true) + F.cross_entropy(p2, y_true)

    # @staticmethod
    # def mmd_loss(source, target):
    #     batch_size = int(source.size()[0])
    #     kernels = gaussian_kernel(source, target)
    #     xx = kernels[:batch_size, :batch_size]
    #     yy = kernels[batch_size:, batch_size:]
    #     xy = kernels[:batch_size, batch_size:]
    #     yx = kernels[batch_size:, :batch_size]

    #     return torch.mean(xx + yy - xy - yx)
    
    def discrepancy(self, out1, out2):
        return torch.mean(torch.abs(F.softmax(out1, dim=1) - F.softmax(out2, dim=1)))


In [205]:
loss_manager = LossManager()

In [206]:
ce_loss = nn.CrossEntropyLoss()
bce_loss = nn.BCEWithLogitsLoss()

In [209]:
source_iter = iter(source_loader)
target_iter = iter(target_loader)
# epoch‑level accumulators
cls_loss_epoch = dis_max_epoch = dis_min_epoch = dom_loss_epoch = 0.0
n_batches = len(source_loader)

In [None]:
# epoch‑level accumulators
cls_loss_epoch = dis_max_epoch = dis_min_epoch = dom_loss_epoch = 0.0
n_batches = len(source_loader)

<torch.utils.data.dataloader._SingleProcessDataLoaderIter at 0x1eb3309c750>

In [210]:
source_features, source_labels = next(source_iter)

In [214]:
print(f"source_feature.shape={source_features.shape}, source_labels.shape={source_labels.shape}")

source_feature.shape=torch.Size([64, 449, 1]), source_labels.shape=torch.Size([64])


In [217]:
source_features, source_labels = source_features.to(cfg.device), source_labels.to(cfg.device)

In [218]:
target_features = next(target_iter)[0]
target_features = target_features.to(cfg.device)

In [221]:
print(f"target_features.shape={target_features.shape}")

target_features.shape=torch.Size([64, 449, 1])


### step-1: 源域分类

In [223]:
model_e.train()
model_c1.train()
model_c2.train()

Classifier(
  (net): Sequential(
    (0): Linear(in_features=128, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=3, bias=True)
  )
)

In [233]:
features_extract_by_e = model_e(source_features)
p1_source, p2_source = model_c1(features_extract_by_e), model_c2(features_extract_by_e)
# loss_cls = ce_loss(p1_source, source_labels) + ce_loss(p2_source, source_labels)
loss_cls = loss_manager.double_classifier_loss_ce(p1_source, p2_source, source_labels)

In [234]:
optimizer_e.zero_grad()
optimizer_c1.zero_grad()
optimizer_c2.zero_grad()
loss_cls.backward()
optimizer_e.step()
optimizer_c1.step()
optimizer_c2.step()

In [235]:
loss_cls

tensor(2.0981, device='cuda:0', grad_fn=<AddBackward0>)

### step-2: 最大化分类器差异

In [236]:
model_e.eval()
model_c1.train()
model_c2.train()
# 将 step 2 修改为纯粹的对抗，避免因损失函数过于复杂而导致训练不稳定
with torch.no_grad():
    f_t = model_e(target_features)
p1_t = model_c1(f_t)
p2_t = model_c2(f_t)
loss_dis_max = -loss_manager.discrepancy(p1_t, p2_t)

optimizer_c1.zero_grad()
optimizer_c2.zero_grad()
loss_dis_max.backward()
optimizer_c1.step()
optimizer_c2.step()

### 最小化分类器差异

In [237]:
model_e.train()
model_c1.eval()
model_c2.eval()
for _ in range(4):
    f_t = model_e(target_features)
    p1_t, p2_t = model_c1(f_t), model_c2(f_t)
    loss_dis_min = loss_manager.discrepancy(p1_t, p2_t)
    optimizer_e.zero_grad()
    loss_dis_min.backward()
    optimizer_e.step()


In [238]:
cls_loss_epoch += loss_cls.item()
dis_max_epoch += (-loss_dis_max).item()
dis_min_epoch += loss_dis_min.item()
# dom_loss_epoch += loss_domain.item()

In [239]:
msg = (f"Classification loss: {loss_cls.item():.4f}\n"
       f"Discrepancy maximization: {(-loss_dis_max).item():.4f}\n"
       f"Discrepancy minimization: {loss_dis_min.item():.4f}\n")

In [240]:
print(msg)

Classification loss: 2.0981
Discrepancy maximization: 0.0210
Discrepancy minimization: 0.0190



In [None]:
for source_features, source_labels in source_loader:
  source_features, source_labels = source_features.to(cfg.device), source_labels.to(cfg.device)
  features = model_e(source_features)
  p1 = model_c1(features)
  p2 = model_c2(features)
  loss_cls = loss_manager.double_classifier_loss_ce(p1, p2, source_labels)

  optimizer_e.zero_grad()
  optimizer_c1.zero_grad()
  optimizer_c2.zero_grad()
  loss_cls.backward()
  optimizer_e.step()
  optimizer_c1.step()
  optimizer_c2.step()

In [243]:
source_features.shape

torch.Size([63, 449, 1])

In [244]:
for (target_features, ) in target_loader:
    print(f"The type of target_features: {type(target_features)}")
    target_features = target_features.to(cfg.device)

    # 将 step 2 修改为纯粹的对抗，避免因损失函数过于复杂而导致训练不稳定
    with torch.no_grad():
        f_t = model_e(target_features)
    p1, p2 = model_c1(f_t), model_c2(f_t)
    loss_dis_max = -loss_manager.discrepancy(p1, p2)

    optimizer_c1.zero_grad()
    optimizer_c2.zero_grad()
    loss_dis_max.backward()
    optimizer_c1.step()
    optimizer_c2.step()
    dis_max_epoch += (-loss_dis_max).item()

The type of target_features: <class 'torch.Tensor'>
The type of target_features: <class 'torch.Tensor'>
The type of target_features: <class 'torch.Tensor'>
The type of target_features: <class 'torch.Tensor'>
The type of target_features: <class 'torch.Tensor'>
The type of target_features: <class 'torch.Tensor'>
The type of target_features: <class 'torch.Tensor'>
The type of target_features: <class 'torch.Tensor'>
The type of target_features: <class 'torch.Tensor'>
The type of target_features: <class 'torch.Tensor'>
The type of target_features: <class 'torch.Tensor'>
The type of target_features: <class 'torch.Tensor'>
The type of target_features: <class 'torch.Tensor'>
The type of target_features: <class 'torch.Tensor'>
The type of target_features: <class 'torch.Tensor'>
The type of target_features: <class 'torch.Tensor'>
The type of target_features: <class 'torch.Tensor'>
The type of target_features: <class 'torch.Tensor'>
The type of target_features: <class 'torch.Tensor'>
The type of 