In [None]:
import os

import torch
import torch.optim as optim
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
from torch.optim import Adam, AdamW

os.environ["WANDB_API_KEY"] = "KEY"
os.environ["WANDB_MODE"] = 'offline'

import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import tqdm

from sklearn.metrics import confusion_matrix
from torch.utils.data import DataLoader, Dataset
import random
import csv
from torch import Tensor
import itertools
import math
import re
import numpy as np
import argparse
import pickle


from PIL import Image
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
from sklearn.model_selection import train_test_split
import pandas as pd

In [18]:
mouse_month = {
    'mouse6': ['021322', '022522', '031722', '042422', 
              '052422', '062422', '072322', '082322', 
              '092422', '102122', '112022', '122022', 
              '012123', '022223', '032123', '042323'],
    'mouse5': ['030222', '042422', '052322', '062322', '082422', 
              '092222', '102522', '112822', '122322', 
              '012123', '022423', '032323', '042323', '052423', '062323', '072123'],
    'mouse2': ['031722', '042322', '052322', '062422', '072322', '082322', 
              '092222', '102522', '112822', '122022', '012123', '022223'],
    'mouse11': ['021722', '030122', '032322', '042322', '052322', '052422', 
              '062422', '072422', '082422', '092222', '102522', '112822', '122322', '012123', 
              '022423', '032323', '042323', '052423', '062323', '072123']
}

In [4]:
with open("/media/ubuntu/sda/data/paper_architecture/01_closed_loop/mouse6/pkl/cluster_042323.pkl", "rb") as f:
    cluster_inf = pickle.load(f)

with open("/media/ubuntu/sda/data/paper_architecture/01_closed_loop/mouse6/pkl/spike_042323.pkl", 'rb') as f:
    spike_inf = pickle.load(f)


cluster_inf = cluster_inf[cluster_inf['date'] != '012123']
spike_inf = spike_inf[spike_inf['date'] != '012123']

- Trough_index: 最低点的位置
- Trough_amp: 最低点的值
- Peak_index: 最高点的位置
- Peak_amp: 最高点的值
- Duration: Peak_index - Trough_index
- PT_ratio: abs(Peak_amp/Trough_amp)
- Repolarization slope: Trough点恢复的斜率
- Recovery slope: Peak点恢复的斜率

In [12]:
def calculate_waveform_features(waveform):
    """
    计算波形特征
    
    参数:
    waveform: 1D numpy数组，表示波形数据
    
    返回:
    dict: 包含所有波形特征的字典
    """
    # 确保输入是numpy数组
    waveform = np.array(waveform)
    
    # 1. Trough_index: 最低点的位置
    trough_index = np.argmin(waveform)
    
    # 2. Trough_amp: 最低点的值
    trough_amp = waveform[trough_index]
    
    # 3. Peak_index: 最高点的位置
    peak_index = np.argmax(waveform)
    
    # 4. Peak_amp: 最高点的值
    peak_amp = waveform[peak_index]
    
    # 5. Duration: Peak_index - Trough_index
    duration = abs(peak_index - trough_index)
    
    # 6. PT_ratio: abs(Peak_amp/Trough_amp)
    # 避免除零错误
    if trough_amp != 0:
        pt_ratio = abs(peak_amp / trough_amp)
    else:
        pt_ratio = np.inf if peak_amp != 0 else 0
    
    # 7. Repolarization slope: Trough点恢复的斜率
    # 计算trough点之后的斜率（使用trough点后5个点的线性拟合）
    if trough_index < len(waveform) - 5:
        x_repol = np.arange(5)
        y_repol = waveform[trough_index:trough_index+5]
        repol_slope = np.polyfit(x_repol, y_repol, 1)[0]
    else:
        # 如果trough点太靠近末尾，使用剩余的点
        remaining_points = len(waveform) - trough_index - 1
        if remaining_points > 1:
            x_repol = np.arange(remaining_points)
            y_repol = waveform[trough_index+1:]
            repol_slope = np.polyfit(x_repol, y_repol, 1)[0]
        else:
            repol_slope = 0
    
    # 8. Recovery slope: Peak点恢复的斜率
    # 计算peak点之后的斜率（使用peak点后5个点的线性拟合）
    if peak_index < len(waveform) - 5:
        x_recovery = np.arange(5)
        y_recovery = waveform[peak_index:peak_index+5]
        recovery_slope = np.polyfit(x_recovery, y_recovery, 1)[0]
    else:
        # 如果peak点太靠近末尾，使用剩余的点
        remaining_points = len(waveform) - peak_index - 1
        if remaining_points > 1:
            x_recovery = np.arange(remaining_points)
            y_recovery = waveform[peak_index+1:]
            recovery_slope = np.polyfit(x_recovery, y_recovery, 1)[0]
        else:
            recovery_slope = 0
    
    return {
        'Trough_index': trough_index,
        'Trough_amp': trough_amp,
        'Peak_index': peak_index,
        'Peak_amp': peak_amp,
        'Duration': duration,
        'PT_ratio': pt_ratio,
        'Repolarization_slope': repol_slope,
        'Recovery_slope': recovery_slope
    }


In [14]:
# 为所有cluster计算波形特征
print("正在计算所有cluster的波形特征...")

# 初始化特征列
feature_columns = ['Trough_index', 'Trough_amp', 'Peak_index', 'Peak_amp', 
                   'Duration', 'PT_ratio', 'Repolarization_slope', 'Recovery_slope']

for col in feature_columns:
    cluster_inf[col] = np.nan

# 计算每个cluster的波形特征
for idx in tqdm.tqdm(range(len(cluster_inf)), desc="计算波形特征"):
    try:
        waveform = cluster_inf['position_waveform'].iloc[idx]
        features = calculate_waveform_features(waveform)
        
        # 将特征添加到对应的行
        for feature_name, feature_value in features.items():
            cluster_inf.at[idx, feature_name] = feature_value
            
    except Exception as e:
        print(f"处理第{idx}行时出错: {e}")
        continue

print("波形特征计算完成！")
print(f"处理了 {len(cluster_inf)} 个cluster")


正在计算所有cluster的波形特征...


计算波形特征: 100%|██████████| 327/327 [00:00<00:00, 8284.98it/s]

波形特征计算完成！
处理了 347 个cluster





In [20]:
mouse6_dates

NameError: name 'mouse6_dates' is not defined

In [23]:
mouse6_dates = mouse_month['mouse6']
mouse6_dates_filtered = [date for date in mouse6_dates if date != '012123']

In [24]:
def plot_feature_stability(cluster_inf, feature_name, mouse6_dates_filtered, figsize=(15, 8)):
    """
    绘制单个特征在月份之间的稳定性
    
    参数:
    cluster_inf: 包含波形特征的DataFrame
    feature_name: 要绘制的特征名称
    mouse6_dates_filtered: 按时间顺序排列的日期列表
    figsize: 图像大小
    """
    plt.figure(figsize=figsize)
    
    # 获取所有唯一的cluster_id
    unique_clusters = cluster_inf['cluster_id'].unique()
    
    # 为每个cluster绘制折线
    for cluster_id in unique_clusters:
        cluster_data = cluster_inf[cluster_inf['cluster_id'] == cluster_id]
        
        # 按日期排序
        cluster_data = cluster_data.sort_values('date')
        
        # 提取该cluster在时间序列上的特征值
        dates = cluster_data['date'].tolist()
        values = cluster_data[feature_name].tolist()
        
        # 只绘制有数据的点
        valid_indices = [i for i, v in enumerate(values) if not np.isnan(v)]
        if len(valid_indices) > 1:  # 至少需要2个点才能画线
            valid_dates = [dates[i] for i in valid_indices]
            valid_values = [values[i] for i in valid_indices]
            
            # 绘制淡色的个体折线
            plt.plot(valid_dates, valid_values, 'o-', alpha=0.3, linewidth=1, markersize=3)
    
    # 计算每个日期的平均值和标准差
    mean_values = []
    std_values = []
    valid_dates = []
    
    for date in mouse6_dates_filtered:
        date_data = cluster_inf[cluster_inf['date'] == date]
        if len(date_data) > 0:
            feature_values = date_data[feature_name].dropna()
            if len(feature_values) > 0:
                mean_values.append(feature_values.mean())
                std_values.append(feature_values.std())
                valid_dates.append(date)
    
    # 绘制深色的平均值折线图和误差条
    if len(mean_values) > 1:
        plt.plot(valid_dates, mean_values, 'o-', color='darkred', linewidth=3, 
                markersize=8, label='平均值', zorder=10)
        plt.errorbar(valid_dates, mean_values, yerr=std_values, 
                    color='darkred', alpha=0.7, capsize=5, capthick=2, 
                    elinewidth=2, zorder=10)
    
    plt.title(f'{feature_name} 在月份间的稳定性', fontsize=16, fontweight='bold')
    plt.xlabel('日期', fontsize=14)
    plt.ylabel(f'{feature_name}', fontsize=14)
    plt.xticks(rotation=45)
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.show()
    
    # 打印统计信息
    print(f"\n{feature_name} 统计信息:")
    print(f"平均值范围: {min(mean_values):.4f} - {max(mean_values):.4f}")
    print(f"标准差范围: {min(std_values):.4f} - {max(std_values):.4f}")
    print(f"变异系数 (CV): {(np.array(std_values) / np.array(mean_values)).mean():.4f}")


In [None]:
# 为所有波形特征创建稳定性分析图
feature_columns = ['Trough_index', 'Trough_amp', 'Peak_index', 'Peak_amp', 
                   'Duration', 'PT_ratio', 'Repolarization_slope', 'Recovery_slope']

# 使用mouse6的日期顺序（过滤掉'012123'）
mouse6_dates_filtered = [date for date in mouse_month['mouse6'] if date != '012123']

print("开始绘制所有特征的稳定性分析图...")
print("=" * 60)

for i, feature in enumerate(feature_columns):
    print(f"\n正在绘制第 {i+1}/8 个特征: {feature}")
    plot_feature_stability(cluster_inf, feature, mouse6_dates_filtered)
    print("=" * 60)


In [None]:
# 创建综合的稳定性分析图 - 所有特征在一个图中
def plot_all_features_stability(cluster_inf, feature_columns, mouse6_dates_filtered, figsize=(20, 16)):
    """
    在一个大图中绘制所有特征的稳定性分析
    """
    fig, axes = plt.subplots(2, 4, figsize=figsize)
    axes = axes.flatten()
    
    for i, feature_name in enumerate(feature_columns):
        ax = axes[i]
        
        # 获取所有唯一的cluster_id
        unique_clusters = cluster_inf['cluster_id'].unique()
        
        # 为每个cluster绘制折线
        for cluster_id in unique_clusters:
            cluster_data = cluster_inf[cluster_inf['cluster_id'] == cluster_id]
            cluster_data = cluster_data.sort_values('date')
            
            dates = cluster_data['date'].tolist()
            values = cluster_data[feature_name].tolist()
            
            # 只绘制有数据的点
            valid_indices = [j for j, v in enumerate(values) if not np.isnan(v)]
            if len(valid_indices) > 1:
                valid_dates = [dates[j] for j in valid_indices]
                valid_values = [values[j] for j in valid_indices]
                
                # 绘制淡色的个体折线
                ax.plot(valid_dates, valid_values, 'o-', alpha=0.2, linewidth=0.8, markersize=2)
        
        # 计算每个日期的平均值和标准差
        mean_values = []
        std_values = []
        valid_dates = []
        
        for date in mouse6_dates_filtered:
            date_data = cluster_inf[cluster_inf['date'] == date]
            if len(date_data) > 0:
                feature_values = date_data[feature_name].dropna()
                if len(feature_values) > 0:
                    mean_values.append(feature_values.mean())
                    std_values.append(feature_values.std())
                    valid_dates.append(date)
        
        # 绘制深色的平均值折线图和误差条
        if len(mean_values) > 1:
            ax.plot(valid_dates, mean_values, 'o-', color='darkred', linewidth=2.5, 
                   markersize=6, label='平均值', zorder=10)
            ax.errorbar(valid_dates, mean_values, yerr=std_values, 
                       color='darkred', alpha=0.6, capsize=3, capthick=1.5, 
                       elinewidth=1.5, zorder=10)
        
        ax.set_title(f'{feature_name}', fontsize=12, fontweight='bold')
        ax.set_xlabel('日期', fontsize=10)
        ax.set_ylabel(f'{feature_name}', fontsize=10)
        ax.tick_params(axis='x', rotation=45, labelsize=8)
        ax.grid(True, alpha=0.3)
        ax.legend(fontsize=8)
    
    plt.suptitle('所有波形特征在月份间的稳定性分析', fontsize=16, fontweight='bold', y=0.98)
    plt.tight_layout()
    plt.subplots_adjust(top=0.93)
    plt.show()

# 绘制综合图
print("绘制所有特征的综合稳定性分析图...")
plot_all_features_stability(cluster_inf, feature_columns, mouse6_dates_filtered)


In [None]:
# 稳定性分析总结
def calculate_stability_summary(cluster_inf, feature_columns, mouse6_dates_filtered):
    """
    计算所有特征的稳定性指标
    """
    stability_data = []
    
    for feature in feature_columns:
        # 计算每个日期的统计量
        date_stats = []
        for date in mouse6_dates_filtered:
            date_data = cluster_inf[cluster_inf['date'] == date]
            if len(date_data) > 0:
                feature_values = date_data[feature].dropna()
                if len(feature_values) > 0:
                    date_stats.append({
                        'date': date,
                        'mean': feature_values.mean(),
                        'std': feature_values.std(),
                        'cv': feature_values.std() / feature_values.mean() if feature_values.mean() != 0 else np.inf
                    })
        
        if len(date_stats) > 1:
            means = [stat['mean'] for stat in date_stats]
            stds = [stat['std'] for stat in date_stats]
            cvs = [stat['cv'] for stat in date_stats if not np.isinf(stat['cv'])]
            
            # 计算稳定性指标
            mean_cv = np.mean(cvs) if len(cvs) > 0 else np.inf
            mean_range = (max(means) - min(means)) / np.mean(means) if np.mean(means) != 0 else np.inf
            std_range = (max(stds) - min(stds)) / np.mean(stds) if np.mean(stds) != 0 else np.inf
            
            stability_data.append({
                'Feature': feature,
                'Mean_CV': mean_cv,
                'Mean_Range_Ratio': mean_range,
                'Std_Range_Ratio': std_range,
                'Num_Dates': len(date_stats)
            })
    
    return pd.DataFrame(stability_data)

# 计算稳定性总结
print("计算稳定性分析总结...")
stability_summary = calculate_stability_summary(cluster_inf, feature_columns, mouse6_dates_filtered)

print("\n稳定性分析总结:")
print("=" * 80)
print(stability_summary.round(4))

# 按稳定性排序（CV越小越稳定）
print("\n按稳定性排序（CV越小越稳定）:")
print("=" * 80)
stability_sorted = stability_summary.sort_values('Mean_CV')
print(stability_sorted.round(4))

# 可视化稳定性排名
plt.figure(figsize=(12, 6))
plt.barh(range(len(stability_sorted)), stability_sorted['Mean_CV'], 
         color='skyblue', alpha=0.7, edgecolor='navy')
plt.yticks(range(len(stability_sorted)), stability_sorted['Feature'])
plt.xlabel('平均变异系数 (CV)')
plt.title('波形特征稳定性排名 (CV越小越稳定)')
plt.grid(True, alpha=0.3, axis='x')
plt.tight_layout()
plt.show()
