In [None]:
%cd /content/drive/MyDrive/Colab Notebooks/

/content/drive/MyDrive/Colab Notebooks


# 1政策文本嵌入

In [None]:
import json
from tqdm import tqdm
import time
from openai import OpenAI

# --- API 配置 ---
# 请将此处的 Token 替换为您自己的 Token
# 为安全起见，建议使用环境变量等方式管理 Token，此处为演示方便直接写入。
API_TOKEN = ""
BASE_URL = "https://uni-api.cstcloud.cn/v1"
EMBEDDING_MODEL = "qwen3-embedding:8b"

def generate_embeddings_with_api(input_file_path, output_file_path):
    """
    加载政策数据，过滤无效条目，通过调用API生成文本嵌入，并将结果保存到新的JSON文件。

    Args:
        input_file_path (str): 输入的JSON文件路径。
        output_file_path (str): 输出的JSON文件路径。
    """
    # 1. 初始化 OpenAI API 客户端
    # 使用您指定的 BASE_URL 和 API_TOKEN
    print("Initializing API client...")
    try:
        client = OpenAI(
            api_key=API_TOKEN,
            base_url=BASE_URL,
        )
        print("API client initialized successfully.")
    except Exception as e:
        print(f"Error initializing API client: {e}")
        return

    # 2. 读取原始JSON数据
    try:
        with open(input_file_path, 'r', encoding='utf-8') as f:
            all_policies = json.load(f)
        print(f"Successfully loaded {len(all_policies)} total records from {input_file_path}")
    except FileNotFoundError:
        print(f"Error: Input file not found at {input_file_path}")
        return
    except json.JSONDecodeError:
        print(f"Error: Could not decode JSON from {input_file_path}")
        return

    # 3. 过滤数据并准备用于嵌入的文本
    valid_policies = []
    texts_to_embed = []
    print("Filtering policies and preparing text for embedding...")
    for policy in all_policies:
        short_desc = policy.get('ShortDescription')
        # 过滤机制：确保ShortDescription不为空或仅包含空白字符
        if short_desc and str(short_desc).strip():
            # 将原始名称和简短描述合并为一个文本字符串
            # 如果NameEnglish为空，则使用空字符串代替
            name_orig = policy.get('NameEnglish', '') or ''
            combined_text = f"{name_orig}. {short_desc}"

            texts_to_embed.append(combined_text)
            valid_policies.append(policy)

    print(f"Found {len(valid_policies)} valid policies to process.")
    if not valid_policies:
        print("No valid policies to process. Exiting.")
        return

    # 4. 通过API分批生成嵌入向量
    print(f"Generating embeddings via API using model: {EMBEDDING_MODEL}. This may take a while...")

    all_embeddings = []
    batch_size = 100  # 设置合理的批处理大小以避免单次请求数据过多

    for i in tqdm(range(0, len(texts_to_embed), batch_size), desc="API Call Progress"):
        batch_texts = texts_to_embed[i:i + batch_size]

        try:
            # 调用 embedding API
            response = client.embeddings.create(
                model=EMBEDDING_MODEL,
                input=batch_texts
            )
            # 从返回结果中提取 embedding 向量
            batch_embeddings = [item.embedding for item in response.data]
            all_embeddings.extend(batch_embeddings)

        except Exception as e:
            print(f"\nAn error occurred during API call for batch starting at index {i}: {e}")
            print("Stopping the process. Please check your API key, network connection, or API service status.")
            return # 遇到错误时终止程序

    # 5. 将嵌入向量添加回字典列表
    # 此时 all_embeddings 的顺序与 valid_policies 完全对应
    for i, policy in enumerate(valid_policies):
        policy['embed'] = all_embeddings[i]

    # 6. 保存带有嵌入向量的新JSON文件
    print(f"\nSaving data with embeddings to {output_file_path}...")
    try:
        with open(output_file_path, 'w', encoding='utf-8') as f:
            json.dump(valid_policies, f, ensure_ascii=False, indent=4)
        print("✅ Success!")
        print(f"Processed {len(valid_policies)} policies and saved to {output_file_path}")
    except Exception as e:
        print(f"Error saving file: {e}")

# --- 主程序入口 ---
if __name__ == '__main__':
    # 请将此路径替换为您自己的文件路径
    # 注意：在Google Colab中运行时，请确保文件路径正确无误
    INPUT_JSON_PATH = '/content/drive/MyDrive/Colab Notebooks/验证3+5/gemini_merged_policy_data_with_labels_v2.json'
    OUTPUT_JSON_PATH = '/content/drive/MyDrive/Colab Notebooks/验证3+5/policies_with_embeddings_api.json'

    generate_embeddings_with_api(INPUT_JSON_PATH, OUTPUT_JSON_PATH)

Initializing API client...
API client initialized successfully.
Successfully loaded 13532 total records from /content/drive/MyDrive/Colab Notebooks/验证3+5/gemini_merged_policy_data_with_labels_v2.json
Filtering policies and preparing text for embedding...
Found 13532 valid policies to process.
Generating embeddings via API using model: qwen3-embedding:8b. This may take a while...


API Call Progress: 100%|██████████| 136/136 [07:46<00:00,  3.43s/it]



Saving data with embeddings to /content/drive/MyDrive/Colab Notebooks/验证3+5/policies_with_embeddings_api.json...
✅ Success!
Processed 13532 policies and saved to /content/drive/MyDrive/Colab Notebooks/验证3+5/policies_with_embeddings_api.json


# 2无监督分类

In [None]:
# -*- coding: utf-8 -*-
"""
Unsupervised Classification: BERTopic and Hierarchical Topic Modeling with HTM-WS (Improved)
"""

# Install required packages
import sys
import subprocess

def install(package):
    subprocess.check_call([sys.executable, "-m", "pip", "install", package])

print("Installing required packages...")
required_packages = [
    "bertopic",
    "hdbscan",
    "umap-learn",
    "nltk",
    "tqdm",
    "openai",  # Added for API access
    "tenacity"  # Added for retry functionality
]

for package in required_packages:
    try:
        print(f"Installing {package}...")
        install(package)
    except Exception as e:
        print(f"Error installing {package}: {e}")

import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired, MaximalMarginalRelevance
from sklearn.feature_extraction.text import CountVectorizer
import hdbscan
import umap
import nltk
from nltk.corpus import stopwords
import os
import warnings
from openai import OpenAI
from tenacity import retry, stop_after_attempt, wait_exponential
import time

warnings.filterwarnings("ignore")

class Qwen3EmbeddingModel:
    """
    A real-time embedding model using Qwen3-8b API.
    This properly embeds documents and keywords using the Qwen3 API.
    """
    def __init__(self):
        """
        Initialize the Qwen3 embedding model.
        """
        self.client = OpenAI(
            api_key="f71888795322d9ab77a5508cee879f70a301f1f72de3aaf520858e16ae645d72",
            base_url="https://uni-api.cstcloud.cn/v1"
        )
        self.model = "qwen3-embedding:8b"
        self.embedding_dim = 4096  # Qwen3-8b's embedding dimension
        print(f"Initialized Qwen3 embedding model with dimension: {self.embedding_dim}")

    @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=60))
    def embed(self, documents, verbose=False):
        """
        Embed documents using Qwen3-8b API.

        Args:
            documents (list): List of documents to embed
            verbose (bool): Whether to print progress

        Returns:
            np.ndarray: Array of embeddings with shape (len(documents), embedding_dim)
        """
        if verbose:
            print(f"Embedding {len(documents)} texts with Qwen3 API")

        # Process in batches to avoid API limits
        batch_size = 100
        all_embeddings = []

        for i in tqdm(range(0, len(documents), batch_size), disable=not verbose):
            batch = documents[i:i+batch_size]
            try:
                response = self.client.embeddings.create(
                    input=batch,
                    model=self.model
                )
                # Extract embeddings from the response
                batch_embeddings = [item.embedding for item in response.data]
                all_embeddings.extend(batch_embeddings)

                # Respect rate limits
                time.sleep(0.5)
            except Exception as e:
                print(f"Error during embedding batch {i}-{i+batch_size}: {e}")
                # Return a fallback embedding for this batch
                fallback_emb = [[0.0] * self.embedding_dim] * len(batch)
                all_embeddings.extend(fallback_emb)

        return np.array(all_embeddings)

class BERTopicAnalyzer:
    def __init__(self, input_path, output_dir, use_api_for_embeddings=True):
        """
        Initialize the BERTopicAnalyzer with paths and parameters.
        Args:
            input_path (str): Path to the JSON file with embedded policy data.
            output_dir (str): Directory to save results.
            use_api_for_embeddings (bool): Whether to use Qwen3 API for embeddings.
        """
        self.input_path = input_path
        self.output_dir = output_dir
        os.makedirs(self.output_dir, exist_ok=True)

        # Flag to decide whether to use the API for embeddings
        self.use_api_for_embeddings = use_api_for_embeddings

        # --- Model Parameters ---
        self.n_neighbors = 15       # UMAP: balances local vs. global structure
        self.min_cluster_size = 30  # HDBSCAN: minimum size for a topic cluster
        self.random_state = 42      # For reproducibility of UMAP

        # --- Data & Stopwords ---
        self.documents = []
        self.embeddings = np.array([])
        self.doc_ids = []

        # Initialize the embedding model
        if self.use_api_for_embeddings:
            self.embedding_model = Qwen3EmbeddingModel()
        else:
            self.embedding_model = None

        self.load_data()
        self.setup_stopwords()

        # --- Models ---
        self.topic_model = None

    def load_data(self):
        """Load policy data with embeddings from the JSON file."""
        print(f"Loading data from {self.input_path}...")
        with open(self.input_path, 'r', encoding='utf-8') as f:
            policies = json.load(f)

        temp_docs, temp_embeds, temp_ids = [], [], []
        for i, policy in enumerate(policies):
            name = policy.get('NameEnglish', '') or ''
            desc = policy.get('ShortDescription', '') or ''
            combined_text = f"{name}. {desc}".strip()

            # Only add documents with sufficient content
            if len(combined_text) > 10:
                temp_docs.append(combined_text)
                # If we're using API, we'll compute embeddings later
                if not self.use_api_for_embeddings and 'embed' in policy:
                    temp_embeds.append(policy['embed'])
                temp_ids.append(i) # Store original index as doc_id

        self.documents = temp_docs
        self.doc_ids = temp_ids

        # If we're using the API, compute embeddings now
        if self.use_api_for_embeddings:
            print("Computing embeddings using Qwen3 API (this may take some time)...")
            self.embeddings = self.embedding_model.embed(self.documents, verbose=True)
        else:
            self.embeddings = np.array(temp_embeds)

        print(f"Loaded {len(self.documents)} documents with embeddings.")

    def setup_stopwords(self):
        """Download and configure stopwords."""
        try:
            base_stopwords = stopwords.words('english')
        except LookupError:
            print("Downloading NLTK stopwords...")
            nltk.download('stopwords')
            base_stopwords = stopwords.words('english')

        domain_stopwords = [
            'policy', 'measure', 'action', 'law', 'government',
            'regulation', 'support', 'development', 'research', 'innovation',
            'technology', 'science',
            'para', 'use', 'based', 'provide', 'new', 'including', 'public',
            'sector', 'countries', 'value', 'mainstreaming', 'access', 'activities',
            'capacities', 'change', 'opportunities', 'quality', 'level',
            'smes', 'sme', 'ist', 'national', 'international', 'federal',
            'regional', 'european'
        ]
        self.stopwords = base_stopwords + domain_stopwords

    def run_flat_topic_modeling(self):
        """Phase 1: Run BERTopic for flat topic discovery."""
        print("\n--- Phase 1: Running Flat Topic Modeling with BERTopic ---")

        # Configure UMAP for dimensionality reduction
        umap_model = umap.UMAP(n_neighbors=self.n_neighbors, n_components=5,
                               min_dist=0.0, metric='cosine', random_state=self.random_state)

        # Configure HDBSCAN for clustering
        hdbscan_model = hdbscan.HDBSCAN(min_cluster_size=self.min_cluster_size,
                                        metric='euclidean', cluster_selection_method='eom',
                                        prediction_data=True)

        # Configure CountVectorizer with our custom stopwords
        vectorizer = CountVectorizer(stop_words=self.stopwords, min_df=5, max_df=0.8)

        # Configure advanced representation models for better topic interpretability
        keybert_model = KeyBERTInspired()
        mmr_model = MaximalMarginalRelevance(diversity=0.3)
        representation_model = {
            "KeyBERT": keybert_model,
            "MMR": mmr_model
        }

        # Initialize and fit BERTopic - now with the real-time embedding model
        self.topic_model = BERTopic(
            # If using API, provide the model; otherwise, pass embeddings directly
            embedding_model=self.embedding_model if self.use_api_for_embeddings else None,
            umap_model=umap_model,
            hdbscan_model=hdbscan_model,
            vectorizer_model=vectorizer,
            representation_model=representation_model,
            verbose=True,
            nr_topics="auto",
            calculate_probabilities=True
        )

        print("Fitting BERTopic model...")
        if self.use_api_for_embeddings:
            # If using API, let the model embed everything
            topics, _ = self.topic_model.fit_transform(self.documents)
        else:
            # Otherwise pass pre-computed embeddings
            topics, _ = self.topic_model.fit_transform(self.documents, self.embeddings)

        print("\nPhase 1 complete. Saving results and visualizations...")
        self.save_topic_info()
        self.generate_visualizations()
        print("\nPlease analyze 'topic_info.csv' and 'topic_keywords.json' to perform manual mapping for Phase 2.")
        return self.topic_model

    def run_hierarchical_topic_modeling(self):
        """
        Phase 2: Implement HTM-WS by running nested BERTopic models on
        theoretically-grouped synthetic sub-corpora.
        """
        print("\n--- Phase 2: Running Hierarchical Topic Modeling (HTM-WS) ---")

        if self.topic_model is None:
            print("Error: Flat topic model (Phase 1) must be run first.")
            return

        # ========================== CRITICAL MANUAL STEP ==========================
        # After running Phase 1, you MUST manually inspect the generated topics
        # (e.g., from 'topic_info.csv') and map them to your theoretical framework.
        # This dictionary is the bridge between empirical findings and your theory.
        #
        # EXAMPLE:
        # manual_topic_mapping = {
        #     "macro": [0, 5, 12, 18],  # Topics that correspond to macro-level policies
        #     "micro": [1, 2, 3, 4, 6, 7, 9, 10, 11, 13, 14, 15, 16, 17] # Topics for micro-level
        # }
        #
        # FOR THIS RUN, PLEASE REPLACE THIS EXAMPLE MAPPING WITH YOUR ACTUAL ANALYSIS.
        # ==========================================================================
        manual_topic_mapping = {
            "macro": [0, 1, 2],  # Placeholder: Replace with your actual topic IDs
            "micro": [3, 4, 5, 6, 7]   # Placeholder: Replace with your actual topic IDs
        }
        print(f"Using manual mapping for HTM-WS: {manual_topic_mapping}")

        # Get the document-topic assignments from the flat model
        doc_info = self.topic_model.get_document_info(self.documents)
        # Map original document indices (doc_ids) to new DataFrame indices
        doc_info['Original_Index'] = self.doc_ids
        doc_info = doc_info.set_index('Original_Index')

        hierarchical_results = {}

        # Process each theoretical category (macro and micro)
        for category, topic_ids in manual_topic_mapping.items():
            print(f"\nProcessing category: '{category.upper()}'")

            # --- 1. Create Synthetic Sub-Corpus ---
            # Get documents and embeddings belonging to this category
            category_doc_indices = doc_info[doc_info.Topic.isin(topic_ids)].index.tolist()

            if len(category_doc_indices) < self.min_cluster_size:
                print(f"  Skipping '{category}': only {len(category_doc_indices)} documents found (less than min_cluster_size).")
                continue

            # We need to map original indices back to their positions in the self.documents list
            corpus_indices = [self.doc_ids.index(orig_idx) for orig_idx in category_doc_indices]
            sub_corpus_docs = [self.documents[i] for i in corpus_indices]

            if not self.use_api_for_embeddings:
                sub_corpus_embeddings = self.embeddings[corpus_indices]

            print(f"  Created sub-corpus for '{category}' with {len(sub_corpus_docs)} documents.")

            # --- 2. Run Nested BERTopic Model ---
            print(f"  Running nested BERTopic on '{category}' sub-corpus...")
            sub_topic_model = self.create_subtopic_model()

            if self.use_api_for_embeddings:
                # Let the model compute embeddings using the API
                sub_topics, _ = sub_topic_model.fit_transform(sub_corpus_docs)
            else:
                # Use pre-computed embeddings
                sub_topics, _ = sub_topic_model.fit_transform(sub_corpus_docs, sub_corpus_embeddings)

            # --- 3. Save and Visualize Results ---
            sub_topic_info = sub_topic_model.get_topic_info()
            sub_topic_info.to_csv(f"{self.output_dir}/htm_ws_{category}_subtopic_info.csv", index=False)

            # Save keywords
            sub_topic_keywords = {}
            for sub_topic_id in sub_topic_info['Topic']:
                if sub_topic_id != -1:
                    keywords = sub_topic_model.get_topic(sub_topic_id)
                    sub_topic_keywords[f"SubTopic_{sub_topic_id}"] = [kw[0] for kw in keywords]

            hierarchical_results[category] = {
                "parent_topics": topic_ids,
                "num_documents": len(sub_corpus_docs),
                "num_sub_topics_found": len(sub_topic_info[sub_topic_info.Topic != -1]),
                "sub_topics": sub_topic_keywords
            }

            # Visualize
            try:
                fig = sub_topic_model.visualize_topics()
                fig.write_html(f"{self.output_dir}/htm_ws_{category}_subtopics_visualization.html")
                plt.close()
            except Exception as e:
                print(f"  Could not generate topic visualization for {category}: {e}")

        # Save the overall hierarchical structure summary
        with open(f"{self.output_dir}/hierarchical_structure_summary.json", 'w', encoding='utf-8') as f:
            json.dump(hierarchical_results, f, ensure_ascii=False, indent=4)

        print("\nPhase 2 complete. Hierarchical analysis results saved.")
        return hierarchical_results

    def create_subtopic_model(self):
        """Create a new BERTopic instance with parameters suitable for smaller sub-corpora."""
        umap_model = umap.UMAP(n_neighbors=10, n_components=5, min_dist=0.0,
                              metric='cosine', random_state=self.random_state)
        hdbscan_model = hdbscan.HDBSCAN(min_cluster_size=15, metric='euclidean',
                                       cluster_selection_method='eom', prediction_data=True)
        vectorizer = CountVectorizer(stop_words=self.stopwords, min_df=3, max_df=0.85)

        return BERTopic(
            # Use the same embedding approach as the main model
            embedding_model=self.embedding_model if self.use_api_for_embeddings else None,
            umap_model=umap_model,
            hdbscan_model=hdbscan_model,
            vectorizer_model=vectorizer,
            verbose=False,
            nr_topics="auto",
            calculate_probabilities=False
        )

    def save_topic_info(self):
        """Save topic information from the flat model to files."""
        topic_info = self.topic_model.get_topic_info()
        topic_info.to_csv(f"{self.output_dir}/topic_info.csv", index=False)

        topic_keywords = {}
        for topic_id in topic_info['Topic']:
            if topic_id != -1:
                keywords = self.topic_model.get_topic(topic_id)
                topic_keywords[str(topic_id)] = [kw[0] for kw in keywords]

        with open(f"{self.output_dir}/topic_keywords.json", 'w', encoding='utf-8') as f:
            json.dump(topic_keywords, f, ensure_ascii=False, indent=4)
        print(f"Topic information saved to {self.output_dir}")

    def generate_visualizations(self):
        """Generate and save visualizations for the flat model."""
        print("Generating visualizations for the flat model...")
        try:
            # 1. Interactive 2D projection
            fig = self.topic_model.visualize_topics()
            fig.write_html(f"{self.output_dir}/topics_visualization.html")

            # 2. Hierarchical clustering dendrogram
            fig = self.topic_model.visualize_hierarchy()
            fig.write_html(f"{self.output_dir}/topic_hierarchy.html")

            # 3. Bar charts for top N topics
            fig = self.topic_model.visualize_barchart(top_n_topics=12)
            fig.write_html(f"{self.output_dir}/topic_barchart.html")

            # 4. Topic similarity matrix
            fig = self.topic_model.visualize_heatmap(top_n_topics=20)
            fig.write_html(f"{self.output_dir}/topic_similarity_heatmap.html")

            print(f"Interactive visualizations saved as HTML files in {self.output_dir}")
        except Exception as e:
            print(f"An error occurred during visualization generation: {e}")
            print("Please check if the model produced enough topics.")
        plt.close('all') # Close all plot figures

    def run_full_analysis(self):
        """Execute the complete unsupervised analysis pipeline."""
        # Phase 1: Flat topic modeling
        self.run_flat_topic_modeling()

        # Phase 2: Hierarchical topic modeling
        self.run_hierarchical_topic_modeling()

        print("\nUnsupervised analysis complete!")
        print(f"All results saved to {self.output_dir}")
        return self.topic_model

# --- Main Execution ---
if __name__ == "__main__":
    # Ensure you are running this in an environment with sufficient memory.
    # Google Colab Pro is recommended for datasets of this size.
    INPUT_JSON_PATH = '/content/drive/MyDrive/Colab Notebooks/验证3+5/policies_with_embeddings_api.json'
    OUTPUT_DIR = '/content/drive/MyDrive/Colab Notebooks/验证3+5/bertopic_results_improved'

    # Set use_api_for_embeddings=True to use Qwen3 API for all embeddings
    analyzer = BERTopicAnalyzer(INPUT_JSON_PATH, OUTPUT_DIR, use_api_for_embeddings=True)
    topic_model = analyzer.run_full_analysis()

# 3有监督分类

## 环境配置

In [None]:
!pip install bertopic hdbscan umap-learn lightgbm shap transformers datasets nltk hiclass
!pip install bertviz  # 用于注意力机制可视化（可选）

# 下载NLTK资源
import nltk
nltk.download('stopwords')

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification

MODEL_NAME = 'microsoft/deberta-v3-base'
SAVE_DIRECTORY = '/content/drive/MyDrive/Colab Notebooks/验证3+5/deberta-v3-base-local'

print(f"正在下载模型和分词器：'{MODEL_NAME}'...")
print("这是一个大型模型，可能需要几分钟时间...")

# 下载并保存分词器
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.save_pretrained(SAVE_DIRECTORY)

# 下载并保存模型
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
model.save_pretrained(SAVE_DIRECTORY)

print(f"\n模型和分词器已成功保存到 '{SAVE_DIRECTORY}'")

## 下载模型

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# --- UPDATED a---
MODEL_NAME = 'microsoft/deberta-v3-base'
SAVE_DIRECTORY = '/content/drive/MyDrive/Colab Notebooks/验证3+5/deberta-v3-base-local'
# --- END UPDATE ---

print(f"Downloading model and tokenizer for '{MODEL_NAME}'...")
print("This is a large model and may take several minutes...")

# Download and save the tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.save_pretrained(SAVE_DIRECTORY)

# Download and save the model
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
model.save_pretrained(SAVE_DIRECTORY)

print(f"\nModel and tokenizer saved successfully to '{SAVE_DIRECTORY}'")

## 层次有监督分类模型

In [None]:
# -*- coding: utf-8 -*-
"""
Hierarchical Supervised Classification with DeBERTa and LightGBM
Enhanced with XAI via SHAP and Attention Visualization
(Revised and Enhanced Version)
"""

# --- Environment Setup ---
# Disable W&B logging to avoid prompts and unnecessary outputs
import os
os.environ["WANDB_DISABLED"] = "true"

# --- Core Imports ---
import json
import numpy as np
import pandas as pd
import torch
import lightgbm as lgb
import matplotlib
matplotlib.use('Agg') # Use a non-interactive backend for saving figures
import matplotlib.pyplot as plt
import seaborn as sns
import shap
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix
from sklearn.utils.class_weight import compute_class_weight

# --- Transformers Imports ---
from transformers import (
    AutoTokenizer, AutoModelForSequenceClassification, Trainer,
    TrainingArguments, DataCollatorWithPadding,
    EarlyStoppingCallback
)
from datasets import Dataset

# --- Hierarchical Metrics ---
# Note: Your custom implementation is clear and correct. For future projects,
# the 'hiclass' library offers a standardized way to compute these metrics.
def h_precision(y_true_paths, y_pred_paths):
    """Computes hierarchical precision."""
    correct_nodes = 0
    total_predicted_nodes = 0
    for true_path, pred_path in zip(y_true_paths, y_pred_paths):
        true_set = set(true_path)
        pred_set = set(pred_path)
        correct_nodes += len(true_set.intersection(pred_set))
        total_predicted_nodes += len(pred_set)
    return correct_nodes / max(total_predicted_nodes, 1)

def h_recall(y_true_paths, y_pred_paths):
    """Computes hierarchical recall."""
    correct_nodes = 0
    total_true_nodes = 0
    for true_path, pred_path in zip(y_true_paths, y_pred_paths):
        true_set = set(true_path)
        pred_set = set(pred_path)
        correct_nodes += len(true_set.intersection(pred_set))
        total_true_nodes += len(true_set)
    return correct_nodes / max(total_true_nodes, 1)

def h_f1(y_true_paths, y_pred_paths):
    """Computes hierarchical F1-score."""
    prec = h_precision(y_true_paths, y_pred_paths)
    rec = h_recall(y_true_paths, y_pred_paths)
    if prec + rec == 0:
        return 0
    return 2 * prec * rec / (prec + rec)

# --- Attention Visualization (BertViz) ---
try:
    from bertviz import head_view, model_view
    HAS_BERTVIZ = True
except ImportError:
    print("Warning: BertViz not installed. Attention visualization will be disabled.")
    HAS_BERTVIZ = False

# --- Configuration Class ---
class Config:
    """Configuration for the supervised classification pipeline"""
    # Paths
    # IMPORTANT: Please update this path to your actual file location
    INPUT_JSON_PATH = '/content/drive/MyDrive/Colab Notebooks/验证3+5/policies_with_embeddings_api.json'
    OUTPUT_DIR = '/content/drive/MyDrive/Colab Notebooks/验证3+5/supervised_results'
    TRANSFORMER_MODEL = 'microsoft/deberta-v3-base' # Using a model from Hub for easier access

    # Split parameters
    RANDOM_STATE = 42
    TEST_SIZE = 0.2
    VALIDATION_SIZE = 0.1 # This will be a fraction of the (1 - TEST_SIZE) data

    # LightGBM parameters
    LGBM_PARAMS = {
        'objective': 'multiclass', 'metric': 'multi_logloss', 'n_estimators': 2000,
        'learning_rate': 0.02, 'feature_fraction': 0.8, 'bagging_fraction': 0.8,
        'num_leaves': 31, 'reg_alpha': 0.1, 'reg_lambda': 0.1, 'bagging_freq': 1,
        'verbose': -1, 'n_jobs': -1, 'seed': RANDOM_STATE, 'boosting_type': 'gbdt',
    }

    # --- [FIXED] Transformer parameters ---
    # Removed 'output_dir' from this dictionary to prevent the TypeError.
    # The output directory is now set dynamically in the training function.
    TRAINING_ARGS = {
        'num_train_epochs': 8, 'learning_rate': 2e-5,
        'per_device_train_batch_size': 8, 'per_device_eval_batch_size': 16,
        'gradient_accumulation_steps': 1, 'warmup_ratio': 0.1, 'weight_decay': 0.01,
        'logging_dir': './logs', 'logging_steps': 50, 'evaluation_strategy': "epoch",
        'save_strategy': "epoch", 'load_best_model_at_end': True,
        'metric_for_best_model': 'f1', 'greater_is_better': True, 'report_to': "none"
    }

    # Hierarchy definition
    HIERARCHY = {
        'Guideline_Strategy': 'Macro',
        'Planning_Layout': 'Macro',
        'Institutional_Arrangements': 'Macro',
        'Resource_Allocation_Policy': 'Micro',
        'Innovation_Actor_Policy': 'Micro',
        'Talent_Policy': 'Micro',
        'Commercialization_Policy': 'Micro',
        'Environment_Shaping_Policy': 'Micro',
        'Macro': 'Root',
        'Micro': 'Root',
        'Root': None
    }
    MACRO_CLASSES = ['Guideline_Strategy', 'Planning_Layout', 'Institutional_Arrangements']

# --- Custom Trainer for Weighted Loss ---
class CustomTrainer(Trainer):
    """A custom trainer to apply class weights for handling imbalanced datasets."""
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.get("logits")

        # Calculate class weights
        class_weights = compute_class_weight(
            class_weight='balanced',
            classes=np.unique(self.train_dataset['label']),
            y=np.array(self.train_dataset['label'])
        )
        weights_tensor = torch.tensor(class_weights, dtype=torch.float, device=model.device)
        loss_fct = torch.nn.CrossEntropyLoss(weight=weights_tensor)
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

# --- Main Classifier Class ---
class HierarchicalClassifier:
    """Implements the complete hierarchical classification and XAI pipeline."""
    def __init__(self, config):
        self.config = config
        os.makedirs(self.config.OUTPUT_DIR, exist_ok=True)
        self.lgbm_models = {}
        self.transformer_models = {}
        self.label_encoders = {}
        self.datasets = {}
        self.load_and_prepare_data()

    def load_and_prepare_data(self):
        """Loads data, creates labels, and performs methodologically sound splits."""
        print("1. Loading and preparing data...")
        # A check to ensure the input file exists
        if not os.path.exists(self.config.INPUT_JSON_PATH):
            raise FileNotFoundError(f"Input JSON not found at: {self.config.INPUT_JSON_PATH}. Please update the path in the Config class.")

        df = pd.read_json(self.config.INPUT_JSON_PATH)
        df['text'] = df['NameEnglish'].fillna('') + ". " + df['ShortDescription'].fillna('')
        df = df[df['text'].str.len() > 10].reset_index(drop=True) # Basic cleaning

        # Create primary (Macro/Micro) and secondary (8 classes) labels
        df['primary_label_str'] = df['ClassificationLabel'].apply(lambda x: 'Macro' if x in self.config.MACRO_CLASSES else 'Micro')
        self.df = df

        # Encode all labels and store encoders
        self.label_encoders['primary'] = LabelEncoder().fit(self.df['primary_label_str'])
        self.label_encoders['secondary'] = LabelEncoder().fit(self.df['ClassificationLabel'])
        self.df['primary_label'] = self.label_encoders['primary'].transform(self.df['primary_label_str'])
        self.df['secondary_label'] = self.label_encoders['secondary'].transform(self.df['ClassificationLabel'])

        # --- Methodologically Sound Data Splitting ---
        # Split the main dataframe first
        train_val_df, test_df = train_test_split(
            self.df, test_size=self.config.TEST_SIZE, random_state=self.config.RANDOM_STATE,
            stratify=self.df['secondary_label']
        )
        relative_val_size = self.config.VALIDATION_SIZE / (1 - self.config.TEST_SIZE)
        train_df, val_df = train_test_split(
            train_val_df, test_size=relative_val_size, random_state=self.config.RANDOM_STATE,
            stratify=train_val_df['secondary_label']
        )
        self.datasets['main'] = {'train': train_df, 'val': val_df, 'test': test_df}
        print(f"Data Split (Main): Train={len(train_df)}, Val={len(val_df)}, Test={len(test_df)}")

        # Create and split specialist datasets DERIVED from the main splits
        for level in ['Macro', 'Micro']:
            le = LabelEncoder().fit(self.df[self.df['primary_label_str'] == level]['ClassificationLabel'])
            self.label_encoders[level.lower()] = le

            self.datasets[level.lower()] = {
                'train': train_df[train_df['primary_label_str'] == level].copy(),
                'val': val_df[val_df['primary_label_str'] == level].copy(),
                'test': test_df[test_df['primary_label_str'] == level].copy(),
            }
            # Add specialist label column to each split
            for split in ['train', 'val', 'test']:
                self.datasets[level.lower()][split]['specialist_label'] = le.transform(
                    self.datasets[level.lower()][split]['ClassificationLabel']
                )

    def train_lgbm_models(self):
        """Trains the complete hierarchical pipeline for LightGBM."""
        print("\n" + "="*50 + "\n2. TRAINING LIGHTGBM HIERARCHICAL CLASSIFIERS\n" + "="*50)
        # Train dispatcher
        self.lgbm_models['dispatcher'], _ = self._train_lgbm_node('main', 'primary_label', 'LGBM Dispatcher')
        # Train specialists
        self.lgbm_models['macro_specialist'], _ = self._train_lgbm_node('macro', 'specialist_label', 'LGBM Macro Specialist', True)
        self.lgbm_models['micro_specialist'], _ = self._train_lgbm_node('micro', 'specialist_label', 'LGBM Micro Specialist', True)
        # Evaluate
        self.evaluate_hierarchical_pipeline('lightgbm')

    def _train_lgbm_node(self, data_key, label_col, title, use_class_weight=False):
        """Helper to train a single LightGBM model node."""
        print(f"\n--- Training: {title} ---")
        dsets = self.datasets[data_key]
        le = self.label_encoders['primary' if data_key == 'main' else data_key]
        X_train = np.array(dsets['train']['embed'].tolist())
        y_train = dsets['train'][label_col]
        X_val = np.array(dsets['val']['embed'].tolist())
        y_val = dsets['val'][label_col]
        X_test = np.array(dsets['test']['embed'].tolist())
        y_test = dsets['test'][label_col]

        params = self.config.LGBM_PARAMS.copy()
        params['objective'] = 'binary' if len(le.classes_) == 2 else 'multiclass'
        params['metric'] = 'binary_logloss' if len(le.classes_) == 2 else 'multi_logloss'
        if len(le.classes_) > 2: params['num_class'] = len(le.classes_)
        if use_class_weight: params['class_weight'] = 'balanced'

        model = lgb.LGBMClassifier(**params)
        model.fit(X_train, y_train, eval_set=[(X_val, y_val)], callbacks=[lgb.early_stopping(100, verbose=False)])
        y_pred = model.predict(X_test)
        metrics = self._evaluate_node(y_test, y_pred, le.classes_, title)
        model.le_ = le # Attach encoder for later use
        return model, metrics

    def train_transformer_models(self):
        """Trains the complete hierarchical pipeline for DeBERTa."""
        print("\n" + "="*50 + "\n3. TRAINING TRANSFORMER (DEBERTA) CLASSIFIERS\n" + "="*50)
        if not torch.cuda.is_available(): print("WARNING: No CUDA device found. Training will be very slow.")
        # Train dispatcher
        self.transformer_models['dispatcher'], _ = self._train_transformer_node('main', 'primary_label', 'DeBERTa Dispatcher')
        # Train specialists
        self.transformer_models['macro_specialist'], _ = self._train_transformer_node('macro', 'specialist_label', 'DeBERTa Macro Specialist', True)
        self.transformer_models['micro_specialist'], _ = self._train_transformer_node('micro', 'specialist_label', 'DeBERTa Micro Specialist', True)
        # Evaluate
        self.evaluate_hierarchical_pipeline('transformer')

    def _train_transformer_node(self, data_key, label_col, title, use_custom_trainer=False):
        """Helper to train a single Transformer model node."""
        print(f"\n--- Training: {title} ---")
        dsets = self.datasets[data_key]
        le = self.label_encoders['primary' if data_key == 'main' else data_key]

        train_ds = Dataset.from_pandas(dsets['train'][['text', label_col]].rename(columns={label_col: 'label'}))
        val_ds = Dataset.from_pandas(dsets['val'][['text', label_col]].rename(columns={label_col: 'label'}))
        test_ds = Dataset.from_pandas(dsets['test'][['text', label_col]].rename(columns={label_col: 'label'}))

        tokenizer = AutoTokenizer.from_pretrained(self.config.TRANSFORMER_MODEL)
        model = AutoModelForSequenceClassification.from_pretrained(
            self.config.TRANSFORMER_MODEL, num_labels=len(le.classes_), ignore_mismatched_sizes=True,
            output_attentions=True # Ensure attentions are output for visualization
        )

        def tokenize(examples): return tokenizer(examples['text'], padding="max_length", truncation=True, max_length=512)
        train_ds, val_ds, test_ds = [ds.map(tokenize, batched=True) for ds in [train_ds, val_ds, test_ds]]

        def compute_metrics(p):
            preds = np.argmax(p.predictions, axis=1)
            return {'f1': f1_score(p.label_ids, preds, average='macro'), 'accuracy': accuracy_score(p.label_ids, preds)}

        # The TrainingArguments object is now created correctly without the duplicate 'output_dir'
        args = TrainingArguments(
            output_dir=f"{self.config.OUTPUT_DIR}/checkpoints/{title.replace(' ', '_')}",
            **self.config.TRAINING_ARGS
        )
        TrainerClass = CustomTrainer if use_custom_trainer else Trainer
        trainer = TrainerClass(
            model=model, args=args, train_dataset=train_ds, eval_dataset=val_ds,
            tokenizer=tokenizer, data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
            compute_metrics=compute_metrics, callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
        )
        trainer.train()

        preds = trainer.predict(test_ds)
        y_pred = np.argmax(preds.predictions, axis=-1)
        metrics = self._evaluate_node(test_ds['label'], y_pred, le.classes_, title)
        trainer.le_ = le # Attach encoder
        return trainer, metrics

    def evaluate_hierarchical_pipeline(self, model_type):
        """Evaluates the complete hierarchical pipeline for a given model type."""
        print(f"\n--- 4. Evaluating End-to-End Pipeline for: {model_type.upper()} ---")
        models = self.lgbm_models if model_type == 'lightgbm' else self.transformer_models
        test_df = self.datasets['main']['test']
        y_true_final = test_df['secondary_label']
        y_pred_final, y_true_paths, y_pred_paths = [], [], []

        for _, row in tqdm(test_df.iterrows(), total=len(test_df), desc=f"Predicting with {model_type}"):
            # Predict primary label
            primary_pred_label, _ = self._predict_node(models['dispatcher'], row, model_type)
            # Select specialist
            specialist = models[f"{primary_pred_label.lower()}_specialist"]
            # Predict secondary label
            final_pred_label, _ = self._predict_node(specialist, row, model_type)
            # Append results
            y_pred_final.append(self.label_encoders['secondary'].transform([final_pred_label])[0])
            true_label = row['ClassificationLabel']
            # This line is now safe because HIERARCHY keys match ClassificationLabel values
            y_true_paths.append([self.config.HIERARCHY[true_label], true_label])
            y_pred_paths.append([primary_pred_label, final_pred_label])

        # Standard Evaluation
        self._evaluate_node(y_true_final, y_pred_final, self.label_encoders['secondary'].classes_, f"End-to-End Pipeline ({model_type.upper()})")
        # Hierarchical Evaluation
        h_p, h_r, h_f1_score = h_precision(y_true_paths, y_pred_paths), h_recall(y_true_paths, y_pred_paths), h_f1(y_true_paths, y_pred_paths)
        print("\n--- Hierarchical Evaluation Metrics ---")
        print(f"Hierarchical Precision: {h_p:.4f}\nHierarchical Recall: {h_r:.4f}\nHierarchical F1 Score: {h_f1_score:.4f}")

    def _predict_node(self, model, row, model_type):
        """Helper to get a prediction from a single model node."""
        if model_type == 'lightgbm':
            features = np.array(row['embed']).reshape(1, -1)
            pred_encoded = model.predict(features)[0]
        else: # transformer
            inputs = model.tokenizer(row['text'], return_tensors="pt", truncation=True, max_length=512).to(model.model.device)
            with torch.no_grad():
                logits = model.model(**inputs).logits
            pred_encoded = torch.argmax(logits, dim=1).item()
        return model.le_.inverse_transform([pred_encoded])[0], pred_encoded

    def analyze_explainability(self):
        """Performs the full XAI analysis pipeline for the transformer models."""
        print("\n" + "="*50 + "\n5. PERFORMING EXPLAINABILITY ANALYSIS (SHAP & ATTENTION)\n" + "="*50)
        if not self.transformer_models:
            print("No transformer models available for XAI analysis.")
            return

        # Use a small, representative subset of the test data for explanations
        sample_df = self.datasets['main']['test'].sample(20, random_state=self.config.RANDOM_STATE)

        # --- SHAP Analysis ---
        print("\n--- Generating SHAP Explanations ---")
        # Create a single SHAP explainer for each model to reuse
        shap_explainers = {
            name: shap.Explainer(self._get_prediction_function(model), model.tokenizer)
            for name, model in self.transformer_models.items()
        }

        # Global Explanations (Summary Plots)
        for name, model in self.transformer_models.items():
            print(f"Generating global SHAP plot for {name}...")
            # Use the appropriate test set for each model
            if name == 'dispatcher':
                data_key = 'main'
            else:
                data_key = name.split('_')[0]

            texts = self.datasets[data_key]['test']['text'].tolist()
            # Use a manageable sample size for SHAP to avoid excessive computation time
            sample_size = min(50, len(texts))
            shap_values = shap_explainers[name](texts[:sample_size])
            plt.figure()
            shap.summary_plot(shap_values, plot_type='bar', class_names=model.le_.classes_, show=False)
            plt.title(f"SHAP Feature Importance - {name}")
            plt.savefig(f"{self.config.OUTPUT_DIR}/shap_global_{name}.png", bbox_inches='tight')
            plt.close()

        # Local Explanations (for specific examples)
        print("\nGenerating local SHAP explanations for one example per class...")
        for label_str in self.label_encoders['secondary'].classes_:
            example_row = self.df[self.df['ClassificationLabel'] == label_str].sample(1, random_state=self.config.RANDOM_STATE).iloc[0]
            text = example_row['text']
            primary_label = example_row['primary_label_str']

            # Explain dispatcher
            self._explain_local_shap(shap_explainers['dispatcher'], text, f"{label_str}_as_dispatcher")
            # Explain specialist
            specialist_name = f"{primary_label.lower()}_specialist"
            self._explain_local_shap(shap_explainers[specialist_name], text, f"{label_str}_as_{specialist_name}")

        # --- Attention Visualization ---
        if HAS_BERTVIZ:
            print("\n--- Visualizing Attention Patterns ---")
            example_text = sample_df.iloc[0]['text']
            for name, model in self.transformer_models.items():
                self._visualize_attention(model, example_text, name)

    def _get_prediction_function(self, trainer):
        """Creates a prediction function compatible with SHAP."""
        model = trainer.model.to("cuda" if torch.cuda.is_available() else "cpu")
        model.eval()
        def predict(texts):
            inputs = trainer.tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
            with torch.no_grad():
                logits = model(**inputs).logits
            return torch.softmax(logits, dim=-1).cpu().numpy() # SHAP works best with probabilities
        return predict

    def _explain_local_shap(self, explainer, text, filename):
        """Generates and saves a single local SHAP text plot."""
        try:
            shap_values = explainer([text])
            # Using shap.plots.text requires careful handling for saving
            # We will generate the plot and save it with matplotlib
            shap.plots.text(shap_values, show=False)
            plt.savefig(f"{self.config.OUTPUT_DIR}/shap_local_{filename}.png", bbox_inches='tight', dpi=150)
            plt.close()
        except Exception as e:
            print(f"Could not generate SHAP plot for {filename}: {e}")

    def _visualize_attention(self, trainer, text, model_name):
        """Generates and saves attention visualizations."""
        print(f"Visualizing attention for {model_name}...")
        model = trainer.model
        tokenizer = trainer.tokenizer
        try:
            inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(model.device)
            tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
            # The model must be loaded with output_attentions=True
            outputs = model(**inputs)
            html_head_view = head_view(outputs.attentions, tokens, html_action='return_string')
            with open(f"{self.config.OUTPUT_DIR}/attention_{model_name}_head_view.html", 'w', encoding='utf-8') as f:
                f.write(html_head_view)
        except Exception as e:
            print(f"Could not visualize attention for {model_name}: {e}")

    def _evaluate_node(self, y_true, y_pred, labels, title):
        """Helper to evaluate a single node and save a confusion matrix."""
        accuracy = accuracy_score(y_true, y_pred)
        f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
        print(f"\n--- Evaluation: {title} ---\nAccuracy: {accuracy:.4f} | Macro F1: {f1:.4f}")
        print(classification_report(y_true, y_pred, target_names=labels, zero_division=0))
        cm = confusion_matrix(y_true, y_pred)
        plt.figure(figsize=(max(8, len(labels)), max(6, len(labels)*0.8)))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)
        plt.title(f'Confusion Matrix - {title}')
        plt.ylabel('True Label'); plt.xlabel('Predicted Label')
        filename = f"{self.config.OUTPUT_DIR}/cm_{title.replace(' ', '_')}.png"
        plt.savefig(filename, bbox_inches='tight'); plt.close()
        return {'accuracy': accuracy, 'f1': f1}

    def run_full_pipeline(self):
        """Runs the entire supervised learning pipeline."""
        self.train_lgbm_models()
        self.train_transformer_models()
        self.analyze_explainability()
        print("\nSupervised classification and XAI pipeline complete!")
        print(f"All results and visualizations saved to {self.config.OUTPUT_DIR}")

# --- Main Execution ---
if __name__ == "__main__":
    try:
        config = Config()
        classifier = HierarchicalClassifier(config)
        classifier.run_full_pipeline()
    except Exception as e:
        print(f"\nAn error occurred during pipeline execution: {e}")