# v3_2 层次有监督分类（LightGBM + DeBERTa）与可解释性

要点：
- 统一文本字段：NameEnglish + ShortDescription；
- 复用 v3_1 的 4096 维嵌入或自行读取；
- 分层训练：Dispatcher (Macro/Micro) + Specialists（Macro三类、Micro五类）
- 强化可解释性：
  - LightGBM：SHAP 全局/局部 + 类别 c-TF-IDF 词特征对照；
  - DeBERTa：Integrated Gradients token 归因并导出着色 HTML；
- 明确保留并输出 v1 风格的混淆矩阵（热力图），包括：dispatcher、两个 specialist，以及端到端管道。

In [None]:
import sys, subprocess
def pip_install(pkg):
    subprocess.check_call([sys.executable, '-m', 'pip', 'install', pkg])

for p in ['pandas','numpy','scikit-learn','matplotlib','seaborn','lightgbm','transformers','datasets','shap','captum','nltk']:
    try:
        __import__(p)
    except Exception:
        pip_install(p)

import os, json, re
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
import nltk
from nltk.corpus import stopwords
nltk.download('stopwords')

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix

import lightgbm as lgb
import shap
import torch
from captum.attr import IntegratedGradients
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, DataCollatorWithPadding, EarlyStoppingCallback
from datasets import Dataset

os.environ['WANDB_DISABLED'] = 'true'

# 路径配置（请按需修改）
CTM_OUTPUT_DIR = '/content/drive/MyDrive/Colab Notebooks/验证3+5/v3_ctm_results'
EMBEDDED_JSON_PATH = os.path.join(CTM_OUTPUT_DIR, 'policies_with_embeddings_api_4096.json')
OUTPUT_DIR = '/content/drive/MyDrive/Colab Notebooks/验证3+5/v3_supervised_results'
os.makedirs(OUTPUT_DIR, exist_ok=True)

# 分类层级定义（8 类 -> 宏/微 两层）
MACRO_CLASSES = ['Guideline_Strategy', 'Planning_Layout', 'Institutional_Arrangements']
ALL_CLASSES = [
    'Guideline_Strategy', 'Planning_Layout', 'Institutional_Arrangements',
    'Resource_Allocation_Policy', 'Innovation_Actor_Policy', 'Talent_Policy',
    'Commercialization_Policy', 'Environment_Shaping_Policy'
]

RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)
torch.manual_seed(RANDOM_STATE)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_STATE)

In [None]:
# 1) 读入数据，构造文本与标签（统一 NameEnglish），并划分数据集

def load_data(path):
    df = pd.read_json(path)
    df['text'] = (df.get('NameEnglish','').fillna('') + '. ' + df.get('ShortDescription','').fillna('')).astype(str)
    df = df[df['text'].str.len() > 10].reset_index(drop=True)
    # 需要存在标注列 ClassificationLabel
    if 'ClassificationLabel' not in df.columns:
        raise ValueError('数据缺少 ClassificationLabel 列。请使用带有 8 类标注的数据文件。')
    return df

df = load_data(EMBEDDED_JSON_PATH)
df['primary_label_str'] = df['ClassificationLabel'].apply(lambda x: 'Macro' if x in MACRO_CLASSES else 'Micro')

le_primary = LabelEncoder().fit(df['primary_label_str'])
le_secondary = LabelEncoder().fit(df['ClassificationLabel'])
df['primary_label'] = le_primary.transform(df['primary_label_str'])
df['secondary_label'] = le_secondary.transform(df['ClassificationLabel'])

train_val_df, test_df = train_test_split(df, test_size=0.2, random_state=RANDOM_STATE, stratify=df['secondary_label'])
relative_val = 0.1 / 0.8
train_df, val_df = train_test_split(train_val_df, test_size=relative_val, random_state=RANDOM_STATE, stratify=train_val_df['secondary_label'])
print(f'Train={len(train_df)}, Val={len(val_df)}, Test={len(test_df)}')

# Specialist 数据集（从主划分衍生）
macro_train = train_df[train_df['primary_label_str']=='Macro'].copy()
macro_val   = val_df[val_df['primary_label_str']=='Macro'].copy()
macro_test  = test_df[test_df['primary_label_str']=='Macro'].copy()
micro_train = train_df[train_df['primary_label_str']=='Micro'].copy()
micro_val   = val_df[val_df['primary_label_str']=='Micro'].copy()
micro_test  = test_df[test_df['primary_label_str']=='Micro'].copy()

le_macro = LabelEncoder().fit(macro_train['ClassificationLabel'])
le_micro = LabelEncoder().fit(micro_train['ClassificationLabel'])

for split, df_ in [('train', macro_train), ('val', macro_val), ('test', macro_test)]:
    df_[f'specialist_label'] = le_macro.transform(df_['ClassificationLabel'])
for split, df_ in [('train', micro_train), ('val', micro_val), ('test', micro_test)]:
    df_[f'specialist_label'] = le_micro.transform(df_['ClassificationLabel'])

In [None]:
# 2) 评估与可视化（保留 v1 风格的混淆矩阵热力图）

def evaluate_and_plot(y_true, y_pred, labels, title):
    acc = accuracy_score(y_true, y_pred)
    f1  = f1_score(y_true, y_pred, average='macro', zero_division=0)
    print(f'\n=== {title} ===\nAccuracy={acc:.4f} | Macro-F1={f1:.4f}')
    print(classification_report(y_true, y_pred, target_names=labels, zero_division=0))
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(max(8, len(labels)*0.9), max(6, len(labels)*0.7)))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)
    plt.title(f'Confusion Matrix - {title}')
    plt.ylabel('True Label'); plt.xlabel('Predicted Label')
    fn = os.path.join(OUTPUT_DIR, f'cm_{re.sub(r"[^a-zA-Z0-9_]+","_", title)}.png')
    plt.savefig(fn, bbox_inches='tight'); plt.close()
    print('混淆矩阵已保存：', fn)
    return {'accuracy': acc, 'f1': f1}

In [None]:
# 3) LightGBM 分层训练（使用 4096 维嵌入），含早停与类权重

LGBM_BASE = {
    'learning_rate': 0.02,
    'n_estimators': 5000,
    'feature_fraction': 0.8,
    'bagging_fraction': 0.8,
    'bagging_freq': 1,
    'num_leaves': 31,
    'reg_alpha': 0.1,
    'reg_lambda': 0.1,
    'n_jobs': -1,
    'seed': RANDOM_STATE,
    'boosting_type': 'gbdt',
}

def train_lgbm(X_train, y_train, X_val, y_val, num_classes, title, use_balanced=True):
    params = LGBM_BASE.copy()
    if num_classes == 2:
        params.update({'objective': 'binary', 'metric': 'binary_logloss'})
    else:
        params.update({'objective': 'multiclass', 'metric': 'multi_logloss', 'num_class': num_classes})
    if use_balanced:
        params['class_weight'] = 'balanced'
    model = lgb.LGBMClassifier(**params)
    model.fit(
        X_train, y_train,
        eval_set=[(X_val, y_val)],
        callbacks=[lgb.early_stopping(200, verbose=False)]
    )
    return model

# Dispatcher
Xd_tr = np.array(train_df['embed'].tolist());    yd_tr = train_df['primary_label'].values
Xd_va = np.array(val_df['embed'].tolist());      yd_va = val_df['primary_label'].values
Xd_te = np.array(test_df['embed'].tolist());     yd_te = test_df['primary_label'].values

lgbm_dispatcher = train_lgbm(Xd_tr, yd_tr, Xd_va, yd_va, len(le_primary.classes_), 'LGBM Dispatcher (Macro-Micro)')
pred_d = lgbm_dispatcher.predict(Xd_te)
_ = evaluate_and_plot(yd_te, pred_d, le_primary.classes_, 'LGBM Dispatcher (Macro-Micro)')

# Macro Specialist
Xm_tr = np.array(macro_train['embed'].tolist()); ym_tr = macro_train['specialist_label'].values
Xm_va = np.array(macro_val['embed'].tolist());   ym_va = macro_val['specialist_label'].values
Xm_te = np.array(macro_test['embed'].tolist());  ym_te = macro_test['specialist_label'].values
lgbm_macro = train_lgbm(Xm_tr, ym_tr, Xm_va, ym_va, len(le_macro.classes_), 'LGBM Macro Specialist')
pred_m = lgbm_macro.predict(Xm_te)
_ = evaluate_and_plot(ym_te, pred_m, le_macro.classes_, 'LGBM Macro Specialist')

# Micro Specialist
xmi_tr = np.array(micro_train['embed'].tolist()); ymi_tr = micro_train['specialist_label'].values
xmi_va = np.array(micro_val['embed'].tolist());   ymi_va = micro_val['specialist_label'].values
xmi_te = np.array(micro_test['embed'].tolist());  ymi_te = micro_test['specialist_label'].values
lgbm_micro = train_lgbm(xmi_tr, ymi_tr, xmi_va, ymi_va, len(le_micro.classes_), 'LGBM Micro Specialist')
pred_mi = lgbm_micro.predict(xmi_te)
_ = evaluate_and_plot(ymi_te, pred_mi, le_micro.classes_, 'LGBM Micro Specialist')

# 端到端：先 Dispatcher，再调用对应 Specialist
y_true_final = test_df['secondary_label'].values
y_pred_final = []
for i, row in test_df.iterrows():
    feat = np.array(row['embed']).reshape(1, -1)
    p_primary = lgbm_dispatcher.predict(feat)[0]
    primary_label = le_primary.inverse_transform([p_primary])[0]
    if primary_label == 'Macro':
        p_spec = lgbm_macro.predict(feat)[0]
        final_label = le_macro.inverse_transform([p_spec])[0]
    else:
        p_spec = lgbm_micro.predict(feat)[0]
        final_label = le_micro.inverse_transform([p_spec])[0]
    y_pred_final.append(le_secondary.transform([final_label])[0])
evaluate_and_plot(y_true_final, y_pred_final, le_secondary.classes_, 'End-to-End Pipeline (LGBM)')

In [None]:
# 4) LightGBM 可解释性：SHAP 全局/局部 + 类别 c-TF-IDF 词对照（辅助）

def shap_global_bar(model, X, class_names, name):
    explainer = shap.TreeExplainer(model)
    # 采样以控制开销
    X_sample = X[:min(200, X.shape[0])]
    shap_values = explainer.shap_values(X_sample)
    # 多分类：shap_values 是 list
    plt.figure()
    try:
        # 汇总为平均|SHAP|并排序，画前 20 维
        if isinstance(shap_values, list):
            mean_abs = np.mean([np.mean(np.abs(sv), axis=0) for sv in shap_values], axis=0)
        else:
            mean_abs = np.mean(np.abs(shap_values), axis=0)
        idx = np.argsort(-mean_abs)[:20]
        plt.bar(range(len(idx)), mean_abs[idx])
        plt.xticks(range(len(idx)), [f'dim_{i}' for i in idx], rotation=45, ha='right')
        plt.title(f'SHAP Global Importance (Top 20) - {name}')
        fn = os.path.join(OUTPUT_DIR, f'shap_global_{re.sub(r"[^a-zA-Z0-9_]+","_", name)}.png')
        plt.tight_layout(); plt.savefig(fn, dpi=150); plt.close()
        print('保存 SHAP 全局重要度图：', fn)
    except Exception as e:
        print('SHAP 全局绘制失败：', e)

# Dispatcher 全局 SHAP
shap_global_bar(lgbm_dispatcher, Xd_te, le_primary.classes_, 'LGBM Dispatcher (Macro-Micro)')
# Specialist 全局 SHAP
shap_global_bar(lgbm_macro, Xm_te, le_macro.classes_, 'LGBM Macro Specialist')
shap_global_bar(lgbm_micro, xmi_te, le_micro.classes_, 'LGBM Micro Specialist')

In [None]:
# 5) DeBERTa 分层训练（含早停），并输出混淆矩阵

TRANSFORMER_MODEL = 'microsoft/deberta-v3-base'

def train_transformer(text_series, y, text_series_val, y_val, num_labels, title, class_names):
    tokenizer = AutoTokenizer.from_pretrained(TRANSFORMER_MODEL)
    model = AutoModelForSequenceClassification.from_pretrained(
        TRANSFORMER_MODEL, num_labels=num_labels
    )
    train_ds = Dataset.from_dict({'text': text_series.tolist(), 'label': y.tolist()})
    val_ds   = Dataset.from_dict({'text': text_series_val.tolist(), 'label': y_val.tolist()})
    def tok_fn(batch):
        return tokenizer(batch['text'], padding='max_length', truncation=True, max_length=512)
    train_ds = train_ds.map(tok_fn, batched=True)
    val_ds   = val_ds.map(tok_fn, batched=True)
    args = TrainingArguments(
        output_dir=os.path.join(OUTPUT_DIR, 'ckpt', re.sub(r'[^a-zA-Z0-9_]+','_', title)),
        num_train_epochs=8,
        learning_rate=2e-5,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=16,
        evaluation_strategy='epoch',
        save_strategy='epoch',
        load_best_model_at_end=True,
        metric_for_best_model='f1',
        report_to='none'
    )
    def compute_metrics(p):
        preds = np.argmax(p.predictions, axis=1)
        return {
            'accuracy': accuracy_score(p.label_ids, preds),
            'f1': f1_score(p.label_ids, preds, average='macro', zero_division=0)
        }
    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=train_ds,
        eval_dataset=val_ds,
        tokenizer=tokenizer,
        data_collator=DataCollatorWithPadding(tokenizer),
        compute_metrics=compute_metrics,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
    )
    trainer.train()
    return trainer

# Dispatcher
tr_dispatcher = train_transformer(train_df['text'], train_df['primary_label'], val_df['text'], val_df['primary_label'], len(le_primary.classes_), 'DeBERTa Dispatcher', le_primary.classes_)
test_ds_d = Dataset.from_dict({'text': test_df['text'].tolist(), 'label': test_df['primary_label'].tolist()})
tok_d = tr_dispatcher.tokenizer
test_ds_d = test_ds_d.map(lambda b: tok_d(b['text'], padding='max_length', truncation=True, max_length=512), batched=True)
pred_d_logits = tr_dispatcher.predict(test_ds_d).predictions
pred_d_label  = np.argmax(pred_d_logits, axis=1)
evaluate_and_plot(test_df['primary_label'].values, pred_d_label, le_primary.classes_, 'DeBERTa Dispatcher (Macro-Micro)')

# Macro Specialist
tr_macro = train_transformer(macro_train['text'], macro_train['specialist_label'], macro_val['text'], macro_val['specialist_label'], len(le_macro.classes_), 'DeBERTa Macro Specialist', le_macro.classes_)
test_ds_m = Dataset.from_dict({'text': macro_test['text'].tolist(), 'label': macro_test['specialist_label'].tolist()})
tok_m = tr_macro.tokenizer
test_ds_m = test_ds_m.map(lambda b: tok_m(b['text'], padding='max_length', truncation=True, max_length=512), batched=True)
pred_m_logits = tr_macro.predict(test_ds_m).predictions
pred_m_label  = np.argmax(pred_m_logits, axis=1)
evaluate_and_plot(macro_test['specialist_label'].values, pred_m_label, le_macro.classes_, 'DeBERTa Macro Specialist')

# Micro Specialist
tr_micro = train_transformer(micro_train['text'], micro_train['specialist_label'], micro_val['text'], micro_val['specialist_label'], len(le_micro.classes_), 'DeBERTa Micro Specialist', le_micro.classes_)
test_ds_mi = Dataset.from_dict({'text': micro_test['text'].tolist(), 'label': micro_test['specialist_label'].tolist()})
tok_mi = tr_micro.tokenizer
test_ds_mi = test_ds_mi.map(lambda b: tok_mi(b['text'], padding='max_length', truncation=True, max_length=512), batched=True)
pred_mi_logits = tr_micro.predict(test_ds_mi).predictions
pred_mi_label  = np.argmax(pred_mi_logits, axis=1)
evaluate_and_plot(micro_test['specialist_label'].values, pred_mi_label, le_micro.classes_, 'DeBERTa Micro Specialist')

# 端到端（DeBERTa）
y_true_final = test_df['secondary_label'].values
y_pred_final = []
device = tr_dispatcher.model.device
for i, row in test_df.iterrows():
    # dispatcher
    inputs = tok_d(row['text'], return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)
    with torch.no_grad():
        logits = tr_dispatcher.model(**inputs).logits
    p_primary = torch.argmax(logits, dim=1).item()
    primary_label = le_primary.inverse_transform([p_primary])[0]
    # specialist
    if primary_label == 'Macro':
        tok_s = tok_m; mdl_s = tr_macro.model
        le_s = le_macro
    else:
        tok_s = tok_mi; mdl_s = tr_micro.model
        le_s = le_micro
    inp2 = tok_s(row['text'], return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)
    with torch.no_grad():
        log2 = mdl_s(**inp2).logits
    p_spec = torch.argmax(log2, dim=1).item()
    final_label_str = le_s.inverse_transform([p_spec])[0]
    y_pred_final.append(le_secondary.transform([final_label_str])[0])
evaluate_and_plot(y_true_final, y_pred_final, le_secondary.classes_, 'End-to-End Pipeline (DeBERTa)')

In [None]:
# 6) DeBERTa 可解释性：Integrated Gradients（单样本 token 归因，导出 HTML）

def ig_explain_to_html(trainer, text, outfile):
    model = trainer.model
    tokenizer = trainer.tokenizer
    model.eval()
    inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=256)
    for k in inputs:
        inputs[k] = inputs[k].to(model.device)
    ig = IntegratedGradients(lambda input_ids: model(input_ids=input_ids, attention_mask=inputs['attention_mask']).logits)
    baseline = torch.zeros_like(inputs['input_ids']).to(model.device)
    attributions, _ = ig.attribute(inputs['input_ids'], baselines=baseline, target=None, return_convergence_delta=True)
    atts = attributions.squeeze(0).sum(dim=-1).detach().cpu().numpy()
    toks = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0].detach().cpu().numpy())
    # 归一化
    if len(atts) > 0:
        m = np.max(np.abs(atts));
        if m > 0: atts = atts / m
    # 简易 HTML 着色
    def color(v):
        # 红色正贡献，蓝色负贡献
        r = int(max(0, v)*255); b = int(max(0, -v)*255); g = 0
        return f'rgb({r},{g},{b})'
    spans = []
    for t, a in zip(toks, atts):
        if t in ['[CLS]','[SEP]','[PAD]']:
            continue
        spans.append(f'<span style="background-color:{color(a)};color:#000;padding:2px;margin:1px;display:inline-block;border-radius:3px;">{t}</span>')
    html = f"<html><meta charset='utf-8'><body><div>{''.join(spans)}</div></body></html>"
    with open(outfile, 'w', encoding='utf-8') as f:
        f.write(html)
    print('IG 可视化已保存：', outfile)

# 对每个 specialist 选一个测试样本示例
if len(macro_test) > 0:
    ex_text = macro_test.iloc[0]['text']
    ig_explain_to_html(tr_macro, ex_text, os.path.join(OUTPUT_DIR, 'ig_macro_example.html'))
if len(micro_test) > 0:
    ex_text2 = micro_test.iloc[0]['text']
    ig_explain_to_html(tr_micro, ex_text2, os.path.join(OUTPUT_DIR, 'ig_micro_example.html'))

print('全部训练与可解释性分析完成。结果目录：', OUTPUT_DIR)

In [None]:
# -*- coding: utf-8 -*-
# 说明：
# 将本文件中各段代码按顺序粘贴为 v3_2_supervised_hier_and_xai.ipynb 的新增单元格（放在训练与评估完成之后），
# 即可在 v3_2 中集成：
# 1) c-TF-IDF 类词对照（宏/微 与 8 类）
# 2) 可靠性（温度标定）与可靠性图（Reliability Diagram）、ECE 指标（LightGBM 与 DeBERTa）
# 3) “3+5 vs 2+6” 对比图（读取 v3_1 的 grid_{macro,micro}.json 指标，绘制与标注 K=3/5/2/6）
#
# 依赖：numpy, pandas, scikit-learn, matplotlib, seaborn, nltk, datasets, transformers, torch
# 运行前提：v3_2 中已定义 df、train_df/val_df/test_df、宏/微拆分数据集、lgbm_* 模型与 tr_* 训练器、OUTPUT_DIR 等变量。
# 路径前提：若已运行 v3_1，CTM_OUTPUT_DIR 下有 grid_macro.json 与 grid_micro.json（用于 3+5 vs 2+6 绘图）。
#
# 注意：以下代码对缺失对象做了存在性检查，若某些变量在你的 v3_2 中命名不同，请按需替换。


import os
import re
import json
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.feature_extraction.text import CountVectorizer
from nltk.corpus import stopwords
import nltk
try:
    _ = stopwords.words('english')
except LookupError:
    nltk.download('stopwords')

from datasets import Dataset
import torch


# =========================
# 一、c-TF-IDF 类词对照
# =========================

def compute_c_tf_idf_per_class(texts, labels, top_n=20, stop_words_extra=None, min_df=3, max_df=0.9):
    """
    计算 c-TF-IDF（Class-based TF-IDF），用于每个类别的代表词和权重：
    - 将同一类别的所有文本拼接为“类文档”
    - 对类文档做词袋统计，计算 TF（类内归一化）与类级 IDF（基于类文档数）
    - c-TF-IDF = TF * IDF
    返回：dict 类别 -> DataFrame(columns=[term, score])
    """
    # 1) 将文本按类别聚合为类文档
    df_tmp = pd.DataFrame({'text': texts, 'label': labels})
    grouped = df_tmp.groupby('label')['text'].apply(lambda x: ' '.join(map(str, x))).reset_index()
    class_docs = grouped['text'].tolist()
    class_names = grouped['label'].tolist()

    # 2) 构造停用词
    base_sw = set(stopwords.words('english'))
    extra = set(stop_words_extra or [])
    sw = list(base_sw | extra)

    # 3) 词袋（对类文档进行向量化）
    vectorizer = CountVectorizer(stop_words=sw, min_df=min_df, max_df=max_df)
    X = vectorizer.fit_transform(class_docs)  # 形状：[n_classes, vocab_size]
    vocab = np.array(vectorizer.get_feature_names_out())
    X = X.toarray().astype(float)

    # 若全 0（极端小样本类），避免除零
    if X.sum() == 0:
        return {cn: pd.DataFrame(columns=['term', 'score']) for cn in class_names}

    n_classes = X.shape[0]

    # 4) 类内 TF：每行（类文档）归一化
    row_sums = X.sum(axis=1, keepdims=True) + 1e-12
    tf = X / row_sums

    # 5) 类级 DF 与 IDF
    df_class = (X > 0).sum(axis=0)  # 出现在多少个类文档里
    idf = np.log(1 + n_classes / (df_class + 1e-12))

    # 6) c-TF-IDF
    c_tfidf = tf * idf

    # 7) 为每个类别选 Top-N 词
    results = {}
    for i, cn in enumerate(class_names):
        scores = c_tfidf[i]
        idx = np.argsort(-scores)[:top_n]
        results[cn] = pd.DataFrame({'term': vocab[idx], 'score': scores[idx]})
    return results


def plot_ctfidf_bars(ctfidf_dict, title_prefix, out_dir):
    """
    将 c-TF-IDF 结果绘制为柱状图，每个类别一张。
    """
    save_dir = os.path.join(out_dir, 'ctfidf')
    os.makedirs(save_dir, exist_ok=True)
    for cls, df_cls in ctfidf_dict.items():
        if df_cls.empty:
            continue
        plt.figure(figsize=(10, 5))
        plt.bar(df_cls['term'], df_cls['score'], color='#4C72B0')
        plt.title(f'{title_prefix} - {cls}')
        plt.xticks(rotation=60, ha='right')
        plt.ylabel('c-TF-IDF Score')
        plt.tight_layout()
        fn = os.path.join(save_dir, f'ctfidf_{re.sub(r"[^a-zA-Z0-9_]+","_", title_prefix)}_{re.sub(r"[^a-zA-Z0-9_]+","_", cls)}.png')
        plt.savefig(fn, dpi=150)
        plt.close()


def export_ctfidf_tables(ctfidf_dict, title_prefix, out_dir):
    """
    将 c-TF-IDF 结果导出为 CSV（每个类别一个表）及总汇总 CSV。
    """
    save_dir = os.path.join(out_dir, 'ctfidf')
    os.makedirs(save_dir, exist_ok=True)
    all_rows = []
    for cls, df_cls in ctfidf_dict.items():
        if df_cls.empty:
            continue
        df_path = os.path.join(save_dir, f'ctfidf_{re.sub(r"[^a-zA-Z0-9_]+","_", title_prefix)}_{re.sub(r"[^a-zA-Z0-9_]+","_", cls)}.csv')
        df_cls.to_csv(df_path, index=False, encoding='utf-8')
        for _, r in df_cls.iterrows():
            all_rows.append({'class': cls, 'term': r['term'], 'score': r['score']})
    if all_rows:
        pd.DataFrame(all_rows).to_csv(os.path.join(save_dir, f'ctfidf_{re.sub(r"[^a-zA-Z0-9_]+","_", title_prefix)}_ALL.csv'),
                                      index=False, encoding='utf-8')


# 运行 c-TF-IDF（宏/微 与 8 类）
if 'df' in globals() and 'OUTPUT_DIR' in globals():
    # 宏/微：primary_label_str
    domain_sw = [
        'policy','policies','measure','measures','action','actions','law','laws','government','ministry','council',
        'support','development','research','innovation','technology','science','program','programs','programme',
        'national','international','regional','local','state','country','countries'
    ]
    ctfidf_primary = compute_c_tf_idf_per_class(df['text'].tolist(),
                                                df['primary_label_str'].tolist(),
                                                top_n=25,
                                                stop_words_extra=domain_sw)
    plot_ctfidf_bars(ctfidf_primary, 'c-TF-IDF (Primary Macro/Micro)', OUTPUT_DIR)
    export_ctfidf_tables(ctfidf_primary, 'c-TF-IDF (Primary Macro/Micro)', OUTPUT_DIR)

    # 8 类：ClassificationLabel
    ctfidf_secondary = compute_c_tf_idf_per_class(df['text'].tolist(),
                                                  df['ClassificationLabel'].tolist(),
                                                  top_n=25,
                                                  stop_words_extra=domain_sw)
    plot_ctfidf_bars(ctfidf_secondary, 'c-TF-IDF (Secondary 8 Classes)', OUTPUT_DIR)
    export_ctfidf_tables(ctfidf_secondary, 'c-TF-IDF (Secondary 8 Classes)', OUTPUT_DIR)
    print('c-TF-IDF 类词对照 已完成并导出。')
else:
    print('警告：未检测到 df 或 OUTPUT_DIR，跳过 c-TF-IDF 步骤。')


# ==========================================
# 二、可靠性评估：温度标定 + 可靠性图 + ECE
# ==========================================

def softmax_np(logits):
    """Numpy 版 softmax"""
    z = logits - logits.max(axis=1, keepdims=True)
    e = np.exp(z)
    return e / e.sum(axis=1, keepdims=True)


def nll_from_logits(logits, y_true):
    """多分类负对数似然（NLL）"""
    probs = softmax_np(logits)
    eps = 1e-12
    return -np.mean(np.log(probs[np.arange(len(y_true)), y_true] + eps))


def fit_temperature_from_logits(logits_val, y_val, max_iter=500, lr=0.01):
    """
    在验证集上拟合温度 T（logits/T），最小化 NLL。
    logits_val: [N, C] 未标定的 logit
    返回：float T
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    L = torch.tensor(logits_val, dtype=torch.float32, device=device)
    y = torch.tensor(y_val, dtype=torch.long, device=device)

    T = torch.nn.Parameter(torch.ones(1, device=device))
    opt = torch.optim.LBFGS([T], lr=lr, max_iter=50)

    loss_fn = torch.nn.CrossEntropyLoss()

    def closure():
        opt.zero_grad()
        scaled = L / T.clamp(min=1e-3)
        loss = loss_fn(scaled, y)
        loss.backward()
        return loss

    for _ in range(5):
        opt.step(closure)

    return float(T.detach().cpu().item())


def fit_temperature_from_probs(probs_val, y_val, max_iter=500, lr=0.01):
    """
    针对只有概率（无logits）的模型（如 LightGBM），通过将 p 转为 logit（log p），进行温度标定：
    new_probs = softmax( log(p) / T )
    """
    eps = 1e-12
    logits_val = np.log(np.clip(probs_val, eps, 1.0))
    return fit_temperature_from_logits(logits_val, y_val, max_iter=max_iter, lr=lr)


def apply_temperature_to_logits(logits, T):
    return logits / max(T, 1e-6)


def apply_temperature_to_probs(probs, T):
    eps = 1e-12
    logits = np.log(np.clip(probs, eps, 1.0))
    scaled = apply_temperature_to_logits(logits, T)
    return softmax_np(scaled)


def compute_ece(probs, y_true, n_bins=15):
    """
    计算 Expected Calibration Error (ECE)
    - 以最大预测概率作为置信度
    """
    confidences = probs.max(axis=1)
    preds = probs.argmax(axis=1)
    bins = np.linspace(0, 1, n_bins + 1)
    ece = 0.0
    for i in range(n_bins):
        mask = (confidences > bins[i]) & (confidences <= bins[i+1])
        if not np.any(mask):
            continue
        acc = np.mean(preds[mask] == y_true[mask])
        conf = np.mean(confidences[mask])
        ece += np.abs(acc - conf) * (np.sum(mask) / len(y_true))
    return float(ece)


def plot_reliability_diagram(probs, y_true, title, out_dir, n_bins=15):
    """
    绘制可靠性图（Reliability Diagram）并保存。
    """
    confidences = probs.max(axis=1)
    preds = probs.argmax(axis=1)
    bins = np.linspace(0, 1, n_bins + 1)
    bin_centers = (bins[:-1] + bins[1:]) / 2

    accs, confs = [], []
    for i in range(n_bins):
        mask = (confidences > bins[i]) & (confidences <= bins[i+1])
        if np.any(mask):
            accs.append(np.mean(preds[mask] == y_true[mask]))
            confs.append(np.mean(confidences[mask]))
        else:
            accs.append(0.0)
            confs.append(0.0)

    plt.figure(figsize=(6.5, 6.5))
    plt.plot([0, 1], [0, 1], 'k--', label='Perfect calibration')
    plt.plot(bin_centers, accs, 'o-', label='Accuracy per bin')
    plt.plot(bin_centers, confs, 's--', label='Confidence per bin')
    plt.xlabel('Confidence')
    plt.ylabel('Accuracy')
    plt.title(f'Reliability Diagram - {title}')
    plt.legend()
    fn = os.path.join(out_dir, f'reliability_{re.sub(r"[^a-zA-Z0-9_]+","_", title)}.png')
    plt.tight_layout()
    plt.savefig(fn, dpi=150)
    plt.close()
    return fn


# 2.1 LightGBM 温度标定与可靠性图（Dispatcher + Macro/Micro Specialists）
if all(v in globals() for v in ['lgbm_dispatcher', 'lgbm_macro', 'lgbm_micro', 'OUTPUT_DIR']):
    print('开始 LightGBM 可靠性（温度标定）评估 ...')

    # Dispatcher
    P_d_val = lgbm_dispatcher.predict_proba(Xd_va)
    P_d_test = lgbm_dispatcher.predict_proba(Xd_te)
    T_d = fit_temperature_from_probs(P_d_val, yd_va)
    P_d_test_cal = apply_temperature_to_probs(P_d_test, T_d)
    ece_d_before = compute_ece(P_d_test, yd_te)
    ece_d_after = compute_ece(P_d_test_cal, yd_te)
    print(f'Dispatcher(LGBM) 温度 T={T_d:.3f} | ECE 前={ece_d_before:.4f} 后={ece_d_after:.4f}')
    plot_reliability_diagram(P_d_test, yd_te, 'LGBM Dispatcher (Before TS)', OUTPUT_DIR)
    plot_reliability_diagram(P_d_test_cal, yd_te, 'LGBM Dispatcher (After TS)', OUTPUT_DIR)

    # Macro Specialist
    P_m_val = lgbm_macro.predict_proba(Xm_va)
    P_m_test = lgbm_macro.predict_proba(Xm_te)
    T_m = fit_temperature_from_probs(P_m_val, ym_va)
    P_m_test_cal = apply_temperature_to_probs(P_m_test, T_m)
    ece_m_before = compute_ece(P_m_test, ym_te)
    ece_m_after = compute_ece(P_m_test_cal, ym_te)
    print(f'Macro Specialist(LGBM) 温度 T={T_m:.3f} | ECE 前={ece_m_before:.4f} 后={ece_m_after:.4f}')
    plot_reliability_diagram(P_m_test, ym_te, 'LGBM Macro Specialist (Before TS)', OUTPUT_DIR)
    plot_reliability_diagram(P_m_test_cal, ym_te, 'LGBM Macro Specialist (After TS)', OUTPUT_DIR)

    # Micro Specialist
    P_mi_val = lgbm_micro.predict_proba(xmi_va)
    P_mi_test = lgbm_micro.predict_proba(xmi_te)
    T_mi = fit_temperature_from_probs(P_mi_val, ymi_va)
    P_mi_test_cal = apply_temperature_to_probs(P_mi_test, T_mi)
    ece_mi_before = compute_ece(P_mi_test, ymi_te)
    ece_mi_after = compute_ece(P_mi_test_cal, ymi_te)
    print(f'Micro Specialist(LGBM) 温度 T={T_mi:.3f} | ECE 前={ece_mi_before:.4f} 后={ece_mi_after:.4f}')
    plot_reliability_diagram(P_mi_test, ymi_te, 'LGBM Micro Specialist (Before TS)', OUTPUT_DIR)
    plot_reliability_diagram(P_mi_test_cal, ymi_te, 'LGBM Micro Specialist (After TS)', OUTPUT_DIR)
else:
    print('提示：未检测到 LightGBM 相关对象，跳过 LGBM 的温度标定。')


# 2.2 DeBERTa 温度标定与可靠性图（Dispatcher + Macro/Micro Specialists）
def build_encoded_dataset(text_series, label_series, tokenizer, max_length=512):
    ds = Dataset.from_dict({'text': text_series.tolist(), 'label': label_series.tolist()})
    ds = ds.map(lambda b: tokenizer(b['text'], padding='max_length', truncation=True, max_length=max_length), batched=True)
    return ds

if all(v in globals() for v in ['tr_dispatcher', 'tr_macro', 'tr_micro', 'val_df', 'test_df', 'macro_val', 'macro_test', 'micro_val', 'micro_test', 'le_primary', 'le_macro', 'le_micro', 'OUTPUT_DIR']):
    print('开始 DeBERTa 可靠性（温度标定）评估 ...')

    device = tr_dispatcher.model.device

    # Dispatcher
    tok_d = tr_dispatcher.tokenizer
    val_ds_d = build_encoded_dataset(val_df['text'], val_df['primary_label'], tok_d)
    test_ds_d = build_encoded_dataset(test_df['text'], test_df['primary_label'], tok_d)

    logits_d_val = tr_dispatcher.predict(val_ds_d).predictions  # [N, C]
    logits_d_test = tr_dispatcher.predict(test_ds_d).predictions
    y_d_val = val_df['primary_label'].values
    y_d_test = test_df['primary_label'].values

    T_d_bert = fit_temperature_from_logits(logits_d_val, y_d_val)
    probs_d_test = softmax_np(logits_d_test)
    probs_d_test_cal = softmax_np(logits_d_test / max(T_d_bert, 1e-6))
    ece_d_b = compute_ece(probs_d_test, y_d_test)
    ece_d_a = compute_ece(probs_d_test_cal, y_d_test)
    print(f'DeBERTa Dispatcher 温度 T={T_d_bert:.3f} | ECE 前={ece_d_b:.4f} 后={ece_d_a:.4f}')
    plot_reliability_diagram(probs_d_test, y_d_test, 'DeBERTa Dispatcher (Before TS)', OUTPUT_DIR)
    plot_reliability_diagram(probs_d_test_cal, y_d_test, 'DeBERTa Dispatcher (After TS)', OUTPUT_DIR)

    # Macro Specialist
    tok_m = tr_macro.tokenizer
    val_ds_m = build_encoded_dataset(macro_val['text'], macro_val['specialist_label'], tok_m)
    test_ds_m = build_encoded_dataset(macro_test['text'], macro_test['specialist_label'], tok_m)

    logits_m_val = tr_macro.predict(val_ds_m).predictions
    logits_m_test = tr_macro.predict(test_ds_m).predictions
    y_m_val = macro_val['specialist_label'].values
    y_m_test = macro_test['specialist_label'].values

    T_m_bert = fit_temperature_from_logits(logits_m_val, y_m_val)
    probs_m_test = softmax_np(logits_m_test)
    probs_m_test_cal = softmax_np(logits_m_test / max(T_m_bert, 1e-6))
    ece_m_b = compute_ece(probs_m_test, y_m_test)
    ece_m_a = compute_ece(probs_m_test_cal, y_m_test)
    print(f'DeBERTa Macro Specialist 温度 T={T_m_bert:.3f} | ECE 前={ece_m_b:.4f} 后={ece_m_a:.4f}')
    plot_reliability_diagram(probs_m_test, y_m_test, 'DeBERTa Macro Specialist (Before TS)', OUTPUT_DIR)
    plot_reliability_diagram(probs_m_test_cal, y_m_test, 'DeBERTa Macro Specialist (After TS)', OUTPUT_DIR)

    # Micro Specialist
    tok_mi = tr_micro.tokenizer
    val_ds_mi = build_encoded_dataset(micro_val['text'], micro_val['specialist_label'], tok_mi)
    test_ds_mi = build_encoded_dataset(micro_test['text'], micro_test['specialist_label'], tok_mi)

    logits_mi_val = tr_micro.predict(val_ds_mi).predictions
    logits_mi_test = tr_micro.predict(test_ds_mi).predictions
    y_mi_val = micro_val['specialist_label'].values
    y_mi_test = micro_test['specialist_label'].values

    T_mi_bert = fit_temperature_from_logits(logits_mi_val, y_mi_val)
    probs_mi_test = softmax_np(logits_mi_test)
    probs_mi_test_cal = softmax_np(logits_mi_test / max(T_mi_bert, 1e-6))
    ece_mi_b = compute_ece(probs_mi_test, y_mi_test)
    ece_mi_a = compute_ece(probs_mi_test_cal, y_mi_test)
    print(f'DeBERTa Micro Specialist 温度 T={T_mi_bert:.3f} | ECE 前={ece_mi_b:.4f} 后={ece_mi_a:.4f}')
    plot_reliability_diagram(probs_mi_test, y_mi_test, 'DeBERTa Micro Specialist (Before TS)', OUTPUT_DIR)
    plot_reliability_diagram(probs_mi_test_cal, y_mi_test, 'DeBERTa Micro Specialist (After TS)', OUTPUT_DIR)
else:
    print('提示：未检测到 DeBERTa 相关对象或数据集，跳过 DeBERTa 的温度标定。')


# ==========================================
# 三、“3+5 vs 2+6” 对比图（读取 v3_1 结果）
# ==========================================

def plot_k_search_comparison(ctm_dir, out_dir):
    """
    从 v3_1 的 grid_macro.json / grid_micro.json 读取每个 K 的 coherence_c_v 与 topic_diversity，
    绘制对比曲线，并重点标注：
      - 宏：K=3 与 K=2 的对比
      - 微：K=5 与 K=6 的对比
    """
    macro_path = os.path.join(ctm_dir, 'grid_macro.json')
    micro_path = os.path.join(ctm_dir, 'grid_micro.json')

    if not (os.path.exists(macro_path) and os.path.exists(micro_path)):
        print('未找到 v3_1 的 K 搜索结果（grid_macro.json/grid_micro.json），跳过“3+5 vs 2+6”绘图。')
        return

    with open(macro_path, 'r', encoding='utf-8') as f:
        macro = json.load(f)
    with open(micro_path, 'r', encoding='utf-8') as f:
        micro = json.load(f)

    def to_df(lst):
        return pd.DataFrame(lst).sort_values('k')

    dfM = to_df(macro)
    dfm = to_df(micro)

    fig, axes = plt.subplots(2, 2, figsize=(11, 8), sharex='col')
    # 宏：coherence
    axes[0,0].plot(dfM['k'], dfM['coherence_c_v'], 'o-', label='c_v')
    axes[0,0].axvline(3, color='g', linestyle='--', label='K=3')
    axes[0,0].axvline(2, color='r', linestyle=':', label='K=2')
    axes[0,0].set_title('Macro: Coherence (c_v)')
    axes[0,0].set_ylabel('c_v'); axes[0,0].legend()

    # 宏：topic diversity
    axes[1,0].plot(dfM['k'], dfM['topic_diversity'], 'o-', label='TD')
    axes[1,0].axvline(3, color='g', linestyle='--', label='K=3')
    axes[1,0].axvline(2, color='r', linestyle=':', label='K=2')
    axes[1,0].set_title('Macro: Topic Diversity')
    axes[1,0].set_xlabel('K'); axes[1,0].set_ylabel('Diversity'); axes[1,0].legend()

    # 微：coherence
    axes[0,1].plot(dfm['k'], dfm['coherence_c_v'], 'o-', label='c_v')
    axes[0,1].axvline(5, color='g', linestyle='--', label='K=5')
    axes[0,1].axvline(6, color='r', linestyle=':', label='K=6')
    axes[0,1].set_title('Micro: Coherence (c_v)')
    axes[0,1].legend()

    # 微：topic diversity
    axes[1,1].plot(dfm['k'], dfm['topic_diversity'], 'o-', label='TD')
    axes[1,1].axvline(5, color='g', linestyle='--', label='K=5')
    axes[1,1].axvline(6, color='r', linestyle=':', label='K=6')
    axes[1,1].set_title('Micro: Topic Diversity')
    axes[1,1].set_xlabel('K'); axes[1,1].legend()

    plt.suptitle('v3_1 层次 CTM：3+5 vs 2+6 指标对比')
    plt.tight_layout(rect=[0, 0.03, 1, 0.97])
    fn = os.path.join(out_dir, 'ctm_3plus5_vs_2plus6_comparison.png')
    plt.savefig(fn, dpi=150)
    plt.close()
    print('已保存 “3+5 vs 2+6” 对比图：', fn)


# 调用绘图（若 v3_1 的输出路径可用）
if 'CTM_OUTPUT_DIR' in globals() and 'OUTPUT_DIR' in globals():
    plot_k_search_comparison(CTM_OUTPUT_DIR, OUTPUT_DIR)
else:
    print('提示：未检测到 CTM_OUTPUT_DIR 或 OUTPUT_DIR，跳过 “3+5 vs 2+6” 绘图。')


print('v3_2 增强部分（c-TF-IDF、温度标定与可靠性、3+5 vs 2+6 对比图）已执行完成。')