In [2]:
#!/usr/bin/env python3
"""
Standalone script to compute:
 1. 对 PMC 与 V1 整体的 95th-% ΔF/F（far, close, single）做描述性统计与非配对检验
 2. 针对两对 V1 会话 (single vs social-close)，不匹配 ROI 名称，直接把各自所有 ROI 的 95th-% ΔF/F 作为两组独立样本，
    分别做 Mann–Whitney U 检验（不配对）。
 3. 对上述两对会话，分别计算每个 ROI 的近似 firing frequency（Hz），并用 Mann–Whitney U 检验比较频率分布。

Usage:
    python3 compute_v1_pairwise_no_roi_matching.py
"""

import os
import glob
import json
import numpy as np
import pandas as pd
from scipy.stats import wilcoxon, mannwhitneyu

# ----------------------------------------------------------------------------
# 1. 加载排除字典 & HDF5 读取函数
# ----------------------------------------------------------------------------
exclude_json_path = '/home/lq53/mir_repos/BBOP/random_tests/25feb_more_corr_explo/neuro_exclude.json'
with open(exclude_json_path, 'r') as f:
    exclude_mapping = json.load(f)

def get_excluded_neurons_for_session(session_path, exclude_dict):
    """
    如果 session_path 在排除字典里，就返回需要排除的神经元索引列表，否则返回空列表。
    """
    if not isinstance(exclude_dict, dict):
        return []
    if session_path in exclude_dict:
        return exclude_dict[session_path]
    for key, val in exclude_dict.items():
        if key in session_path:
            return val
    return []

def load_session_data(rec_path):
    """
    在 rec_path/MIR_Aligned 下寻找匹配的 .h5 文件，将其读入为 DataFrame 并添加 session_path 字段。
    """
    h5_dir = os.path.join(rec_path, 'MIR_Aligned')
    pattern = '*aligned_predictions_with_ca_and_dF_F*.h5'
    h5_files = glob.glob(os.path.join(h5_dir, pattern))
    if not h5_files:
        raise FileNotFoundError(f"No .h5 found in {h5_dir}")
    df = pd.read_hdf(h5_files[0], key='df')
    if df.index.name == 'timestamp_ms_mini':
        df = df.reset_index()
    df['session_path'] = rec_path
    return df

def build_neuron_matrix_raw(df, exclude_dict):
    """
    - 如果 df 的索引名是 'timestamp_ms_mini'，先 reset_index()
    - 找到所有以 'dF_F_roi' 开头的列，排除 exclude_dict 中指定的 ROI 索引
    - 计算每个 ROI 的方差，将底部 5% 方差的 ROI 丢弃
    - 返回：kept_names (list[str]) 和 raw_matrix (numpy.ndarray, shape=(n_kept, n_frames))
    """
    if df.index.name == 'timestamp_ms_mini':
        df = df.reset_index()
    sess = df['session_path'].iloc[0]
    excluded = get_excluded_neurons_for_session(sess, exclude_dict)
    all_names = [c for c in df.columns if c.startswith('dF_F_roi')]
    keep_names = []
    for c in all_names:
        try:
            idx = int(c.split('_')[-1][3:])
        except:
            continue
        if idx not in excluded:
            keep_names.append(c)
    if not keep_names:
        return [], np.zeros((0, len(df)))
    activity = df[keep_names].values.T  # shape: (n_all, n_frames)
    variances = np.var(activity, axis=1)
    cutoff = np.percentile(variances, 5) if variances.size else 0
    keep_mask = variances > cutoff
    kept_names = [keep_names[i] for i in range(len(keep_names)) if keep_mask[i]]
    filtered = activity[keep_mask, :]
    return kept_names, filtered

# ----------------------------------------------------------------------------
# 2a. 获取 social session 中每个 ROI 在 far/close 片段的 95th 百分位
# ----------------------------------------------------------------------------
def get_social_far_close_stats(rec_path, exclude_dict, threshold=250.0, pct=95):
    """
    对一个 social session：
      1. load DataFrame & build raw neuron matrix (n_kept, n_frames)
      2. 从 CSV 读取每帧的距离，并和 HDF5 的 camera_frame_sixcam 对齐
      3. 根据 threshold 分类成 far_mask (distance > threshold) 和 close_mask (distance <= threshold)
      4. 对每个 ROI（行），分别计算 raw dF/F 在 far 和 close 片段中的第 pct 百分位数
    返回：
      roi_names: list[str]
      stat_far:   numpy array of shape (n_kept,)
      stat_close: numpy array of shape (n_kept,)
    """
    df = load_session_data(rec_path)
    roi_names, mat_raw = build_neuron_matrix_raw(df, exclude_dict)
    if mat_raw.size == 0:
        return [], np.array([]), np.array([])

    com_csv_path = os.path.join(rec_path, 'MIR_Aligned', 'com_distances_filtered.csv')
    if not os.path.exists(com_csv_path):
        raise FileNotFoundError(f"Missing {com_csv_path}")

    df_com = pd.read_csv(com_csv_path)
    if df_com.shape[1] < 2:
        raise KeyError("com_distances_filtered.csv 必须包含至少两列：帧索引和距离")

    frame_col = df_com.columns[0]
    dist_col  = df_com.columns[1]
    df_com[frame_col] = df_com[frame_col].astype(int)
    df_com[dist_col]  = df_com[dist_col].astype(float)

    if 'camera_frame_sixcam' not in df.columns:
        raise KeyError("HDF5 必须包含 'camera_frame_sixcam' 列以便合并")

    merged = pd.merge(
        df[['camera_frame_sixcam']].reset_index(drop=True),
        df_com[[frame_col, dist_col]].rename(columns={frame_col: 'camera_frame_sixcam'}),
        on='camera_frame_sixcam',
        how='left'
    )
    distances = merged[dist_col].fillna(threshold + 1.0).values

    close_mask = distances <= threshold
    far_mask   = distances > threshold

    if not np.any(close_mask):
        stat_close = np.zeros(len(roi_names)) + np.nan
    else:
        stat_close = np.nanpercentile(mat_raw[:, close_mask], pct, axis=1)
    if not np.any(far_mask):
        stat_far = np.zeros(len(roi_names)) + np.nan
    else:
        stat_far = np.nanpercentile(mat_raw[:, far_mask], pct, axis=1)

    return roi_names, stat_far, stat_close

# ----------------------------------------------------------------------------
# 2b. 获取 single session 中每个 ROI 在全时段的 95th 百分位
# ----------------------------------------------------------------------------
def get_single_stats(rec_path, exclude_dict, pct=95):
    """
    对一个 single session：
      1. load DataFrame & build raw neuron matrix (n_kept, n_frames)
      2. 对每个 ROI（行），计算 raw dF/F 在全时段中的第 pct 百分位数
    返回：
      roi_names: list[str]
      stat_all:   numpy array of shape (n_kept,)
    """
    df = load_session_data(rec_path)
    roi_names, mat_raw = build_neuron_matrix_raw(df, exclude_dict)
    if mat_raw.size == 0:
        return [], np.array([])
    stat_all = np.nanpercentile(mat_raw, pct, axis=1)
    return roi_names, stat_all

# ----------------------------------------------------------------------------
# 2c. 计算每个 ROI 的近似“firing frequency”（Hz）
# ----------------------------------------------------------------------------
def get_firing_frequency(rec_path, exclude_dict, threshold_factor=2.0, frame_rate=30.0):
    """
    对一个会话：
      1. load DataFrame & build raw neuron matrix (n_kept, n_frames)
      2. 对每个 ROI，计算阈值 = median + threshold_factor * std
      3. 统计超出阈值的帧数，除以总时长（假设 frame_rate 帧/秒），得到频率（Hz）
    返回：
      roi_names: list[str]
      freqs:     numpy array of shape (n_kept,)
    """
    df = load_session_data(rec_path)
    roi_names, mat_raw = build_neuron_matrix_raw(df, exclude_dict)
    if mat_raw.size == 0:
        return [], np.array([])
    medians = np.nanmedian(mat_raw, axis=1)
    stds    = np.nanstd(mat_raw, axis=1)
    thresholds = medians + threshold_factor * stds
    n_frames = mat_raw.shape[1]
    duration_sec = n_frames / frame_rate
    if duration_sec <= 0:
        freqs = np.zeros(len(roi_names)) + np.nan
    else:
        firing_counts = (mat_raw > thresholds[:, None]).sum(axis=1)
        freqs = firing_counts / duration_sec
    return roi_names, freqs

# ----------------------------------------------------------------------------
# 3a. 聚合所有 social session 的 far & close 数据（用于整体统计）
# ----------------------------------------------------------------------------
def aggregate_social_far_close(social_sessions, exclude_dict, threshold=250.0, pct=95):
    """
    对一组 social_sessions，循环调用 get_social_far_close_stats，
    把每个 ROI 的 far_pct 和 close_pct 都添加到 all_far 和 all_close 列表里。
    返回：
      np.array(all_far), np.array(all_close)
    """
    all_far = []
    all_close = []

    for rec_path in social_sessions:
        try:
            _, stat_far, stat_close = get_social_far_close_stats(rec_path, exclude_dict, threshold, pct)
        except Exception as e:
            print(f"  Skipping {rec_path}: {e}")
            continue

        valid_idx = ~np.isnan(stat_far) & ~np.isnan(stat_close)
        all_far.extend(stat_far[valid_idx].tolist())
        all_close.extend(stat_close[valid_idx].tolist())

    return np.array(all_far), np.array(all_close)

# ----------------------------------------------------------------------------
# 3b. 聚合所有 single session 的全部 95th 数据（用于整体统计）
# ----------------------------------------------------------------------------
def aggregate_single_stats(single_sessions, exclude_dict, pct=95):
    """
    对一组 single_sessions，循环调用 get_single_stats，
    把每个 ROI 的 pct 值添加到 all_single 列表里。
    返回：
      np.array(all_single)
    """
    all_single = []

    for rec_path in single_sessions:
        try:
            _, stat_all = get_single_stats(rec_path, exclude_dict, pct)
        except Exception as e:
            print(f"  Skipping {rec_path}: {e}")
            continue

        valid_idx = ~np.isnan(stat_all)
        all_single.extend(stat_all[valid_idx].tolist())

    return np.array(all_single)

# ----------------------------------------------------------------------------
# 4. 针对指定的两对 V1 会话，分别比较 single vs social-close（不匹配 ROI 名称）
# ----------------------------------------------------------------------------
def compute_session_pair_unpaired(pairs, exclude_dict, pct=95):
    """
    对每个 (single_path, social_path) 对：
      1. 分别获取 single 会话的 95% ΔF/F (roi_single_vals)
      2. 获取 social 会话的 close-95% ΔF/F (roi_social_close_vals)
      3. 直接将这两组视为独立样本，做 Mann–Whitney U 检验
      4. 同时分别计算两会话中每个 ROI 的 firing frequency，然后两组频率也用 Mann–Whitney U 检验
      5. 打印各自分布的中位数、均值、标准差，以及检验结果
    """
    for single_path, social_path in pairs:
        # 4.1 single 会话 95th ΔF/F
        roi_single, stat_single = get_single_stats(single_path, exclude_dict, pct)
        # 4.2 social 会话的 close-95th ΔF/F
        roi_social, stat_far, stat_close = get_social_far_close_stats(social_path, exclude_dict, threshold=250.0, pct=pct)

        # 如果两组数据都为空，则跳过
        if stat_single.size == 0 or stat_close.size == 0:
            print(f"\nPair:\n  Single: {single_path}\n  Social: {social_path}")
            print("  单侧或社交-close 数据为空，跳过此对。")
            continue

        # 4.3 Mann–Whitney U 检验 (95th ΔF/F)
        u_stat_df, p_value_df = mannwhitneyu(stat_single, stat_close, alternative='two-sided')

        # 4.4 计算 firing frequency
        _, freq_single = get_firing_frequency(single_path, exclude_dict)
        _, freq_social = get_firing_frequency(social_path, exclude_dict)
        # 如果某一侧频率为空，也跳过该检验
        if freq_single.size == 0 or freq_social.size == 0:
            freq_test_pass = False
        else:
            freq_test_pass = True
            u_stat_freq, p_value_freq = mannwhitneyu(freq_single, freq_social, alternative='two-sided')

        # 4.5 打印结果
        print(f"\n===== 会话对 比较（不匹配 ROI 名）=====")
        print(f"Single 会话: {single_path}")
        print(f"Social-close 会话: {social_path}")

        # 95th ΔF/F 的描述性统计
        mean_sing = np.nanmean(stat_single)
        std_sing  = np.nanstd(stat_single)
        med_sing  = np.nanmedian(stat_single)
        n_sing    = np.count_nonzero(~np.isnan(stat_single))

        mean_soc = np.nanmean(stat_close)
        std_soc  = np.nanstd(stat_close)
        med_soc  = np.nanmedian(stat_close)
        n_soc    = np.count_nonzero(~np.isnan(stat_close))

        print(f"\n--- 95th-% ΔF/F 分布（Single vs Social-close）---")
        print(f"Single:       n={n_sing}, mean={mean_sing:.3f}, std={std_sing:.3f}, median={med_sing:.3f}")
        print(f"Social-close: n={n_soc}, mean={mean_soc:.3f}, std={std_soc:.3f}, median={med_soc:.3f}")
        print(f"Mann–Whitney U (ΔF/F): U={u_stat_df:.3f}, p-value={p_value_df:.5f}")

        # 频率的描述性统计
        if freq_test_pass:
            mean_fs = np.nanmean(freq_single)
            std_fs  = np.nanstd(freq_single)
            med_fs  = np.nanmedian(freq_single)
            n_fs    = np.count_nonzero(~np.isnan(freq_single))

            mean_fc = np.nanmean(freq_social)
            std_fc  = np.nanstd(freq_social)
            med_fc  = np.nanmedian(freq_social)
            n_fc    = np.count_nonzero(~np.isnan(freq_social))

            print(f"\n--- firing frequency (Hz) 分布（Single vs Social）---")
            print(f"Single freq:       n={n_fs}, mean={mean_fs:.3f}, std={std_fs:.3f}, median={med_fs:.3f}")
            print(f"Social freq:       n={n_fc}, mean={mean_fc:.3f}, std={std_fc:.3f}, median={med_fc:.3f}")
            print(f"Mann–Whitney U (freq): U={u_stat_freq:.3f}, p-value={p_value_freq:.5f}")
        else:
            print("\n无法比较 firing frequency：某一侧频率数据为空。")

# ----------------------------------------------------------------------------
# 5. 对整体 V1 与 PMC 会话做描述性统计和非配对检验
# ----------------------------------------------------------------------------
def compute_and_print_stats(label, all_far, all_close, all_single):
    """
    1. 对 Social Far vs Social Close 做配对的 Wilcoxon signed-rank 检验
    2. 对 Single vs Social Close 做 Mann-Whitney U 检验（不配对）
    3. 对 Single vs Social Far   做 Mann-Whitney U 检验（不配对）
    打印均值、标准差、中位数和 p 值。
    """
    def desc_stats(arr):
        return {
            'mean': np.nanmean(arr),
            'std':  np.nanstd(arr),
            'median': np.nanmedian(arr),
            'n':    np.count_nonzero(~np.isnan(arr))
        }

    stats_far    = desc_stats(all_far)
    stats_close  = desc_stats(all_close)
    stats_single = desc_stats(all_single)

    print(f"\n===== {label} =====")
    print(f"Social-Far:    n={stats_far['n']}, mean={stats_far['mean']:.3f}, std={stats_far['std']:.3f}, median={stats_far['median']:.3f}")
    print(f"Social-Close:  n={stats_close['n']}, mean={stats_close['mean']:.3f}, std={stats_close['std']:.3f}, median={stats_close['median']:.3f}")
    print(f"Single:        n={stats_single['n']}, mean={stats_single['mean']:.3f}, std={stats_single['std']:.3f}, median={stats_single['median']:.3f}")

    # 1. 配对 Wilcoxon：Social-Far vs Social-Close
    if len(all_far) == len(all_close) and len(all_far) > 0:
        w_stat, w_p = wilcoxon(all_far, all_close)
        print(f"\nPaired Wilcoxon (Far vs Close): statistic={w_stat:.3f}, p-value={w_p:.5f}")
    else:
        print("\nPaired Wilcoxon (Far vs Close): 数据长度不匹配或无有效数据，无法进行配对检验。")

    # 2. Single vs Social-Close（Mann-Whitney U，不配对）
    if len(all_single) > 0 and len(all_close) > 0:
        u_stat_sc, u_p_sc = mannwhitneyu(all_single, all_close, alternative='two-sided')
        print(f"Mann-Whitney U (Single vs Close): U={u_stat_sc:.3f}, p-value={u_p_sc:.5f}")
    else:
        print("Mann-Whitney U (Single vs Close): 数据不足，无法进行检验。")

    # 3. Single vs Social-Far（Mann-Whitney U，不配对）
    if len(all_single) > 0 and len(all_far) > 0:
        u_stat_sf, u_p_sf = mannwhitneyu(all_single, all_far, alternative='two-sided')
        print(f"Mann-Whitney U (Single vs Far): U={u_stat_sf:.3f}, p-value={u_p_sf:.5f}")
    else:
        print("Mann-Whitney U (Single vs Far): 数据不足，无法进行检验。")

# ----------------------------------------------------------------------------
# 6. 会话路径示例（请根据实际情况修改）
# ----------------------------------------------------------------------------
pmc_social_sessions = [
    "/data/big_rim/rsync_dcc_sum/Oct3V1/2025_02_27/20241015PMCBE1mini_p20241015PMCRE1_12_33",
    "/data/big_rim/rsync_dcc_sum/Oct3V1/2025_05_16/20240303PMCBE0r1coatedmini_p20240303RE1",
]

pmc_single_sessions = [
    "/data/big_rim/rsync_dcc_sum/Oct3V1/2024_10_24/20241001PMCr2_16_19",
    "/data/big_rim/rsync_dcc_sum/Oct3V1/2024_10_25/20241002PMCr2_16_25",
    "/data/big_rim/rsync_dcc_sum/Oct3V1/2024_11_06/20241015pmcr2_16_53",
    "/data/big_rim/rsync_dcc_sum/Oct3V1/2024_11_06/20241015pmcr2_17_13",
    "/data/big_rim/rsync_dcc_sum/Oct3V1/2025_02_12/20241001PMCRE2mini_13_44",
    "/data/big_rim/rsync_dcc_sum/Oct3V1/2025_02_12/20241001PMCRE2mini_13_57",
    "/data/big_rim/rsync_dcc_sum/Oct3V1/2025_02_12/20241001PMCRE2mini_15_35",
    "/data/big_rim/rsync_dcc_sum/Oct3V1/2025_02_13/20241225PMCLE1mini_11_06",
    "/data/big_rim/rsync_dcc_sum/Oct3V1/2025_02_27/20241015PMCBE1mini_12_24"
]

v1_social_sessions = [
    "/data/big_rim/rsync_dcc_sum/Oct3V1/2024_12_18/20240919v1l5r1mini_p20240717PMC_social_test_11_30",
    "/data/big_rim/rsync_dcc_sum/Oct3V1/2024_12_31/20240919v1l5r2mini_p20240717PMC_social_14_04",
    "/data/big_rim/rsync_dcc_sum/Oct3V1/2024_10_31/2social_mini_20240819V1r1_single_11_29",
    "/data/big_rim/rsync_dcc_sum/Oct3V1/2024_10_31/2social_mini_20240819V1r1_femalebleach_11_48",
    "/data/big_rim/rsync_dcc_sum/Oct3V1/2025_05_16/20241216V1RE1Fmini_p20241216RE2",
    "/data/big_rim/rsync_dcc_sum/Oct3V1/2025_05_16/20241216V1RE1Fmini_p20241224PMCLE1",
]

v1_single_sessions = [
    "/data/big_rim/rsync_dcc_sum/Oct3V1/2024_10_14/20240916v1r1_16_37",
    "/data/big_rim/rsync_dcc_sum/Oct3V1/2024_10_14/20240916v1r1_16_53",
    "/data/big_rim/rsync_dcc_sum/Oct3V1/2024_10_14/20240916v1r2_14_30",
    "/data/big_rim/rsync_dcc_sum/Oct3V1/2024_10_14/20240916v1r2_15_58",
    "/data/big_rim/rsync_dcc_sum/Oct3V1/2024_10_17/20240819V1r1_13_41",
    "/data/big_rim/rsync_dcc_sum/Oct3V1/2024_10_17/20240819V1r1_14_25",
    "/data/big_rim/rsync_dcc_sum/Oct3V1/2024_12_18/20240919v1l5r1mini_11_21",
    "/data/big_rim/rsync_dcc_sum/Oct3V1/2024_12_31/20240919v1l5r2mini_13_54",
]

# ----------------------------------------------------------------------------
# 7. V1 指定的两对会话 (single, social-close)，不做 ROI 匹配
# ----------------------------------------------------------------------------
v1_paired_sessions = [
    (
        "/data/big_rim/rsync_dcc_sum/Oct3V1/2024_12_18/20240919v1l5r1mini_11_21",
        "/data/big_rim/rsync_dcc_sum/Oct3V1/2024_12_18/20240919v1l5r1mini_p20240717PMC_social_test_11_30"
    ),
    (
        "/data/big_rim/rsync_dcc_sum/Oct3V1/2024_12_31/20240919v1l5r2mini_13_54",
        "/data/big_rim/rsync_dcc_sum/Oct3V1/2024_12_31/20240919v1l5r2mini_p20240717PMC_social_14_04"
    ),
]

# ----------------------------------------------------------------------------
# 8. 主流程
# ----------------------------------------------------------------------------
if __name__ == '__main__':
    threshold = 250.0  # COM 距离阈值
    pct = 95           # 使用第 95 百分位数

    # # --- PMC 整体统计 ---
    # print("=== PMC 整体统计 ===")
    # pmc_far, pmc_close = aggregate_social_far_close(pmc_social_sessions, exclude_mapping, threshold, pct)
    # pmc_single        = aggregate_single_stats(pmc_single_sessions, exclude_mapping, pct)
    # compute_and_print_stats("PMC (Social vs Single)", pmc_far, pmc_close, pmc_single)

    # # --- V1 整体统计 ---
    # print("\n=== V1 整体统计 ===")
    # v1_far, v1_close  = aggregate_social_far_close(v1_social_sessions, exclude_mapping, threshold, pct)
    # v1_single         = aggregate_single_stats(v1_single_sessions, exclude_mapping, pct)
    # compute_and_print_stats("V1 (Social vs Single)", v1_far, v1_close, v1_single)

    # --- V1 指定会话对（不匹配 ROI）---
    print("\n=== V1 指定会话对比较（不匹配 ROI 名称） ===")
    compute_session_pair_unpaired(v1_paired_sessions, exclude_mapping, pct)

    print("\n所有统计计算完毕。")



=== V1 指定会话对比较（不匹配 ROI 名称） ===

===== 会话对 比较（不匹配 ROI 名）=====
Single 会话: /data/big_rim/rsync_dcc_sum/Oct3V1/2024_12_18/20240919v1l5r1mini_11_21
Social-close 会话: /data/big_rim/rsync_dcc_sum/Oct3V1/2024_12_18/20240919v1l5r1mini_p20240717PMC_social_test_11_30

--- 95th-% ΔF/F 分布（Single vs Social-close）---
Single:       n=31, mean=2.968, std=0.813, median=2.849
Social-close: n=42, mean=2.688, std=1.737, median=2.015
Mann–Whitney U (ΔF/F): U=873.000, p-value=0.01344

--- firing frequency (Hz) 分布（Single vs Social）---
Single freq:       n=31, mean=1.833, std=0.290, median=1.896
Social freq:       n=42, mean=1.689, std=0.456, median=1.686
Mann–Whitney U (freq): U=822.000, p-value=0.05706

===== 会话对 比较（不匹配 ROI 名）=====
Single 会话: /data/big_rim/rsync_dcc_sum/Oct3V1/2024_12_31/20240919v1l5r2mini_13_54
Social-close 会话: /data/big_rim/rsync_dcc_sum/Oct3V1/2024_12_31/20240919v1l5r2mini_p20240717PMC_social_14_04

--- 95th-% ΔF/F 分布（Single vs Social-close）---
Single:       n=15, mean=3.917, std=1.588, m