In [2]:
from dataclasses import dataclass, field
from typing import List
import torch
import os
import logging
import sys
import numpy as np
from datetime import datetime
import pandas as pd
from typing import Optional, Dict, List

In [3]:
@dataclass
class Config:
  # path
  data_path: str = r"D:\Data\Group\2-nuclear_data\deeplearning"
  
  log_path: str = os.path.join(data_path, "logs")
  
  subjects: List[str] = field(default_factory=lambda: [
    'NP03', 'NP04', 'NP05', 
    'NP06', 'NP07', 'NP08', 
    'NP09', 'NP10', 'NP11', 
    'NP12', 'NP13', 'NP14', 
    'NP15', 'NP16', 'NP17', 
    'NP18', 'NP19', 'NP20', 
    'NP21', 'NP22', 'NP23', 
    'NP24', 'NP25', 'NP26',
    'NP27', 'NP28', 'NP29', 
    'NP30', 'NP31', 'NP32'])

  modal_types: List[str] = field(default_factory=lambda: [
    'eeg', 
    'ecg', 
    'eda', 
    'eye'
  ])

  # data
  knn_k: int = 5
  smote_seed: int = 42

  # training hyper-params
  batch_size: int = 128
  max_epochs: int = 40
  lr_encoder: float = 1e-3
  lr_classifier: float = 1e-3
  lr_domain_discriminator: float = 9e-4
  clip_grad: float = 5.0

  # MCD iterations
  step1_iter: int = 1
  step2_iter: int = 4
  step3_iter: int = 1
  lambda_GRL: float = 0.7

  # mwl level
  low_level = 1
  mid_level = 5
  high_level = 9

  # 任务定义
  num_classes: int = 3
  binary_threshold: int = 6

  # misc
  device: str = "cuda" if torch.cuda.is_available() else "cpu"


In [4]:
class Logger:
    def __init__(self, log_dir, name=__name__, log_name_prefix="log", level=logging.INFO):
        self.logger = logging.getLogger(name)
        self.logger.setLevel(self._get_log_level(level))

        self.logger.propagate = False
        if self.logger.handlers:
            self.logger.handlers.clear()

        # 日志路径设置
        os.makedirs(log_dir, exist_ok=True)
        timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
        self.log_path = os.path.join(log_dir, f"{log_name_prefix}_{timestamp}.log")

        # 文件输出
        file_formatter = logging.Formatter(
            "[{asctime}][{levelname}] {message}",
            datefmt="%Y-%m-%d %H:%M:%S",
            style='{')
        # file_formatter = logging.Formatter(
        #     "[{levelname}] {message}",
        #     datefmt="%Y-%m-%d %H:%M:%S",
        #     style='{')

        file_handler = logging.FileHandler(self.log_path, encoding='utf-8')
        file_handler.setFormatter(file_formatter)
        self.logger.addHandler(file_handler)

        # 控制台输出
        console_formatter = logging.Formatter("{message}", style='{')
        console_handler = logging.StreamHandler(sys.stdout)
        console_handler.setFormatter(console_formatter)
        self.logger.addHandler(console_handler)

    def _get_log_level(self, level: str = "INFO") -> int:
        if isinstance(level, int):
            return level
        if isinstance(level, str):
            level = level.upper()
            return {
                "DEBUG": logging.DEBUG,
                "INFO": logging.INFO,
                "WARNING": logging.WARNING,
                "ERROR": logging.ERROR,
                "CRITICAL": logging.CRITICAL
                }.get(level, logging.INFO)

    def log(self, message: str, *args, level: str = "INFO", **kwargs):
        log_level_int = self._get_log_level(level)
        if log_level_int < self.logger.level:
            return
        if log_level_int >= self.logger.level:
            try:
                formatted_message = message.format(*args, **kwargs)
            except Exception as e:
                formatted_message = f"[FormatError] {message} | Args: {args} | Kwargs: {kwargs} | Error: {e}"

        self.logger.log(log_level_int, 
                        formatted_message)

    def log_metrics(self, subject_id, metrics: dict, level="info"):
        msg = f"Subject {subject_id} |" + " | ".join(
            f"{k}: {v * 100: .2f}" for k, v in metrics.items()
        )
        self.log(msg, level=level)

    def save_summary(self, performance_list, metrics):
        self.log("\n=== Summary of Repetitions ===")
        for metric in metrics:
            values = [perf[metric] for perf in performance_list]
            mean = np.mean(values)
            std = np.std(values)
            self.log(f"{metric.capitalize():<9}: {mean * 100:.2f}% ± {std * 100:.2f}%")

    def get_log_path(self):
        return self.log_path


In [5]:
class LabelClassifier:
    def __init__(self, 
                 low: int,
                 mid: int,
                 high: int,
                 binary_threshold: int,
                 num_classes: int):
        """
        
        """
        self.num_classes = num_classes
        self.low_start = low
        self.mid_start = mid
        self.high_start = high
        self.binary_threshold = binary_threshold
        
    def classify(self, rating):
        """
        将单个标签值分类为 0/1/2。
        :param x: 单个 MWL_Rating 值
        :return: 类别标签 0/1/2
        """
        if self.num_classes == 3:
            if rating < self.mid_start:
                return 0
            elif self.mid_start <= rating < self.high_start:
                return 1
            else:
                return 2
        elif self.num_classes == 2:
            if rating <= self.binary_threshold:
                return 0
            else:
                return 1

In [31]:
class MultimodalLoader:
  def __init__(self, cfg: Config, logger: Logger, lbl_classifier: LabelClassifier):
    self.cfg = cfg
    self.logger = logger
    self.lbl_classifier = lbl_classifier

  def LoadMultimodalData(self) -> pd.DataFrame:
    """
    加载多模态数据，处理缺失模态
    """
    all_data = []
    for subject in self.cfg.subjects:
      data = self._loadSubjectData(subject)
      if data is not None:
        all_data.append(data)
    
    if not all_data:
        raise ValueError("cannot load any subject data")
    return pd.concat(all_data, ignore_index=True)

  def _loadSubjectData(self, subject: str) -> Optional[pd.DataFrame]:
    """
    加载单个被试的数据
    """
    subject_data: Dict[str, pd.DataFrame] = {}
    available_modal_types: List[str] = []

    for modal_type in self.cfg.modal_types:
      file_path = f'{self.cfg.data_path}/{subject}/20width-4step/combined_{modal_type}_features.csv'
      try:
        data = pd.read_csv(file_path)
        if 'relative_time' not in data.columns or 'MWL_Rating' not in data.columns:
          self.logger.log("No relative_time or MWL_Rating column in [subject: {}] data", 
                          subject, level="WARNING")
          continue
        subject_data[modal_type] = data
        available_modal_types.append(modal_type)
      except FileNotFoundError:
        self.logger.log("Missing file for subject: {} modal type: {}", subject, modal_type)
        continue
    
    if not available_modal_types:
      self.logger.log("No available modal type for subject: {}", subject)
      return None
    
    combined_data = self._combineModalities(subject_data, available_modal_types, subject)
    self.logger.log("Merged subject {}: all features shape={}", subject, combined_data.shape, level="INFO")
    return combined_data

  def _combineModalities(self, subject_data: Dict, avail_model_type: List[str], subject: str) -> pd.DataFrame:
    """
    合并多模态数据
    """
    # [收集时间戳]
    tol = 3.0
    all_timestamp = sorted({t for df in subject_data.values() for t in df['relative_time'].dropna()})
    base_times = []
    if all_timestamp:
      rep = all_timestamp[0]
      base_times.append(rep)
      for t in all_timestamp[1:]:
        if t - rep > tol:
          rep = t
          base_times.append(rep)
    total_features = pd.DataFrame({'relative_time': base_times})

    if not all_timestamp:
      raise ValueError("No timestamp found in [subject: {}] data", subject)
    
    # [合并特征]
    features = []
    for modal_type in self.cfg.modal_types:
      if modal_type in avail_model_type:
        df = subject_data[modal_type]
        feature_cols = [col for col in df.columns if col not in ['relative_time', 'MWL_Rating']]
        if len(feature_cols) == 0:
          continue
        block = df[['relative_time'] + feature_cols].copy()
        block = block.rename(columns={col: f"{modal_type}_{col}" for col in feature_cols})
        features.append(block)
    for block in features:
      # total_features = total_features.merge(block, on='relative_time', how='left')
      total_features = pd.merge_asof(total_features,
                                     block, 
                                     on='relative_time', 
                                     direction='backward', 
                                     tolerance=tol)
    

    # [合并标签数据]
    labels = []
    for modal_type in self.cfg.modal_types:
      if modal_type in avail_model_type and \
        modal_type in subject_data and \
        'MWL_Rating' in subject_data[modal_type].columns:
          labels.append(
            subject_data[modal_type][['relative_time', 'MWL_Rating']]
            .rename(columns={'MWL_Rating': f'MWL_Rating__{modal_type}'}))
      
    if labels:
      lbl = labels[0]
      for extra in labels[1:]:
        lbl = lbl.merge(extra, on='relative_time', how='outer')
      label_cols = [c for c in lbl.columns if c.startswith('MWL_Rating__')]
      lbl['MWL_Rating'] = lbl[label_cols].bfill(axis=1).ffill(axis=1).iloc[:, 0]
      lbl = lbl[['relative_time', 'MWL_Rating']]
      total_features = total_features.merge(lbl, on='relative_time', how='left')
      total_features['MWL_Rating'] = total_features['MWL_Rating'].ffill().bfill()
    else:
      total_features['MWL_Rating'] = np.nan

    features_only = [c for c in total_features.columns if c not in ['relative_time', 'MWL_Rating']]
    total_features[features_only] = total_features[features_only].fillna(0)
    # [三分类]
    # classifier = LabelClassifier(self.cfg.low_level, 
    #                              self.cfg.mid_level, 
    #                              self.cfg.high_level, 
    #                              self.cfg.binary_threshold, 
    #                              self.cfg.num_classes)
    total_features['MWL_Rating'] = total_features['MWL_Rating'].apply(self.lbl_classifier.classify)

    total_features['subject_id'] = subject
    total_features = total_features.sort_values(by='relative_time').reset_index(drop=True)
    total_features = total_features.drop_duplicates(subset=features_only, keep='first')
    self.logger.log("Success to combine modalities for subject: {}, shape: {}", 
                    subject, total_features.shape, level="INFO")
    return total_features




In [34]:
cfg = Config()
logger = Logger(cfg.log_path)
lbl_classifier = LabelClassifier(cfg.low_level, 
                                 cfg.mid_level, 
                                 cfg.high_level, 
                                 cfg.binary_threshold,
                                 cfg.num_classes)
mm_data = MultimodalLoader(cfg, logger, lbl_classifier)

In [35]:
total_data = mm_data.LoadMultimodalData()

Success to combine modalities for subject: NP03, shape: (1623, 495)
Merged subject NP03: all features shape=(1623, 495)


  total_features['subject_id'] = subject


Success to combine modalities for subject: NP04, shape: (1546, 495)
Merged subject NP04: all features shape=(1546, 495)


  total_features['subject_id'] = subject
  total_features['subject_id'] = subject


Success to combine modalities for subject: NP05, shape: (1650, 495)
Merged subject NP05: all features shape=(1650, 495)
Success to combine modalities for subject: NP06, shape: (1427, 495)
Merged subject NP06: all features shape=(1427, 495)


  total_features['subject_id'] = subject


Success to combine modalities for subject: NP07, shape: (1532, 495)
Merged subject NP07: all features shape=(1532, 495)


  total_features['subject_id'] = subject


Success to combine modalities for subject: NP08, shape: (1527, 495)
Merged subject NP08: all features shape=(1527, 495)


  total_features['subject_id'] = subject


Success to combine modalities for subject: NP09, shape: (1732, 495)
Merged subject NP09: all features shape=(1732, 495)


  total_features['subject_id'] = subject
  total_features['subject_id'] = subject


Success to combine modalities for subject: NP10, shape: (1723, 495)
Merged subject NP10: all features shape=(1723, 495)
Success to combine modalities for subject: NP11, shape: (1678, 495)
Merged subject NP11: all features shape=(1678, 495)


  total_features['subject_id'] = subject


Success to combine modalities for subject: NP12, shape: (1653, 495)
Merged subject NP12: all features shape=(1653, 495)


  total_features['subject_id'] = subject


Success to combine modalities for subject: NP13, shape: (1715, 495)
Merged subject NP13: all features shape=(1715, 495)


  total_features['subject_id'] = subject


Success to combine modalities for subject: NP14, shape: (1657, 495)
Merged subject NP14: all features shape=(1657, 495)
Missing file for subject: NP15 modal type: eye


  total_features['subject_id'] = subject


Success to combine modalities for subject: NP15, shape: (1758, 483)
Merged subject NP15: all features shape=(1758, 483)
Missing file for subject: NP16 modal type: eye


  total_features['subject_id'] = subject


Success to combine modalities for subject: NP16, shape: (1714, 483)
Merged subject NP16: all features shape=(1714, 483)
Missing file for subject: NP17 modal type: eye


  total_features['subject_id'] = subject


Success to combine modalities for subject: NP17, shape: (1732, 483)
Merged subject NP17: all features shape=(1732, 483)
Missing file for subject: NP18 modal type: eye


  total_features['subject_id'] = subject


Success to combine modalities for subject: NP18, shape: (1710, 483)
Merged subject NP18: all features shape=(1710, 483)
Missing file for subject: NP19 modal type: eye


  total_features['subject_id'] = subject


Success to combine modalities for subject: NP19, shape: (1567, 483)
Merged subject NP19: all features shape=(1567, 483)
Missing file for subject: NP20 modal type: eye


  total_features['subject_id'] = subject


Success to combine modalities for subject: NP20, shape: (1562, 483)
Merged subject NP20: all features shape=(1562, 483)
Missing file for subject: NP21 modal type: eye


  total_features['subject_id'] = subject


Success to combine modalities for subject: NP21, shape: (1343, 483)
Merged subject NP21: all features shape=(1343, 483)
Missing file for subject: NP22 modal type: eye


  total_features['subject_id'] = subject
  total_features['subject_id'] = subject


Success to combine modalities for subject: NP22, shape: (1714, 483)
Merged subject NP22: all features shape=(1714, 483)
Missing file for subject: NP23 modal type: eye
Success to combine modalities for subject: NP23, shape: (1653, 483)
Merged subject NP23: all features shape=(1653, 483)
Missing file for subject: NP24 modal type: eye


  total_features['subject_id'] = subject


Success to combine modalities for subject: NP24, shape: (1745, 483)
Merged subject NP24: all features shape=(1745, 483)
Missing file for subject: NP25 modal type: eye


  total_features['subject_id'] = subject
  total_features['subject_id'] = subject


Success to combine modalities for subject: NP25, shape: (1657, 483)
Merged subject NP25: all features shape=(1657, 483)
Missing file for subject: NP26 modal type: eye
Success to combine modalities for subject: NP26, shape: (1805, 483)
Merged subject NP26: all features shape=(1805, 483)


  total_features['subject_id'] = subject


Success to combine modalities for subject: NP27, shape: (1746, 495)
Merged subject NP27: all features shape=(1746, 495)
Missing file for subject: NP28 modal type: eye


  total_features['subject_id'] = subject


Success to combine modalities for subject: NP28, shape: (1672, 483)
Merged subject NP28: all features shape=(1672, 483)
Missing file for subject: NP29 modal type: eye


  total_features['subject_id'] = subject


Success to combine modalities for subject: NP29, shape: (1663, 483)
Merged subject NP29: all features shape=(1663, 483)
Missing file for subject: NP30 modal type: eye


  total_features['subject_id'] = subject


Success to combine modalities for subject: NP30, shape: (1425, 483)
Merged subject NP30: all features shape=(1425, 483)


  total_features['subject_id'] = subject


Success to combine modalities for subject: NP31, shape: (1747, 495)
Merged subject NP31: all features shape=(1747, 495)


  total_features['subject_id'] = subject


Success to combine modalities for subject: NP32, shape: (1651, 495)
Merged subject NP32: all features shape=(1651, 495)


  total_features['subject_id'] = subject


In [37]:
total_data

Unnamed: 0,relative_time,eeg_Fp1_delta_PSD,eeg_Fp1_theta_PSD,eeg_Fp1_alpha_PSD,eeg_Fp1_beta_PSD,eeg_Fpz_delta_PSD,eeg_Fpz_theta_PSD,eeg_Fpz_alpha_PSD,eeg_Fpz_beta_PSD,eeg_Fp2_delta_PSD,...,eye_Saccade_Duration,eye_Saccade_Frequency,eye_Avg_Velocity_Count,eye_Avg_Velocity,eye_Saccade_Amplitude_Count,eye_Avg_Saccade_Amplitude,eye_Avg_Amplitude,eye_Avg_Pupil_Diameter,MWL_Rating,subject_id
0,150.000,119284.967383,3187.071067,209.955324,21.815460,122016.744454,3250.179139,195.643888,24.244840,122585.845195,...,0.00,0.000000,0.0,0.000000,0.0,0.000000,0.000000,0.000000,1,NP03
1,154.000,60004.291996,2103.637583,169.341553,12.376694,60492.269850,2180.295764,172.240027,14.045738,63001.775268,...,1.26,1.610630,33.0,210.169394,7.0,6.532857,6.532857,3.152475,1,NP03
2,158.000,59507.427397,3207.667219,238.766322,14.656701,59458.503479,3335.313952,256.375639,16.397078,61927.499552,...,1.22,1.672665,34.0,226.257647,5.0,5.468000,5.468000,3.166479,1,NP03
3,162.000,97356.887970,5163.942438,363.664856,26.030363,97603.409333,5256.477830,386.907906,28.141836,100218.774038,...,1.24,1.849093,37.0,208.760541,9.0,6.366667,6.366667,3.129368,1,NP03
4,166.000,58192.279293,3087.346277,201.488797,17.094732,58435.380151,3088.646130,213.624687,18.858007,60903.466550,...,1.16,1.797484,37.0,203.598108,12.0,5.961667,5.961667,3.143987,1,NP03
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
49322,13060.618,141.725367,23.649767,5.054168,2.354176,396.197326,27.122484,5.961487,2.594063,1237.124431,...,0.00,0.000000,0.0,0.000000,0.0,0.000000,0.000000,0.000000,0,NP32
49323,13064.618,371.677870,60.445750,8.270844,3.613309,406.689014,75.311848,9.681655,3.702077,845.847998,...,0.00,0.000000,0.0,0.000000,0.0,0.000000,0.000000,0.000000,0,NP32
49324,13068.618,488.161669,63.473177,7.949261,3.592574,523.549268,73.263961,9.168316,3.792235,1174.196092,...,0.00,0.000000,0.0,0.000000,0.0,0.000000,0.000000,0.000000,0,NP32
49325,13072.618,240.336572,25.269240,4.280280,1.867680,270.665617,30.763740,5.699906,2.156052,616.544041,...,0.00,0.000000,0.0,0.000000,0.0,0.000000,0.000000,0.000000,0,NP32
