In [None]:
import os
import gc
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt



TRAIN_FRACTION = 1  # 改成 0.1 即表示使用清洗后训练集的 10%

'''
# 如果还没装依赖
!pip -q install pyarrow
'''

# =========================================================
# 0) 读取数据
# =========================================================
#parquet_path = "/content/drive/MyDrive/CA-AA/Cleandata_Lable20_N157794_ML30_MU3000.parquet" # 20251201
#parquet_path = "/content/drive/MyDrive/CA-AA/Cleandata5_Lable20_N54820_3000_ML30_MU3000.parquet"  # 20251207-1
#parquet_path = "/content/drive/MyDrive/CA-AA/Cleandata5_Lable20_N70848_4000_ML30_MU3000.parquet"  # 20251207-2
parquet_path = "/content/drive/MyDrive/CA-AA/Cleandata5_Lable20_N83597_5000_ML30_MU3000.parquet"   # 20251207-3

data = pd.read_parquet(parquet_path)
print("读取完成:", data.shape)

# =========================================================
# 1) 有效长度+按标签剔除
# =========================================================
label_col = data.columns[0]      # 第一列 label
feature_cols = data.columns[3:]  # 第四列开始是时序特征

data = data.copy()
data["valid_length"] = (data[feature_cols] != 0).sum(axis=1)

def remove_outliers_by_label(df):
    q_low = df["valid_length"].quantile(0.35)
    q_high = df["valid_length"].quantile(0.65)
    return df[(df["valid_length"] >= q_low) & (df["valid_length"] <= q_high)]

cleaned_data = (
    data
    .groupby(label_col, group_keys=False)
    .apply(remove_outliers_by_label)
)

cleaned_data = cleaned_data.drop(columns=["valid_length"])
print(f"清洗前样本数: {len(data)}")
print(f"清洗后样本数: {len(cleaned_data)}")

data = cleaned_data

# =========================================================
# 2) 删除指定标签
# =========================================================
remove_labels = [8, 12, 14]
filtered_data = data[~data.iloc[:, 0].isin(remove_labels)].copy()

print("原始数据 shape:", data.shape)
print("删除后的数据 shape:", filtered_data.shape)
data = filtered_data

# =========================================================
# 3) 划分 train/val/test
# =========================================================
labels = data.iloc[:, 0].values
raw = data.iloc[:, 3:].values         # numpy array [N, L]

del data
gc.collect()

# 先拆训练/临时
raw_train, raw_tmp, y_train, y_tmp = train_test_split(
    raw, labels,
    test_size=0.2,
    random_state=42,
    stratify=labels
)
# 再拆临时为 验证/测试
raw_val, raw_test, y_val, y_test = train_test_split(
    raw_tmp, y_tmp,
    test_size=0.5,
    random_state=42,
    stratify=y_tmp
)

del raw, raw_tmp, y_tmp
gc.collect()

print(f"训练集样本数（拆分后）：{len(raw_train)}")

# =========================================================
# 4) 第一次清洗：按非零长度去除每标签最小/最大 x% 样本
# =========================================================
x = 0.0  # 原来是 0 或 0.005，自行调整
keep_idx = []
for lbl in np.unique(y_train):
    idxs = np.where(y_train == lbl)[0]
    nz = (raw_train[idxs] != 0).sum(axis=1)
    order = idxs[np.argsort(nz)]
    n = len(order)
    low, high = int(n * x), int(n * (1 - x))
    keep_idx.extend(order[low:high])

keep_idx = np.sort(keep_idx)
raw_train = raw_train[keep_idx]
y_train   = y_train[keep_idx]
gc.collect()

# =========================================================
# 5) 第二次清洗：按“最大下降幅度”去除每标签最小/最大 x% 样本
# =========================================================
x = 0.0  # 同上

def calc_max_drop_amp(arr):
    v = arr[~np.isnan(arr) & (arr != 0)]
    if v.size == 0:
        return 0.0
    m = np.argmax(v)
    s = max(0, m - 10)
    e = min(v.size, m + 11)
    return abs(v[s:e].mean())

amps = np.apply_along_axis(calc_max_drop_amp, 1, raw_train)
keep2 = []
for lbl in np.unique(y_train):
    idxs = np.where(y_train == lbl)[0]
    order = idxs[np.argsort(amps[idxs])]
    n = len(order)
    low, high = int(n * x), int(n * (1 - x))
    keep2.extend(order[low:high])

keep2 = np.sort(keep2)
raw_train = raw_train[keep2]
y_train   = y_train[keep2]

del amps, keep_idx, keep2
gc.collect()

print(f"训练集样本数（清洗后，尚未下采样）：{len(raw_train)}")

# =========================================================
# 5.1) 训练集下采样：用 TRAIN_FRACTION 控制规模
# =========================================================

if 0.0 < TRAIN_FRACTION < 1.0:
    rng = np.random.RandomState(42)
    selected_idx = []

    # 按标签分层下采样，尽量保持类别分布
    for lbl in np.unique(y_train):
        idxs = np.where(y_train == lbl)[0]
        n_lbl = len(idxs)
        n_keep_lbl = max(1, int(n_lbl * TRAIN_FRACTION))
        chosen = rng.choice(idxs, size=n_keep_lbl, replace=False)
        selected_idx.extend(chosen)

    selected_idx = np.array(selected_idx)
    raw_train = raw_train[selected_idx]
    y_train   = y_train[selected_idx]

print(f"训练集样本数（下采样后，比例={TRAIN_FRACTION:.2f}）：{len(raw_train)}")

# =========================================================
# 6) 标准化：仅用训练集 fit，然后 transform 三个子集
# =========================================================
scaler = StandardScaler()
scaler.fit(raw_train)                 # 只用训练集
X_train_np = scaler.transform(raw_train)
X_val_np   = scaler.transform(raw_val)
X_test_np  = scaler.transform(raw_test)

del raw_train, raw_val, raw_test
gc.collect()

# 转为 torch.Tensor
X_train = torch.tensor(X_train_np, dtype=torch.float32)
X_val   = torch.tensor(X_val_np,   dtype=torch.float32)
X_test  = torch.tensor(X_test_np,  dtype=torch.float32)
y_train = torch.tensor(y_train,    dtype=torch.long)
y_val   = torch.tensor(y_val,      dtype=torch.long)
y_test  = torch.tensor(y_test,     dtype=torch.long)

del X_train_np, X_val_np, X_test_np
gc.collect()



# =========================================================
# 7) 通道构造：多滑动窗口通道
# =========================================================
use_sliding_std = 1     # 打开滑动窗口  # 1：使用滑动标准差双通道；0：只用单通道
windows = [10,20,40]      # 想要使用的滑动窗口：10 和 20

def sliding_std(x: torch.Tensor, window_size: int) -> torch.Tensor:
    patches = x.unfold(1, window_size, 1)          # (N, L-ws+1, ws)
    stds    = patches.std(dim=2, unbiased=False)   # (N, L-ws+1)
    pad_l   = (window_size - 1) // 2
    pad_r   = window_size - 1 - pad_l
    return F.pad(stds, (pad_l, pad_r), value=0.0)  # (N, L)

if use_sliding_std:
    channels_train = [X_train]  # 通道 0：原始信号
    channels_val   = [X_val]
    channels_test  = [X_test]

    for ws in windows:
        std_tr = sliding_std(X_train, ws)
        std_v  = sliding_std(X_val,   ws)
        std_te = sliding_std(X_test,  ws)
        channels_train.append(std_tr)
        channels_val.append(std_v)
        channels_test.append(std_te)

    X_train = torch.stack(channels_train, dim=1)  # [N, C, L]，C = 1 + len(windows)
    X_val   = torch.stack(channels_val, dim=1)
    X_test  = torch.stack(channels_test, dim=1)

    del std_tr, std_v, std_te
    gc.collect()

else:
    X_train = X_train.unsqueeze(1)
    X_val   = X_val.unsqueeze(1)
    X_test  = X_test.unsqueeze(1)

print(f"训练集样本数：{X_train.size(0)}, 通道数：{X_train.size(1)}, 长度：{X_train.size(2)}")





################################################################################################################################################训练
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt

# ====== Attention After TCN (unchanged) ======
class AttentionAfterTCN(nn.Module):
    def __init__(self, channel_in, window_size=100, stride=100,
                 hidden_dim=32, num_heads=2, dropout=0):
        super().__init__()
        self.window_size = window_size
        self.stride = stride
        self.linear_proj = nn.Linear(channel_in * window_size, hidden_dim)
        self.attn = nn.MultiheadAttention(embed_dim=hidden_dim,
                                          num_heads=num_heads,
                                          dropout=dropout,
                                          batch_first=True)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, C, L = x.shape
        N = (L - self.window_size) // self.stride + 1

        # 1) unfold into windows
        x_unfold = x.unfold(2, self.window_size, self.stride)           # (B, C, N, W)
        x_unfold = x_unfold.permute(0,2,1,3).reshape(B, N, C*self.window_size)

        # 2) linear + ReLU
        x_proj = F.relu(self.linear_proj(x_unfold))                    # (B, N, hidden_dim)

        # 3) self-attention
        attn_out, _ = self.attn(x_proj, x_proj, x_proj)                # (B, N, hidden_dim)

        # 4) residual
        if attn_out.shape == x_proj.shape:
            attn_out = attn_out + x_proj

        # 5) mean-pool over windows + dropout
        out = attn_out.mean(dim=1)                                     # (B, hidden_dim)
        return self.dropout(out)

# ====== TCNWithAttention using your TCN backbone ======
class TCNWithAttention(nn.Module):
    def __init__(self,
                 input_channels=1,
                 output_channels=128,
                 kernel_size=3,
                 num_layers=8,
                 num_classes=20,
                 threshold=0.9,
                 window_size=50,
                 stride=50,
                 attn_dim=512,
                 num_heads=8,
                 dropout=0.8):
        super().__init__()
        # TCN backbone: conv -> ReLU -> pool each layer
        self.tcn_layers = nn.ModuleList()
        in_ch = input_channels
        for i in range(num_layers):
            dilation = 2 ** i
            padding = (kernel_size - 1) * dilation
            block = nn.Sequential(
                nn.Conv1d(in_ch, output_channels,
                          kernel_size=kernel_size,
                          padding=padding,
                          dilation=dilation),
                nn.ReLU(),
                nn.MaxPool1d(kernel_size=2)
            )
            self.tcn_layers.append(block)
            in_ch = output_channels

        # Attention & classifier
        self.attn_block = AttentionAfterTCN(output_channels,
                                            window_size, stride,
                                            attn_dim, num_heads, dropout)
        self.fc = nn.Linear(attn_dim, num_classes)
        self.threshold = threshold

    def forward(self, x):
        # x: (B, C, L)
        for block in self.tcn_layers:
            x = block(x)   # conv -> relu -> pool

        # align length to windows
        L = x.size(2)
        N = (L - self.attn_block.window_size) // self.attn_block.stride + 1
        needed = N * self.attn_block.stride + self.attn_block.window_size
        x = x[:, :, :needed]

        # Attention + classifier
        out = self.attn_block(x)        # (B, attn_dim)
        logits = self.fc(out)           # (B, num_classes)
        return logits                   # raw logits

    def predict(self, x):
        logits = self.forward(x)
        probs = F.softmax(logits, dim=1)
        conf, pred = torch.max(probs, dim=1)
        pred[conf < self.threshold] = -1
        return pred, conf

# ====== Dynamic Class Weight Tracker (unchanged) ======
class DynamicClassWeightTracker:
    def __init__(self, num_classes, device,
                 alpha=0.9, epsilon=1e-6,
                 weight_min=0.5, weight_max=1000.0):
        self.num_classes = num_classes
        self.device = device
        self.alpha = alpha
        self.epsilon = epsilon
        self.weight_min = weight_min
        self.weight_max = weight_max
        self.smoothed_accuracies = torch.ones(
            num_classes, dtype=torch.float32).to(device)

    def update(self, preds, targets):
        for c in range(self.num_classes):
            mask = (targets == c)
            total = mask.sum().item()
            correct = (preds[mask] == targets[mask]).sum().item()
            acc = correct / total if total > 0 else 1.0
            self.smoothed_accuracies[c] = (
                self.alpha * self.smoothed_accuracies[c]
                + (1 - self.alpha) * acc
            )
        weights = 1.0 / (self.smoothed_accuracies + self.epsilon)
        return torch.clamp(weights, self.weight_min, self.weight_max)

# ====== Training Preparation ======
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")                         ########### ！！！！！！！！！！！！！！！！！！！！！
model = TCNWithAttention(
    input_channels=X_train.size(1),
    output_channels=64,
    kernel_size=3,
    num_layers=5,
    num_classes=20,
    threshold=0.9,
    window_size=5,
    stride=5,
    attn_dim=128,
    num_heads=8,
    dropout=0.7
).to(device)

train_loader = DataLoader(TensorDataset(X_train, y_train),
                          batch_size=64, shuffle=True)
val_loader   = DataLoader(TensorDataset(X_val,   y_val),
                          batch_size=64)

optimizer = optim.AdamW(model.parameters(), lr=1e-4)
scheduler = OneCycleLR(
    optimizer,
    max_lr=1e-3,
    steps_per_epoch=len(train_loader),
    epochs=40,
    pct_start=0.1,
    anneal_strategy='cos',
    final_div_factor=1e4
)
criterion = nn.CrossEntropyLoss()
tracker = DynamicClassWeightTracker(num_classes=20, device=device)

train_accs, val_accs_max, val_accs_pred, coverages = [], [], [], []

# ====== Training Loop ======
for epoch in range(40):
    # —— train ——
    model.train()
    correct = total = 0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        loss   = criterion(logits, yb)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        preds = logits.argmax(dim=1)
        correct += (preds == yb).sum().item()
        total += yb.size(0)
    train_accs.append(100 * correct / total)

   ###
    # —— validate ——
    model.eval()
    cm_max = cm_pred = total_max = total_pred = 0
    with torch.no_grad():
        for xb, yb in val_loader:
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb)
            # Val Acc(max)
            _, pm = torch.max(logits, dim=1)
            cm_max += (pm == yb).sum().item()
            total_max += yb.size(0)
            # Val Acc(pred)
            pp, conf = model.predict(xb)
            mask = (pp != -1)   # ✅ 用预测标签是否为 -1 来筛选“被接受”的样本
            if mask.any():
                cm_pred   += (pp[mask] == yb[mask]).sum().item()   # 只在“被接受”的样本上计算准确率
                total_pred += mask.sum().item()                    # 分母 = 被接受样本数（coverage）




    val_accs_max.append(100 * cm_max / total_max)
    val_accs_pred.append(100 * cm_pred / total_pred if total_pred > 0 else 0.0)
    cov = 100 * (total_pred / total_max) if total_max > 0 else 0.0  # Coverage（被接收比例）
    coverages.append(cov)

    # —— dynamic weight update ——
    if epoch >= 1:
        all_p, all_t = [], []
        with torch.no_grad():
            for xb, yb in train_loader:
                xb, yb = xb.to(device), yb.to(device)
                out = model(xb)
                all_p.append(out.argmax(dim=1))
                all_t.append(yb)
        wp = tracker.update(torch.cat(all_p), torch.cat(all_t))
        criterion = nn.CrossEntropyLoss(weight=wp.to(device))

    print(f"Epoch {epoch+1:02d} | "
          f"Train Acc: {train_accs[-1]:.2f}% | "
          f"Val Acc(max): {val_accs_max[-1]:.2f}% | "
          f"Val Acc(pred): {val_accs_pred[-1]:.2f}% | "
          f"Coverage: {coverages[-1]:.1f}%")



# --------------------------- 曲线可视化 ---------------------------
plt.figure(figsize=(10,5))
plt.plot(train_accs, label="Train Acc")
plt.plot(val_accs_max, label="Val Acc (max)")
plt.plot(val_accs_pred, label="Val Acc (predict, with rejection)")
plt.plot(coverages, label="Coverage (%)")
plt.xlabel("Epoch")
plt.ylabel("Accuracy / %")
plt.title("Training Curve: TCN (pool every 2) + CLS-Attn + ArcFace")
plt.legend()
plt.grid(True)
plt.show()
###


# ================= Validation: threshold calibration + per-class gate stats + report =================
import os, json
import numpy as np

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.metrics import confusion_matrix


# ----------------- A) 配置：选择校准策略 -----------------
CAL_POLICY      = "target_coverage"   # 可选: "target_coverage" 或 "target_risk"
TARGET_VALUE    = 0.65                # policy="target_coverage" -> 覆盖率目标(0~1)；policy="target_risk" -> 风险=1-准确率
CALIB_BATCHSIZE = 512                 # 校准批大小
PLOT_CALIB      = True                # 是否绘制 置信度直方图 + 风险-覆盖率曲线
SAVE_JSON_PATH  = "./val_gate_stats.json"  # 门控参数缓存文件（供代码1直接读取）

# ----------------- B) 在验证集上标定阈值（风险-覆盖率） -----------------
@torch.no_grad()
def calibrate_threshold_on_val(model, X_val, y_val,
                               device="cpu", batch_size=512,
                               policy="target_coverage", target=0.90,
                               plot=True):
    """
    选择性分类标定 Softmax 阈值：
      - 对每个样本取最大置信度 conf 与是否预测正确 correct。
      - 按 conf 降序累加 -> (coverage, acc_on_accepted, risk) 曲线。
      - policy="target_coverage": 选 coverage >= target 时风险最小的点；
        policy="target_risk": 选 risk <= target 时覆盖率最大的点。
    返回：thr_star, coverage_star, acc_on_accepted_star，并写回 model.threshold。
    """
    model.eval()
    loader = DataLoader(TensorDataset(X_val, y_val), batch_size=batch_size, shuffle=False)

    confs, corrects = [], []
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        probs  = F.softmax(logits, dim=1)
        conf, pred = probs.max(dim=1)
        confs.append(conf.cpu())
        corrects.append((pred == yb).cpu())

    confs   = torch.cat(confs).numpy()                         # [N]
    correct = torch.cat(corrects).numpy().astype(np.int32)
    N = len(confs)
    if N == 0:
        raise RuntimeError("Validation set is empty, cannot calibrate threshold.")

    # 置信度降序 -> 覆盖率曲线
    order = np.argsort(-confs)
    conf_sorted    = confs[order]
    correct_sorted = correct[order]

    k = np.arange(1, N+1)
    coverage = k / N
    acc_on_accepted = np.cumsum(correct_sorted) / k
    risk = 1.0 - acc_on_accepted

    # 选点
    if policy == "target_coverage":
        idx0 = int(np.searchsorted(coverage, target, side="left"))
        idx0 = np.clip(idx0, 0, N-1)
        tail = np.arange(idx0, N)
        idx_star = int(tail[np.argmin(risk[tail])])  # 该覆盖率及右侧中风险最小
    elif policy == "target_risk":
        ok = np.where(risk <= target)[0]
        idx_star = int(ok[-1]) if ok.size > 0 else int(np.argmin(risk))
    else:
        raise ValueError("CAL_POLICY must be 'target_coverage' or 'target_risk'.")

    thr_star = float(conf_sorted[idx_star])
    cov_star = float(coverage[idx_star])
    acc_star = float(acc_on_accepted[idx_star])

    # 写回阈值
    model.threshold = thr_star

    print("\n[Calibration]")
    print(f"  policy = {policy}, target = {target}")
    print(f"  -> threshold = {thr_star:.4f}")
    print(f"  -> coverage  = {cov_star*100:.2f}%")
    print(f"  -> acc_on_accepted (1-risk) = {acc_star*100:.2f}%")

    if plot:
        fig = plt.figure(figsize=(12,4.5))
        ax1 = fig.add_subplot(1,2,1)
        ax1.hist(confs[correct==1], bins=40, alpha=0.6, label="Correct", density=True)
        ax1.hist(confs[correct==0], bins=40, alpha=0.6, label="Incorrect", density=True)
        ax1.axvline(thr_star, ls="--", c="k", label=f"Threshold {thr_star:.2f}")
        ax1.set_xlabel("Max softmax confidence")
        ax1.set_ylabel("Density")
        ax1.set_title("Validation confidence distribution")
        ax1.legend()

        # —— 把风险-覆盖率改为“接收样本上的准确率（Precision/Accuracy on accepted）-覆盖率” —— #
        ax2 = fig.add_subplot(1,2,2)
        ax2.plot(coverage, acc_on_accepted, lw=2)
        ax2.scatter([cov_star], [acc_star], s=50, edgecolors="k",
                    label=f"Chosen: cov={cov_star*100:.1f}%, acc={acc_star*100:.2f}%")
        ax2.set_xlabel("Coverage (accepted proportion)")
        ax2.set_ylabel("Accuracy on accepted (Precision)")
        ax2.set_title("Accuracy–Coverage curve (validation)")
        ax2.set_xlim(0.2, 1)
        ax2.set_ylim(0.92, 1)
        ax2.grid(True, alpha=0.3)
        ax2.legend()
        plt.tight_layout()
        plt.show()

    return thr_star, cov_star, acc_star

# ----------------- C) 先标定阈值 -----------------
thr_star, cov_star, acc_star = calibrate_threshold_on_val(
    model, X_val, y_val,
    device=device,
    batch_size=CALIB_BATCHSIZE,
    policy=CAL_POLICY,
    target=TARGET_VALUE,
    plot=PLOT_CALIB
)
print(f"[Model] model.threshold 已设置为 {model.threshold:.4f}")

# ----------------- D) 用选定阈值统计每类 gate 参数 (r_c, a_c, S_1mr, S_full) 并缓存 -----------------
@torch.no_grad()
def compute_val_gate_stats(model, X_val, y_val, thr, device="cpu", batch_size=512):
    model.eval()
    loader = DataLoader(TensorDataset(X_val, y_val), batch_size=batch_size, shuffle=False)
    y_true_list, y_hat_list, p_max_list = [], [], []
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        probs  = F.softmax(logits, dim=1)
        p_max, y_hat = probs.max(dim=1)
        y_true_list.append(yb.cpu()); y_hat_list.append(y_hat.cpu()); p_max_list.append(p_max.cpu())
    y_true = torch.cat(y_true_list).numpy()
    y_hat  = torch.cat(y_hat_list).numpy()
    p_max  = torch.cat(p_max_list).numpy()
    num_classes = int(max(y_true.max(), y_hat.max())) + 1

    accepted = (p_max >= thr)
    gate = {"selected_threshold": float(thr), "per_class": {}, "overall": {}}

    # overall
    acc_on_accepted = float((y_hat[accepted] == y_true[accepted]).mean()) if accepted.any() else 0.0
    coverage = float(accepted.mean())
    gate["overall"] = {"acc_on_accepted": acc_on_accepted, "coverage": coverage}

    # per-class
    for c in range(num_classes):
        is_c = (y_true == c)
        n_c  = int(is_c.sum())
        if n_c == 0:
            gate["per_class"][str(c)] = {"r": None, "a": None, "S_1mr": None, "S_full": None}
            continue
        # r_c: 真 c 被拒识比例
        rc = float((~accepted[is_c]).mean())
        # a_c: 被接收的真 c 中预测为 c 的准确率
        mask_acc_c = accepted & is_c
        if mask_acc_c.any():
            ac = float((y_hat[mask_acc_c] == c).mean())
        else:
            ac = 0.0
        gate["per_class"][str(c)] = {
            "r": rc,
            "a": ac,
            "S_1mr": 1.0 - rc,
            "S_full": (1.0 - rc) * ac
        }
    return gate

VAL_GATE_STATS = compute_val_gate_stats(model, X_val, y_val, thr_star, device=device, batch_size=CALIB_BATCHSIZE)
VAL_SELECTED_THR = float(thr_star)

# 保存到 json，供 Monte-Carlo/GCY 模块直接复用
with open(SAVE_JSON_PATH, "w") as f:
    json.dump(VAL_GATE_STATS, f, indent=2)
print(f"[Saved] per-class gate stats -> {os.path.abspath(SAVE_JSON_PATH)}")

# ----------------- E) 用选定阈值输出逐类指标 & （可选）混淆矩阵 -----------------
# 重新前向（也可改为复用上面结果）
val_loader = DataLoader(TensorDataset(X_val, y_val), batch_size=256, shuffle=False)
y_true_list, y_hat_list, p_max_list = [], [], []
with torch.no_grad():
    for xb, yb in val_loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        probs  = F.softmax(logits, dim=1)
        p_max, y_hat = probs.max(dim=1)
        y_true_list.append(yb.cpu()); y_hat_list.append(y_hat.cpu()); p_max_list.append(p_max.cpu())
y_true = torch.cat(y_true_list).numpy()
y_hat  = torch.cat(y_hat_list).numpy()
p_max  = torch.cat(p_max_list).numpy()

accepted_mask = (p_max >= VAL_SELECTED_THR)
num_classes = int(max(y_true.max(), y_hat.max())) + 1

print("\n==== Accuracy and Uncertainty Per Label (Validation @ selected threshold) ====")
overall_correct_conf = 0
overall_total_conf   = 0
overall_not_pred     = 0

for c in range(num_classes):
    mask_c  = (y_true == c)
    total_c = int(mask_c.sum())
    if total_c == 0:
        continue

    acc_max = float((y_hat[mask_c] == c).mean()) * 100.0

    mask_acc = mask_c & accepted_mask
    accept_c = int(mask_acc.sum())
    if accept_c > 0:
        acc_conf = float((y_hat[mask_acc] == c).mean()) * 100.0
        correct_c = int((y_hat[mask_acc] == c).sum())
        acc_conf_display = f"{acc_conf:5.2f}%"
    else:
        acc_conf_display = "  NaN"
        correct_c = 0

    not_pred_pct = float((~accepted_mask[mask_c]).mean()) * 100.0

    overall_correct_conf += correct_c
    overall_total_conf   += accept_c
    overall_not_pred     += int((~accepted_mask[mask_c]).sum())

    print(f"Label {c:02d}: 总样本={total_c:4d}, "
          f"argmax Acc={acc_max:5.2f}%, "
          f"阈值 Acc={acc_conf_display:>6}, "
          f"Not Predicted={not_pred_pct:5.2f}% "
          f"(正确 {correct_c}/{accept_c})")

overall_acc_conf = (overall_correct_conf / overall_total_conf * 100.0) if overall_total_conf > 0 else 0.0
overall_not_pred_pct = overall_not_pred / y_true.size * 100.0
print(f"\nOverall Accuracy (on accepted): {overall_acc_conf:.2f}%")
print(f"Overall Not Predicted Percentage: {overall_not_pred_pct:.2f}%")
print(f"Selected threshold: {VAL_SELECTED_THR:.4f}")

# ——（可选）只在“被接收”的样本上画混淆矩阵 —— #
PLOT_CM = True
if PLOT_CM:
    valid = accepted_mask
    y_true_f = y_true[valid]
    y_pred_f = y_hat[valid]
    if y_true_f.size > 0:
        labels_order = sorted(list(set(y_true_f.tolist()) | set(y_pred_f.tolist())))
        cm = confusion_matrix(y_true_f, y_pred_f, labels=labels_order)

        plt.figure(figsize=(12, 9))
        sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                    xticklabels=labels_order, yticklabels=labels_order)
        plt.xlabel('Predicted Label'); plt.ylabel('True Label')
        plt.title('Confusion Matrix on Accepted Samples (Validation)')
        plt.show()

        cm_norm = cm.astype(float) / cm.sum(axis=1, keepdims=True)
        plt.figure(figsize=(12, 9))
        sns.heatmap(cm_norm, annot=True, fmt=".2f", cmap="Blues",
                    xticklabels=labels_order, yticklabels=labels_order)
        plt.xlabel('Predicted Label'); plt.ylabel('True Label')
        plt.title('Normalized Confusion Matrix on Accepted Samples (Validation)')
        plt.show()
    else:
        print("（所有验证样本都被拒识，跳过混淆矩阵绘制）")
# ====================================================================================================




import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix



from torch.utils.data import TensorDataset, DataLoader

# ====== 在绘图之后，替换之前的一次性 model(X_val) 评估，改为分批评估 ======
model.eval()

# 创建验证集 DataLoader，batch_size 根据你的显存调整，32 或者更小
val_dataset = TensorDataset(X_test, y_test)
val_loader  = DataLoader(val_dataset, batch_size=32, shuffle=False)

all_true       = []
all_pred_max   = []
all_pred_conf  = []

with torch.no_grad():
    for xb, yb in val_loader:
        xb, yb = xb.to(device), yb.to(device)
        out = model(xb)                         # [B, num_classes]
        _, pm = torch.max(out, dim=1)           # 普通 argmax
        pc, _ = model.predict(xb)               # 带阈值预测

        all_true.append(yb.cpu())
        all_pred_max.append(pm.cpu())
        all_pred_conf.append(pc.cpu())

all_true      = torch.cat(all_true)
all_pred_max  = torch.cat(all_pred_max)
all_pred_conf = torch.cat(all_pred_conf)

print("\n=== 测试集上各标签准确度（分批计算） ===")
num_classes = all_pred_max.max().item() + 1
for c in range(num_classes):
    mask_c = (all_true == c)
    total_c = mask_c.sum().item()
    if total_c == 0:
        continue

    # 普通 argmax 准确度
    correct_max_c = (all_pred_max[mask_c] == c).sum().item()
    acc_max = correct_max_c / total_c * 100

    # 带阈值的准确度（只对那些预测 != -1 的样本算）
    mask_conf_c = mask_c & (all_pred_conf != -1)
    count_conf_c = mask_conf_c.sum().item()
    if count_conf_c > 0:
        correct_conf_c = (all_pred_conf[mask_conf_c] == c).sum().item()
        acc_conf = correct_conf_c / count_conf_c * 100
    else:
        acc_conf = float('nan')
        correct_conf_c = 0

    print(f"Label {c:02d}: "
          f" 总样本={total_c:3d}, "
          f"argmax Acc={acc_max:5.2f}%, "
          f"阈值 Acc={acc_conf if not np.isnan(acc_conf) else ' NaN':6}%, "
          f"(正确 {correct_conf_c}/{count_conf_c})")


# 评估模型（预测、准确率、未预测比例）
def evaluate_model_per_label_before_cm(model, test_signals, test_labels, device='cpu', batch_size=32):
    model.eval()
    unique_labels = torch.unique(test_labels).tolist()

    correct_per_label = {label: 0 for label in unique_labels}
    total_per_label = {label: 0 for label in unique_labels}
    not_predicted_per_label = {label: 0 for label in unique_labels}

    total_correct_all = 0
    total_samples_all = 0
    total_not_predicted_all = 0

    test_signals = test_signals.to(device)
    test_labels = test_labels.to(device)

    confident_predictions_all = []
    confident_labels_all = []

    with torch.no_grad():
        for i in range(0, test_signals.size(0), batch_size):
            input_batch = test_signals[i:i + batch_size]
            target_batch = test_labels[i:i + batch_size]

            outputs = model(input_batch)
            probabilities = F.softmax(outputs, dim=1)   # ← 在这里加一行
            confident_probs, confident_predictions = torch.max(probabilities, dim=1)
            confident_predictions[confident_probs < model.threshold] = -1

            confident_predictions_all.append(confident_predictions)
            confident_labels_all.append(target_batch)

            for label in unique_labels:
                mask = target_batch == label
                correct_per_label[label] += (confident_predictions[mask] == target_batch[mask]).sum().item()
                total_per_label[label] += mask.sum().item()
                not_predicted_per_label[label] += (confident_predictions[mask] == -1).sum().item()

            total_correct_all += (confident_predictions == target_batch).sum().item()
            total_samples_all += target_batch.size(0)
            total_not_predicted_all += (confident_predictions == -1).sum().item()

    confident_predictions_all = torch.cat(confident_predictions_all)
    confident_labels_all = torch.cat(confident_labels_all)

    accuracy_per_label = {
        label: (correct_per_label[label] / total_per_label[label]) * 100 if total_per_label[label] > 0 else 0
        for label in total_per_label
    }

    not_predicted_per_label_percent = {
        label: (not_predicted_per_label[label] / total_per_label[label]) * 100 if total_per_label[label] > 0 else 0
        for label in total_per_label
    }

    overall_accuracy = (total_correct_all / total_samples_all) * 100 if total_samples_all > 0 else 0
    overall_not_predicted_percent = (total_not_predicted_all / total_samples_all) * 100 if total_samples_all > 0 else 0

    return accuracy_per_label, not_predicted_per_label_percent, overall_accuracy, overall_not_predicted_percent, confident_predictions_all, confident_labels_all


# 评估并绘制混淆矩阵
def evaluate_model_per_label_after_cm(model, test_signals, test_labels, device='cpu', batch_size=32):
    accuracy_per_label, not_predicted_per_label_percent, overall_accuracy, overall_not_predicted_percent, confident_predictions_all, confident_labels_all = evaluate_model_per_label_before_cm(
        model, test_signals, test_labels, device=device, batch_size=batch_size)

    valid_mask = confident_predictions_all != -1
    filtered_predicted = confident_predictions_all[valid_mask]
    filtered_true_labels = confident_labels_all[valid_mask]

    cm = confusion_matrix(filtered_true_labels.cpu().numpy(), filtered_predicted.cpu().numpy())
    unique_labels = sorted(list(set(filtered_true_labels.cpu().numpy()) | set(filtered_predicted.cpu().numpy())))

    correct_per_label = {label: 0 for label in unique_labels}
    total_per_label = {label: 0 for label in unique_labels}
    not_predicted_per_label = {label: 0 for label in unique_labels}

    for label in unique_labels:
        mask = filtered_true_labels == label
        correct_per_label[label] = (filtered_predicted[mask] == filtered_true_labels[mask]).sum().item()
        total_per_label[label] = mask.sum().item()
        not_predicted_per_label[label] = (filtered_predicted[mask] == -1).sum().item()

    total_correct_all = (filtered_predicted == filtered_true_labels).sum().item()
    total_samples_all = filtered_true_labels.size(0)
    total_not_predicted_all = (filtered_predicted == -1).sum().item()

    accuracy_per_label_after_cm = {
        label: (correct_per_label[label] / total_per_label[label]) * 100 if total_per_label[label] > 0 else 0
        for label in total_per_label
    }

    not_predicted_per_label_percent_after_cm = {
        label: (not_predicted_per_label[label] / total_per_label[label]) * 100 if total_per_label[label] > 0 else 0
        for label in total_per_label
    }

    overall_accuracy_after_cm = (total_correct_all / total_samples_all) * 100 if total_samples_all > 0 else 0
    overall_not_predicted_percent_after_cm = (total_not_predicted_all / total_samples_all) * 100 if total_samples_all > 0 else 0

    return accuracy_per_label_after_cm, not_predicted_per_label_percent_after_cm, overall_accuracy_after_cm, overall_not_predicted_percent_after_cm, cm, unique_labels


# ======= 用法示例（假设你已经有 model、X_test、y_test）=======

# Step 1: 初步评估
acc1, np1, overall_acc1, overall_np1, preds, labels = evaluate_model_per_label_before_cm(
    model, X_test, y_test, device=device)

print("\n==== Accuracy and Uncertainty Per Label ====")
for label in acc1:
    print(f'Label {label}: Accuracy = {acc1[label]:.2f}%, Not Predicted = {np1[label]:.2f}%')

print(f'\nOverall Accuracy: {overall_acc1:.2f}%')
print(f'Overall Not Predicted Percentage: {overall_np1:.2f}%')


# Step 2: 混淆矩阵分析
acc2, np2, overall_acc2, overall_np2, cm, label_order = evaluate_model_per_label_after_cm(
    model, X_test, y_test, device=device)

# ====== 在这里打印具体数值 ======
print("\n==== Test Confusion Matrix (absolute counts) ====")
print("Label order (rows = true, cols = pred):", label_order)
print(cm)

cm_normalized = cm.astype('float') / cm.sum(axis=1, keepdims=True)
print("\n==== Test Confusion Matrix (row-normalized) ====")
print(cm_normalized)
# ====== 打印结束，下面保持原来的画图 ======

# 绘图：混淆矩阵
plt.figure(figsize=(12, 9))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
            xticklabels=label_order, yticklabels=label_order)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Confusion Matrix (Absolute Values)')
plt.show()

# 绘图：归一化混淆矩阵
plt.figure(figsize=(12, 9))
sns.heatmap(cm_normalized, annot=True, fmt=".2f", cmap="Blues",
            xticklabels=label_order, yticklabels=label_order)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Confusion Matrix (Normalized)')
plt.show()


