In [None]:
# -*- coding: utf-8 -*-
# 对齐代码2：histogram 输入使用“未标准化 raw + 非负截断（<=0置0）”，并用 baseline/bin 分箱 + z-score
# 说明：本脚本只针对“对比模型分支（hist+StevieNet）”严格复现代码2的 my_binning 输入体系；
#       不再对直方图输入做 StandardScaler（你原代码里的 scaler 逻辑已移除/绕开）。

import gc
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader, TensorDataset

import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split



TRAIN_FRACTION = 1  # 改成 0.1 即使用清洗后训练集的 10%

# =========================================================
# 0) 读取数据
# =========================================================
# parquet_path = "/content/drive/MyDrive/CA-AA/Cleandata_Lable20_N157794_ML30_MU3000.parquet"
parquet_path = "/content/drive/MyDrive/CA-AA/Cleandata5_Lable20_N83597_5000_ML30_MU3000.parquet"
data = pd.read_parquet(parquet_path)
print("读取完成:", data.shape)

# =========================================================
# 1) 字段约定
#   第1列：label
#   第2列：baseline
#   第4列开始：时序特征
# =========================================================
label_col = data.columns[0]
baseline_col = data.columns[1]
feature_cols = data.columns[3:]  # 第4列开始为时序特征

# =========================================================
# 2) Step 1: 计算每个样本有效长度（非零个数）
# =========================================================
data = data.copy()
data["valid_length"] = (data[feature_cols] != 0).sum(axis=1)

# =========================================================
# 3) Step 2: 按标签剔除异常（你原来是 0.35~0.65 分位）
# =========================================================
def remove_outliers_by_label(df):
    q_low = df["valid_length"].quantile(0.35)
    q_high = df["valid_length"].quantile(0.65)
    return df[(df["valid_length"] >= q_low) & (df["valid_length"] <= q_high)]

cleaned_data = (
    data.groupby(label_col, group_keys=False)
        .apply(remove_outliers_by_label)
)

cleaned_data = cleaned_data.drop(columns=["valid_length"])
print(f"清洗前样本数: {len(data)}")
print(f"清洗后样本数: {len(cleaned_data)}")

data = cleaned_data
del cleaned_data
gc.collect()

# =========================================================
# 4) Step 3: 删除指定标签
# =========================================================
remove_labels = [8, 12, 14]  # 你原代码
data = data[~data.iloc[:, 0].isin(remove_labels)].copy()
print("删除标签后 shape:", data.shape)

# =========================================================
# 5) 提取 label / baseline / raw 序列（注意：此 raw 将作为 histogram 输入，不做 StandardScaler）
# =========================================================
labels = data[label_col].values.astype(np.int64)          # (N,)
baselines = data[baseline_col].values.astype(np.float32)  # (N,)
raw = data[feature_cols].values.astype(np.float32)        # (N, L)

del data
gc.collect()

# =========================================================
# 6) 拆分：train / val / test（baseline 同步拆分）
# =========================================================
raw_train, raw_tmp, y_train, y_tmp, b_train, b_tmp = train_test_split(
    raw, labels, baselines,
    test_size=0.2,
    random_state=42,
    stratify=labels
)

raw_val, raw_test, y_val, y_test, b_val, b_test = train_test_split(
    raw_tmp, y_tmp, b_tmp,
    test_size=0.5,
    random_state=42,
    stratify=y_tmp
)

del raw, raw_tmp, y_tmp, b_tmp
gc.collect()

print(f"训练集样本数（拆分后）：{len(raw_train)}")

# =========================================================
# 7) 训练集清洗1：按非零长度去除每标签最小/最大 x%
# =========================================================
x = 0.0  # 你原代码保持 0.0
keep_idx = []
for lbl in np.unique(y_train):
    idxs = np.where(y_train == lbl)[0]
    nz = (raw_train[idxs] != 0).sum(axis=1)
    order = idxs[np.argsort(nz)]
    n = len(order)
    low, high = int(n * x), int(n * (1 - x))
    keep_idx.extend(order[low:high])

keep_idx = np.sort(np.array(keep_idx))
raw_train = raw_train[keep_idx]
y_train   = y_train[keep_idx]
b_train   = b_train[keep_idx]
gc.collect()

# =========================================================
# 8) 训练集清洗2：按“最大下降幅度”去除每标签最小/最大 x%
# =========================================================
x = 0.0  # 同上
def calc_max_drop_amp(arr):
    v = arr[~np.isnan(arr) & (arr != 0)]
    if v.size == 0:
        return 0.0
    m = np.argmax(v)
    s = max(0, m - 10)
    e = min(v.size, m + 11)
    return abs(v[s:e].mean())

amps = np.apply_along_axis(calc_max_drop_amp, 1, raw_train)

keep2 = []
for lbl in np.unique(y_train):
    idxs = np.where(y_train == lbl)[0]
    order = idxs[np.argsort(amps[idxs])]
    n = len(order)
    low, high = int(n * x), int(n * (1 - x))
    keep2.extend(order[low:high])

keep2 = np.sort(np.array(keep2))
raw_train = raw_train[keep2]
y_train   = y_train[keep2]
b_train   = b_train[keep2]

del amps, keep_idx, keep2
gc.collect()

print(f"训练集样本数（清洗后）：{len(raw_train)}")

# =========================================================
# 8.1) 训练集下采样：用 TRAIN_FRACTION 控制训练集规模（如 0.1 表示 10%）
# =========================================================


if 0.0 < TRAIN_FRACTION < 1.0:
    rng = np.random.RandomState(42)
    selected_idx = []

    # 按标签分层下采样，尽量保持类别分布
    for lbl in np.unique(y_train):
        idxs = np.where(y_train == lbl)[0]
        n_lbl = len(idxs)
        n_keep_lbl = max(1, int(n_lbl * TRAIN_FRACTION))
        chosen = rng.choice(idxs, size=n_keep_lbl, replace=False)
        selected_idx.extend(chosen)

    selected_idx = np.array(selected_idx)
    raw_train = raw_train[selected_idx]
    y_train   = y_train[selected_idx]
    b_train   = b_train[selected_idx]

    print(f"训练集样本数（下采样后，比例={TRAIN_FRACTION:.2f}）：{len(raw_train)}")

# =========================================================
# 9) 关键改动：对比模型（hist）分支不做 StandardScaler
#    直接把 raw_* 转 torch，后续 histogram 内部做 nan_to_num + 非负截断
# =========================================================
X_train = torch.tensor(raw_train, dtype=torch.float32)
X_val   = torch.tensor(raw_val,   dtype=torch.float32)
X_test  = torch.tensor(raw_test,  dtype=torch.float32)

y_train = torch.tensor(y_train, dtype=torch.long)
y_val   = torch.tensor(y_val,   dtype=torch.long)
y_test  = torch.tensor(y_test,  dtype=torch.long)

baseline_train = torch.tensor(b_train, dtype=torch.float32)
baseline_val   = torch.tensor(b_val,   dtype=torch.float32)
baseline_test  = torch.tensor(b_test,  dtype=torch.float32)

del raw_train, raw_val, raw_test, b_train, b_val, b_test
gc.collect()

# =========================================================
# 10) 通道构造：保持你原来的逻辑（默认单通道）
# =========================================================
use_sliding_std = 0
ws = 10

def sliding_std(x: torch.Tensor, window_size: int) -> torch.Tensor:
    patches = x.unfold(1, window_size, 1)              # (N, L-ws+1, ws)
    stds = patches.std(dim=2, unbiased=False)          # (N, L-ws+1)
    pad_l = (window_size - 1) // 2
    pad_r = window_size - 1 - pad_l
    return F.pad(stds, (pad_l, pad_r), value=0.0)      # (N, L)

if use_sliding_std:
    std_tr = sliding_std(X_train, ws)
    std_v  = sliding_std(X_val,   ws)
    std_te = sliding_std(X_test,  ws)

    X_train = torch.stack([X_train, std_tr], dim=1)    # (N,2,L)
    X_val   = torch.stack([X_val,   std_v],  dim=1)
    X_test  = torch.stack([X_test,  std_te], dim=1)

    del std_tr, std_v, std_te
    gc.collect()
else:
    X_train = X_train.unsqueeze(1)  # (N,1,L)
    X_val   = X_val.unsqueeze(1)
    X_test  = X_test.unsqueeze(1)

# baseline > 0 保护（避免 p_step=0）
baseline_train = baseline_train.clamp_min(1e-8)
baseline_val   = baseline_val.clamp_min(1e-8)
baseline_test  = baseline_test.clamp_min(1e-8)

print(f"X_train: {X_train.shape}, y_train: {y_train.shape}, baseline_train: {baseline_train.shape}")
print(f"X_val  : {X_val.shape},   y_val  : {y_val.shape},   baseline_val  : {baseline_val.shape}")
print(f"X_test : {X_test.shape},  y_test : {y_test.shape},  baseline_test : {baseline_test.shape}")


# =========================================================
# 11) baseline sanity check
# =========================================================
def _check_baseline(x, b, y, name="train"):
    assert isinstance(x, torch.Tensor) and isinstance(b, torch.Tensor) and isinstance(y, torch.Tensor)
    assert x.dim() == 3, f"{name}: X must be (N,C,L), got {x.shape}"
    assert y.dim() == 1, f"{name}: y must be (N,), got {y.shape}"
    assert b.dim() == 1, f"{name}: baseline must be (N,), got {b.shape}"
    assert x.size(0) == y.size(0) == b.size(0), f"{name}: N mismatch: X={x.size(0)}, y={y.size(0)}, b={b.size(0)}"
    b = b.float().clamp_min(1e-8)
    y = y.long()
    return x, b, y


# =========================================================
# 12) 直方图映射：对齐代码2.my_binning
# =========================================================
class MyBinningTorch(nn.Module):
    def __init__(self, bin: int):
        super().__init__()
        self.bin = int(bin)

    @staticmethod
    def _standardize(P: torch.Tensor, eps: float = 1e-8):
        mean = P.mean(dim=-1, keepdim=True)
        std = P.std(dim=-1, keepdim=True)
        out = (P - mean) / (std + eps)
        out = torch.where(std > 0, out, torch.zeros_like(out))  # std==0 时置0
        return out

    def forward(self, sequence: torch.Tensor, baseline: torch.Tensor):
        """
        sequence: (B, L) float
        baseline: (B,)  float
        return  : (B, bin) float
        """
        if sequence.dim() != 2:
            raise ValueError(f"sequence must be (B,L), got {sequence.shape}")
        if baseline.dim() != 1:
            raise ValueError(f"baseline must be (B,), got {baseline.shape}")

        B, L = sequence.shape

        # 对齐代码2：nan/inf -> 0
        x = torch.nan_to_num(sequence, nan=0.0, posinf=0.0, neginf=0.0)

        # 对齐代码2：小于等于0的都替换为0（temp>0 else 0）
        x = torch.clamp(x, min=0.0)

        baseline = baseline.clamp_min(1e-8)
        p_step = baseline / self.bin  # (B,)

        s = torch.floor(x / p_step.unsqueeze(1)).long()  # (B,L)
        s = s.clamp_(0, self.bin - 1)

        P = torch.zeros(B, self.bin, device=x.device, dtype=torch.float32)
        ones = torch.ones(B, L, device=x.device, dtype=torch.float32)
        P.scatter_add_(dim=1, index=s, src=ones)

        return self._standardize(P)


# =========================================================
# 13) StevieNet（ResNet1D）
# =========================================================
class Bottleneck1D(nn.Module):
    extention = 4

    def __init__(self, inplanes, planes, stride, downsample=None):
        super().__init__()
        self.conv1 = nn.Conv1d(inplanes, planes, kernel_size=1, stride=stride, bias=False)
        self.bn1 = nn.BatchNorm1d(planes)
        self.conv2 = nn.Conv1d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm1d(planes)
        self.conv3 = nn.Conv1d(planes, planes * self.extention, kernel_size=1, stride=1, bias=False)
        self.bn3 = nn.BatchNorm1d(planes * self.extention)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        residual = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        if self.downsample is not None:
            residual = self.downsample(x)
        out = self.relu(out + residual)
        return out


class StevieNet(nn.Module):
    def __init__(self, block, layers, num_class, input_channels=1):
        super().__init__()
        self.inplane = 64
        self.conv1 = nn.Conv1d(input_channels, self.inplane, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm1d(self.inplane)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)

        self.stage1 = self.make_layer(block, 64, layers[0], stride=1)
        self.stage2 = self.make_layer(block, 128, layers[1], stride=2)
        self.stage3 = self.make_layer(block, 256, layers[2], stride=2)
        self.stage4 = self.make_layer(block, 512, layers[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(512 * block.extention, num_class)

    def make_layer(self, block, plane, block_num, stride=1):
        block_list = []
        downsample = None
        if stride != 1 or self.inplane != plane * block.extention:
            downsample = nn.Sequential(
                nn.Conv1d(self.inplane, plane * block.extention, stride=stride, kernel_size=1, bias=False),
                nn.BatchNorm1d(plane * block.extention),
            )
        block_list.append(block(self.inplane, plane, stride=stride, downsample=downsample))
        self.inplane = plane * block.extention
        for _ in range(1, block_num):
            block_list.append(block(self.inplane, plane, stride=1))
        return nn.Sequential(*block_list)

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.maxpool(out)
        out = self.stage1(out)
        out = self.stage2(out)
        out = self.stage3(out)
        out = self.stage4(out)
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        return self.fc(out)


# =========================================================
# 14) 对比模型封装：hist + StevieNet（forward/predict 需要 baseline）
# =========================================================
class StevieNetHistBaseline(nn.Module):
    def __init__(self, num_classes=20, bin=1024, threshold=0.9, input_use_channel0=True):
        super().__init__()
        self.hist = MyBinningTorch(bin=int(bin))
        self.backbone = StevieNet(Bottleneck1D, layers=[3, 4, 6, 3], num_class=num_classes, input_channels=1)
        self.threshold = float(threshold)
        self.input_use_channel0 = bool(input_use_channel0)

    def forward(self, x, baseline):
        # x: (B, C, L)
        if self.input_use_channel0:
            seq = x[:, 0, :]          # (B,L)
        else:
            seq = x.mean(dim=1)       # (B,L)

        hist_vec = self.hist(seq, baseline)      # (B,bin)
        hist_vec = hist_vec.unsqueeze(1)         # (B,1,bin)
        logits = self.backbone(hist_vec)         # (B,num_classes)
        return logits

    def predict(self, x, baseline):
        logits = self.forward(x, baseline)
        probs = F.softmax(logits, dim=1)
        conf, pred = torch.max(probs, dim=1)
        pred = pred.clone()
        pred[conf < self.threshold] = -1
        return pred, conf


# =========================================================
# 16) 训练准备
# =========================================================
X_train, baseline_train, y_train = _check_baseline(X_train, baseline_train, y_train, "train")
X_val,   baseline_val,   y_val   = _check_baseline(X_val,   baseline_val,   y_val,   "val")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

NUM_CLASSES = 20
BIN = 256
EPOCHS = 40
BATCH = 64
THRESH = 0

model = StevieNetHistBaseline(
    num_classes=NUM_CLASSES,
    bin=BIN,
    threshold=THRESH,
    input_use_channel0=True
).to(device)

train_loader = DataLoader(
    TensorDataset(X_train, baseline_train, y_train),
    batch_size=BATCH,
    shuffle=True,
    pin_memory=True
)
val_loader = DataLoader(
    TensorDataset(X_val, baseline_val, y_val),
    batch_size=BATCH,
    shuffle=False,
    pin_memory=True
)

optimizer = optim.AdamW(model.parameters(), lr=1e-4)
scheduler = OneCycleLR(
    optimizer,
    max_lr=1e-3,
    steps_per_epoch=len(train_loader),
    epochs=EPOCHS,
    pct_start=0.1,
    anneal_strategy='cos',
    final_div_factor=1e4
)

# 固定不变：不再动态调整类别权重
criterion = nn.CrossEntropyLoss()

train_accs, val_accs_max, val_accs_pred, coverages = [], [], [], []


# =========================================================
# 17) Training Loop
# =========================================================
for epoch in range(EPOCHS):
    # —— train ——
    model.train()
    correct = total = 0

    for xb, bb, yb in train_loader:
        xb = xb.to(device, non_blocking=True)
        bb = bb.to(device, non_blocking=True).float().clamp_min(1e-8)
        yb = yb.to(device, non_blocking=True).long()

        logits = model(xb, bb)
        loss = criterion(logits, yb)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        scheduler.step()

        preds = logits.argmax(dim=1)
        correct += (preds == yb).sum().item()
        total += yb.size(0)

    train_accs.append(100.0 * correct / max(total, 1))

    # —— validate ——
    model.eval()
    cm_max = cm_pred = total_max = total_pred = 0

    with torch.no_grad():
        for xb, bb, yb in val_loader:
            xb = xb.to(device, non_blocking=True)
            bb = bb.to(device, non_blocking=True).float().clamp_min(1e-8)
            yb = yb.to(device, non_blocking=True).long()

            logits = model(xb, bb)

            # Val Acc(max)
            pm = logits.argmax(dim=1)
            cm_max += (pm == yb).sum().item()
            total_max += yb.size(0)

            # Val Acc(pred) with rejection
            pp, conf = model.predict(xb, bb)
            mask = (pp != -1)
            if mask.any():
                cm_pred += (pp[mask] == yb[mask]).sum().item()
                total_pred += mask.sum().item()

    val_accs_max.append(100.0 * cm_max / max(total_max, 1))
    val_accs_pred.append(100.0 * cm_pred / total_pred if total_pred > 0 else 0.0)
    coverages.append(100.0 * (total_pred / max(total_max, 1)))

    print(f"Epoch {epoch+1:02d} | "
          f"Train Acc: {train_accs[-1]:.2f}% | "
          f"Val Acc(max): {val_accs_max[-1]:.2f}% | "
          f"Val Acc(pred): {val_accs_pred[-1]:.2f}% | "
          f"Coverage: {coverages[-1]:.1f}%")

# =========================================================
# 18) 可视化
# =========================================================
plt.figure(figsize=(10, 5))
plt.plot(train_accs, label="Train Acc")
plt.plot(val_accs_max, label="Val Acc (max)")
plt.plot(val_accs_pred, label="Val Acc (predict, with rejection)")
plt.plot(coverages, label="Coverage (%)")
plt.xlabel("Epoch")
plt.ylabel("Accuracy / %")
plt.title("Training Curve: StevieNet + my_binning(raw+nonneg) + rejection")
plt.legend()
plt.grid(True)
plt.show()


model.threshold = 0 # 修改为你想要测试的阈值
'''
# 以下为门槛值超参数搜索环节，可选
# =========================================================# =========================================================# =========================================================# =========================================================



# ================= Validation: threshold calibration + per-class gate stats + report (baseline-aware) =================
import os, json
import numpy as np

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix


# ----------------- A) 配置：选择校准策略 -----------------
CAL_POLICY      = "target_coverage"   # "target_coverage" 或 "target_risk"
TARGET_VALUE    = 0.65                # coverage目标(0~1) 或 risk目标(0~1)
CALIB_BATCHSIZE = 512
PLOT_CALIB      = True
SAVE_JSON_PATH  = "./val_gate_stats.json"


# ----------------- B) 在验证集上标定阈值（风险-覆盖率） -----------------
@torch.no_grad()
def calibrate_threshold_on_val(model, X_val, baseline_val, y_val,
                               device="cpu", batch_size=512,
                               policy="target_coverage", target=0.90,
                               plot=True):
    """
    选择性分类标定 Softmax 阈值（baseline-aware）：
      - 取每个样本最大置信度 conf 与是否预测正确 correct
      - conf 降序 -> (coverage, acc_on_accepted, risk) 曲线
      - policy="target_coverage": 覆盖率>=target 的区间内选 risk 最小点
        policy="target_risk": risk<=target 的区间内选 coverage 最大点
    返回：thr_star, cov_star, acc_star，并写回 model.threshold
    """
    model.eval()

    # ✅ DataLoader：多了 baseline_val
    loader = DataLoader(
        TensorDataset(X_val, baseline_val, y_val),
        batch_size=batch_size,
        shuffle=False
    )

    confs, corrects = [], []
    for xb, bb, yb in loader:
        xb = xb.to(device, non_blocking=True)
        bb = bb.to(device, non_blocking=True).float().clamp_min(1e-8)
        yb = yb.to(device, non_blocking=True)

        # ✅ 关键改动：模型前向需要 baseline
        logits = model(xb, bb)
        probs  = F.softmax(logits, dim=1)
        conf, pred = probs.max(dim=1)

        confs.append(conf.detach().cpu())
        corrects.append((pred == yb).detach().cpu())

    confs   = torch.cat(confs).numpy()
    correct = torch.cat(corrects).numpy().astype(np.int32)
    N = len(confs)
    if N == 0:
        raise RuntimeError("Validation set is empty, cannot calibrate threshold.")

    order = np.argsort(-confs)
    conf_sorted    = confs[order]
    correct_sorted = correct[order]

    k = np.arange(1, N + 1)
    coverage = k / N
    acc_on_accepted = np.cumsum(correct_sorted) / k
    risk = 1.0 - acc_on_accepted

    if policy == "target_coverage":
        idx0 = int(np.searchsorted(coverage, target, side="left"))
        idx0 = np.clip(idx0, 0, N - 1)
        tail = np.arange(idx0, N)
        idx_star = int(tail[np.argmin(risk[tail])])
    elif policy == "target_risk":
        ok = np.where(risk <= target)[0]
        idx_star = int(ok[-1]) if ok.size > 0 else int(np.argmin(risk))
    else:
        raise ValueError("CAL_POLICY must be 'target_coverage' or 'target_risk'.")

    thr_star = float(conf_sorted[idx_star])
    cov_star = float(coverage[idx_star])
    acc_star = float(acc_on_accepted[idx_star])

    model.threshold = thr_star

    print("\n[Calibration]")
    print(f"  policy = {policy}, target = {target}")
    print(f"  -> threshold = {thr_star:.4f}")
    print(f"  -> coverage  = {cov_star*100:.2f}%")
    print(f"  -> acc_on_accepted (1-risk) = {acc_star*100:.2f}%")

    if plot:
        fig = plt.figure(figsize=(12, 4.5))
        ax1 = fig.add_subplot(1, 2, 1)
        ax1.hist(confs[correct == 1], bins=40, alpha=0.6, label="Correct", density=True)
        ax1.hist(confs[correct == 0], bins=40, alpha=0.6, label="Incorrect", density=True)
        ax1.axvline(thr_star, ls="--", c="k", label=f"Threshold {thr_star:.2f}")
        ax1.set_xlabel("Max softmax confidence")
        ax1.set_ylabel("Density")
        ax1.set_title("Validation confidence distribution")
        ax1.legend()

        ax2 = fig.add_subplot(1, 2, 2)
        ax2.plot(coverage, acc_on_accepted, lw=2)
        ax2.scatter([cov_star], [acc_star], s=50, edgecolors="k",
                    label=f"Chosen: cov={cov_star*100:.1f}%, acc={acc_star*100:.2f}%")
        ax2.set_xlabel("Coverage (accepted proportion)")
        ax2.set_ylabel("Accuracy on accepted (Precision)")
        ax2.set_title("Accuracy–Coverage curve (validation)")
        ax2.set_xlim(0.2, 1)
        ax2.set_ylim(0.0, 1.0)
        ax2.grid(True, alpha=0.3)
        ax2.legend()
        plt.tight_layout()
        plt.show()

    return thr_star, cov_star, acc_star


# ----------------- C) 先标定阈值 -----------------
thr_star, cov_star, acc_star = calibrate_threshold_on_val(
    model, X_val, baseline_val, y_val,
    device=device,
    batch_size=CALIB_BATCHSIZE,
    policy=CAL_POLICY,
    target=TARGET_VALUE,
    plot=PLOT_CALIB
)
print(f"[Model] model.threshold 已设置为 {model.threshold:.4f}")


# ----------------- D) 用选定阈值统计每类 gate 参数并缓存 -----------------
@torch.no_grad()
def compute_val_gate_stats(model, X_val, baseline_val, y_val, thr,
                           device="cpu", batch_size=512):
    model.eval()
    loader = DataLoader(
        TensorDataset(X_val, baseline_val, y_val),
        batch_size=batch_size,
        shuffle=False
    )

    y_true_list, y_hat_list, p_max_list = [], [], []
    for xb, bb, yb in loader:
        xb = xb.to(device, non_blocking=True)
        bb = bb.to(device, non_blocking=True).float().clamp_min(1e-8)
        yb = yb.to(device, non_blocking=True)

        logits = model(xb, bb)
        probs  = F.softmax(logits, dim=1)
        p_max, y_hat = probs.max(dim=1)

        y_true_list.append(yb.detach().cpu())
        y_hat_list.append(y_hat.detach().cpu())
        p_max_list.append(p_max.detach().cpu())

    y_true = torch.cat(y_true_list).numpy()
    y_hat  = torch.cat(y_hat_list).numpy()
    p_max  = torch.cat(p_max_list).numpy()

    num_classes = int(max(y_true.max(), y_hat.max())) + 1
    accepted = (p_max >= thr)

    gate = {"selected_threshold": float(thr), "per_class": {}, "overall": {}}

    acc_on_accepted = float((y_hat[accepted] == y_true[accepted]).mean()) if accepted.any() else 0.0
    coverage = float(accepted.mean())
    gate["overall"] = {"acc_on_accepted": acc_on_accepted, "coverage": coverage}

    for c in range(num_classes):
        is_c = (y_true == c)
        n_c  = int(is_c.sum())
        if n_c == 0:
            gate["per_class"][str(c)] = {"r": None, "a": None, "S_1mr": None, "S_full": None}
            continue

        rc = float((~accepted[is_c]).mean())  # 真c被拒比例
        mask_acc_c = accepted & is_c
        ac = float((y_hat[mask_acc_c] == c).mean()) if mask_acc_c.any() else 0.0

        gate["per_class"][str(c)] = {
            "r": rc,
            "a": ac,
            "S_1mr": 1.0 - rc,
            "S_full": (1.0 - rc) * ac
        }
    return gate


VAL_GATE_STATS = compute_val_gate_stats(
    model, X_val, baseline_val, y_val, thr_star,
    device=device, batch_size=CALIB_BATCHSIZE
)
VAL_SELECTED_THR = float(thr_star)

with open(SAVE_JSON_PATH, "w") as f:
    json.dump(VAL_GATE_STATS, f, indent=2)
print(f"[Saved] per-class gate stats -> {os.path.abspath(SAVE_JSON_PATH)}")


# ----------------- E) 用选定阈值输出逐类指标 & （可选）混淆矩阵 -----------------
val_loader = DataLoader(
    TensorDataset(X_val, baseline_val, y_val),
    batch_size=256,
    shuffle=False
)

y_true_list, y_hat_list, p_max_list = [], [], []
with torch.no_grad():
    for xb, bb, yb in val_loader:
        xb = xb.to(device, non_blocking=True)
        bb = bb.to(device, non_blocking=True).float().clamp_min(1e-8)
        yb = yb.to(device, non_blocking=True)

        logits = model(xb, bb)
        probs  = F.softmax(logits, dim=1)
        p_max, y_hat = probs.max(dim=1)

        y_true_list.append(yb.detach().cpu())
        y_hat_list.append(y_hat.detach().cpu())
        p_max_list.append(p_max.detach().cpu())

y_true = torch.cat(y_true_list).numpy()
y_hat  = torch.cat(y_hat_list).numpy()
p_max  = torch.cat(p_max_list).numpy()

accepted_mask = (p_max >= VAL_SELECTED_THR)
num_classes = int(max(y_true.max(), y_hat.max())) + 1

print("\n==== Accuracy and Uncertainty Per Label (Validation @ selected threshold) ====")
overall_correct_conf = 0
overall_total_conf   = 0
overall_not_pred     = 0

for c in range(num_classes):
    mask_c  = (y_true == c)
    total_c = int(mask_c.sum())
    if total_c == 0:
        continue

    acc_max = float((y_hat[mask_c] == c).mean()) * 100.0

    mask_acc = mask_c & accepted_mask
    accept_c = int(mask_acc.sum())
    if accept_c > 0:
        acc_conf = float((y_hat[mask_acc] == c).mean()) * 100.0
        correct_c = int((y_hat[mask_acc] == c).sum())
        acc_conf_display = f"{acc_conf:5.2f}%"
    else:
        acc_conf_display = "  NaN"
        correct_c = 0

    not_pred_pct = float((~accepted_mask[mask_c]).mean()) * 100.0

    overall_correct_conf += correct_c
    overall_total_conf   += accept_c
    overall_not_pred     += int((~accepted_mask[mask_c]).sum())

    print(f"Label {c:02d}: 总样本={total_c:4d}, "
          f"argmax Acc={acc_max:5.2f}%, "
          f"阈值 Acc={acc_conf_display:>6}, "
          f"Not Predicted={not_pred_pct:5.2f}% "
          f"(正确 {correct_c}/{accept_c})")

overall_acc_conf = (overall_correct_conf / overall_total_conf * 100.0) if overall_total_conf > 0 else 0.0
overall_not_pred_pct = overall_not_pred / y_true.size * 100.0
print(f"\nOverall Accuracy (on accepted): {overall_acc_conf:.2f}%")
print(f"Overall Not Predicted Percentage: {overall_not_pred_pct:.2f}%")
print(f"Selected threshold: {VAL_SELECTED_THR:.4f}")

# ——（可选）只在“被接收”的样本上画混淆矩阵 —— #
PLOT_CM = True
if PLOT_CM:
    valid = accepted_mask
    y_true_f = y_true[valid]
    y_pred_f = y_hat[valid]
    if y_true_f.size > 0:
        labels_order = sorted(list(set(y_true_f.tolist()) | set(y_pred_f.tolist())))
        cm = confusion_matrix(y_true_f, y_pred_f, labels=labels_order)

        plt.figure(figsize=(12, 9))
        sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                    xticklabels=labels_order, yticklabels=labels_order)
        plt.xlabel('Predicted Label'); plt.ylabel('True Label')
        plt.title('Confusion Matrix on Accepted Samples (Validation)')
        plt.show()

        cm_norm = cm.astype(float) / cm.sum(axis=1, keepdims=True)
        plt.figure(figsize=(12, 9))
        sns.heatmap(cm_norm, annot=True, fmt=".2f", cmap="Blues",
                    xticklabels=labels_order, yticklabels=labels_order)
        plt.xlabel('Predicted Label'); plt.ylabel('True Label')
        plt.title('Normalized Confusion Matrix on Accepted Samples (Validation)')
        plt.show()
    else:
        print("（所有验证样本都被拒识，跳过混淆矩阵绘制）")
# ====================================================================================================






# =========================================================# =========================================================# =========================================================# =========================================================

'''


import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
from torch.utils.data import TensorDataset, DataLoader


# =========================================================
# 0) 快速 sanity check（可选，但建议保留）
# =========================================================
assert X_test.dim() == 3, f"X_test must be (N,C,L), got {X_test.shape}"
assert y_test.dim() == 1, f"y_test must be (N,), got {y_test.shape}"
assert baseline_test.dim() == 1, f"baseline_test must be (N,), got {baseline_test.shape}"
assert X_test.size(0) == y_test.size(0) == baseline_test.size(0), "N mismatch among X_test/y_test/baseline_test"

# baseline 不能为 0（你前面的 hist 映射会 baseline/bin 做步长），这里做一个保险
baseline_test = baseline_test.float().clamp_min(1e-8)


# =========================================================
# 1) 分批得到：y_true / argmax_pred / threshold_pred(-1拒识) / conf
# =========================================================
def batch_predict_test(model, X_test, baseline_test, y_test, device="cpu", batch_size=32):
    model.eval()
    test_dataset = TensorDataset(X_test, baseline_test, y_test)
    test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    all_true, all_pred_max, all_pred_conf, all_conf = [], [], [], []

    with torch.no_grad():
        for xb, bb, yb in test_loader:
            xb = xb.to(device, non_blocking=True)
            bb = bb.to(device, non_blocking=True).float().clamp_min(1e-8)
            yb = yb.to(device, non_blocking=True).long()

            logits = model(xb, bb)                 # ✅ baseline-aware forward
            pm = logits.argmax(dim=1)              # argmax pred

            # ✅ 推荐：直接用你模型自带 predict（它内部用 model.threshold）
            pc, conf = model.predict(xb, bb)       # threshold pred (-1 for reject)

            all_true.append(yb.cpu())
            all_pred_max.append(pm.cpu())
            all_pred_conf.append(pc.cpu())
            all_conf.append(conf.cpu())

    all_true      = torch.cat(all_true)
    all_pred_max  = torch.cat(all_pred_max)
    all_pred_conf = torch.cat(all_pred_conf)
    all_conf      = torch.cat(all_conf)
    return all_true, all_pred_max, all_pred_conf, all_conf


# =========================================================
# 2) 输出：每类 argmax acc / 阈值 acc(只在被接收样本上) / 拒识率
# =========================================================
def report_per_class_metrics(all_true, all_pred_max, all_pred_conf):
    num_classes = int(all_true.max().item()) + 1

    print("\n=== 测试集上各标签准确度（baseline-aware, 分批计算） ===")
    overall_correct_conf = 0
    overall_total_conf   = 0
    overall_reject       = 0
    overall_total        = all_true.numel()

    for c in range(num_classes):
        mask_c  = (all_true == c)
        total_c = int(mask_c.sum().item())
        if total_c == 0:
            continue

        # argmax acc（不拒识）
        correct_max_c = int((all_pred_max[mask_c] == c).sum().item())
        acc_max = 100.0 * correct_max_c / total_c

        # threshold acc（只在被接收样本上算）
        accepted_c_mask = mask_c & (all_pred_conf != -1)
        accept_c = int(accepted_c_mask.sum().item())
        if accept_c > 0:
            correct_conf_c = int((all_pred_conf[accepted_c_mask] == c).sum().item())
            acc_conf = 100.0 * correct_conf_c / accept_c
            acc_conf_display = f"{acc_conf:5.2f}%"
        else:
            correct_conf_c = 0
            acc_conf_display = "  NaN "

        # 拒识率
        reject_c = int((mask_c & (all_pred_conf == -1)).sum().item())
        reject_pct = 100.0 * reject_c / total_c

        overall_correct_conf += correct_conf_c
        overall_total_conf   += accept_c
        overall_reject       += reject_c

        print(f"Label {c:02d}: 总样本={total_c:4d}, "
              f"argmax Acc={acc_max:6.2f}%, "
              f"阈值 Acc={acc_conf_display:>6}, "
              f"Reject={reject_pct:6.2f}% "
              f"(正确 {correct_conf_c}/{accept_c})")

    overall_acc_conf = 100.0 * overall_correct_conf / overall_total_conf if overall_total_conf > 0 else 0.0
    overall_coverage = 100.0 * overall_total_conf / overall_total if overall_total > 0 else 0.0
    overall_reject_pct = 100.0 * overall_reject / overall_total if overall_total > 0 else 0.0

    print("\n--- Overall (threshold-based) ---")
    print(f"Accuracy on accepted: {overall_acc_conf:.2f}%")
    print(f"Coverage (accepted):  {overall_coverage:.2f}%")
    print(f"Reject rate:          {overall_reject_pct:.2f}%")

    return num_classes


# =========================================================
# 3) 混淆矩阵（只在 accepted 样本上画）
# =========================================================


def plot_confusion_matrix_on_accepted(
    all_true, all_pred_conf,
    title_prefix="Test",
    save_csv=False,
    csv_prefix="cm_test"
):
    accepted = (all_pred_conf != -1)
    y_true_f = all_true[accepted].numpy()
    y_pred_f = all_pred_conf[accepted].numpy()

    if y_true_f.size == 0:
        print("（所有测试样本都被拒识，跳过混淆矩阵绘制）")
        return

    labels_order = sorted(list(set(y_true_f.tolist()) | set(y_pred_f.tolist())))
    cm = confusion_matrix(y_true_f, y_pred_f, labels=labels_order)

    # ================== 新增：打印“绝对值”混淆矩阵 ==================
    cm_df = pd.DataFrame(cm, index=labels_order, columns=labels_order)
    print(f"\n==== {title_prefix} Confusion Matrix on Accepted Samples (Absolute) ====")
    with pd.option_context("display.max_rows", None, "display.max_columns", None, "display.width", 200):
        print(cm_df)

    # ================== 绘图：绝对值 ==================
    plt.figure(figsize=(12, 9))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                xticklabels=labels_order, yticklabels=labels_order)
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.title(f"{title_prefix} Confusion Matrix on Accepted Samples (Absolute)")
    plt.show()

    # ================== 归一化（按行） ==================
    cm_norm = cm.astype(float) / np.clip(cm.sum(axis=1, keepdims=True), 1, None)

    # ================== 新增：打印“归一化”混淆矩阵 ==================
    cm_norm_df = pd.DataFrame(cm_norm, index=labels_order, columns=labels_order)
    print(f"\n==== {title_prefix} Confusion Matrix on Accepted Samples (Normalized, row-wise) ====")
    with pd.option_context("display.max_rows", None, "display.max_columns", None, "display.width", 200):
        print(cm_norm_df.round(4))

    # ================== 绘图：归一化 ==================
    plt.figure(figsize=(12, 9))
    sns.heatmap(cm_norm, annot=True, fmt=".2f", cmap="Blues",
                xticklabels=labels_order, yticklabels=labels_order)
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.title(f"{title_prefix} Confusion Matrix on Accepted Samples (Normalized)")
    plt.show()

    # ================== 可选：保存为 CSV ==================
    #if save_csv:
        #cm_df.to_csv(f"{csv_prefix}_absolute.csv", index=True)
        #cm_norm_df.to_csv(f"{csv_prefix}_normalized.csv", index=True)
        #print(f"\n[Saved] {csv_prefix}_absolute.csv")
        #print(f"[Saved] {csv_prefix}_normalized.csv")




# =========================================================
# 4) 一键运行（假设 model.threshold 已经用验证集校准好）
# =========================================================
model.eval()

all_true, all_pred_max, all_pred_conf, all_conf = batch_predict_test(
    model, X_test, baseline_test, y_test,
    device=device,
    batch_size=32
)

num_classes = report_per_class_metrics(all_true, all_pred_max, all_pred_conf)

PLOT_CM = True
if PLOT_CM:
    plot_confusion_matrix_on_accepted(all_true, all_pred_conf, title_prefix="Test")
