## 阶段一：环境设置与数据准备

### 划分训练、验证、测试集

In [None]:
# ===================================================================
# CELL 5: 实现综合路径分析的任务集划分逻辑
# ===================================================================
import os
import json
import pandas as pd
import collections
import numpy as np
import random
from tqdm.notebook import tqdm
from IPython.display import display, HTML

print("\n--- 正在基于路径分析实现精确的元任务集划分 ---")

# --- 配置路径 ---
GDRIVE_PATH = '/content/drive/MyDrive/Colab Notebooks/'
OUTPUT_DIR = os.path.join(GDRIVE_PATH, "prediction_outputs_paper_final")
STRATIFIED_OUTPUT_DIR = os.path.join(GDRIVE_PATH, "prediction_outputs_stratified_v1")
KG_PATH = os.path.join(GDRIVE_PATH, 'drkg_autoimmune_enhanced_glm4.tsv')
STRATIFIED_PATHS_FILE = os.path.join(STRATIFIED_OUTPUT_DIR, "stratified_paths_k1000.json")
DRUG_STATUS_PATH = os.path.join(OUTPUT_DIR, 'drug_status_lookup.json')
AUTOIMMUNE_DISEASE_LIST_PATH = os.path.join(GDRIVE_PATH, 'mondo_autoimmune_disease_ALL_results.csv')
NEW_TASK_SPLIT_PATH = os.path.join(OUTPUT_DIR, 'meta_task_split_enhanced_v3.json')

# --- 加载数据 ---
print("加载基础数据...")

# 加载知识图谱
kg_df = pd.read_csv(KG_PATH, sep='\t', header=None, names=['head', 'relation', 'tail'])
kg_df.dropna(inplace=True)
print(f"知识图谱加载成功，包含 {len(kg_df)} 条三元组。")

# 加载自身免疫病列表
autoimmune_df = pd.read_csv(AUTOIMMUNE_DISEASE_LIST_PATH)
autoimmune_disease_ids = set("Disease::" + autoimmune_df['ID'].astype(str))
print(f"加载了 {len(autoimmune_disease_ids)} 个自身免疫病ID。")

# 加载药物状态信息
with open(DRUG_STATUS_PATH, 'r') as f:
    drug_status_lookup = json.load(f)
approved_drugs = {drug for drug, status in drug_status_lookup.items() if status == 'approved'}
print(f"加载了 {len(drug_status_lookup)} 个药物状态，其中 {len(approved_drugs)} 个已批准药物。")

# 定义治疗关系类型
POSITIVE_TREATMENT_RELATIONS = {
    'treats',
    'DRUGBANK::treats::Compound:Disease',
    'GNBR::T::Compound:Disease',
    'Hetionet::CtD::Compound:Disease'
}

# --- 第1步：从知识图谱中提取直接的治疗关系 ---
print("\n1. 从知识图谱提取直接治疗关系...")
treats_df = kg_df[kg_df['relation'].isin(POSITIVE_TREATMENT_RELATIONS)]
treats_df = treats_df[treats_df['tail'].isin(autoimmune_disease_ids)]  # 只保留自身免疫病
treats_df = treats_df[treats_df['head'].isin(approved_drugs)]  # 只保留已批准药物

# 构建直接治疗对字典
direct_disease_drug_pairs = collections.defaultdict(set)
for _, row in treats_df.iterrows():
    disease_id = row['tail']
    drug_id = row['head']
    direct_disease_drug_pairs[disease_id].add(drug_id)

print(f"找到 {sum(len(drugs) for disease, drugs in direct_disease_drug_pairs.items())} 个直接治疗药物-疾病对，涉及 {len(direct_disease_drug_pairs)} 个疾病。")

# --- 第2步：加载路径查找结果 ---
print("\n2. 加载路径查找结果...")
valid_pathways_pairs = set()
inferred_disease_drug_pairs = collections.defaultdict(set)

if os.path.exists(STRATIFIED_PATHS_FILE):
    with open(STRATIFIED_PATHS_FILE, 'r') as f:
        stratified_paths = json.load(f)

    print(f"加载了 {len(stratified_paths)} 个药物-疾病路径组。")

    # 分析路径来提取有效对
    for pair_key, paths in tqdm(stratified_paths.items(), desc="分析路径"):
        if not paths:  # 跳过空路径
            continue

        try:
            # 分割键 drug_id|||disease_id
            drug_id, disease_id = pair_key.split('|||')

            # 跳过非自身免疫病和非已批准药物
            if disease_id not in autoimmune_disease_ids or drug_id not in approved_drugs:
                continue

            # 检查路径的质量和相关性
            has_valid_path = False
            for path_info in paths:
                # 检查路径长度是否合理（排除太长的路径）
                if 2 <= len(path_info['nodes']) <= 5:
                    # 检查路径总权重是否在合理范围内
                    if path_info['total_weight'] < 20:
                        has_valid_path = True
                        break

            if has_valid_path:
                valid_pathways_pairs.add(pair_key)
                inferred_disease_drug_pairs[disease_id].add(drug_id)
        except Exception as e:
            print(f"处理路径时出错: {e} - {pair_key}")

    print(f"从路径中识别了 {len(valid_pathways_pairs)} 个有效药物-疾病对，涉及 {len(inferred_disease_drug_pairs)} 个疾病。")
else:
    print(f"警告: 路径文件 {STRATIFIED_PATHS_FILE} 不存在。仅使用直接治疗关系进行划分。")

# --- 第3步：合并直接治疗关系和路径推断关系 ---
print("\n3. 合并直接治疗关系和路径推断关系...")
combined_disease_drug_pairs = collections.defaultdict(set)

# 首先添加所有直接治疗关系
for disease, drugs in direct_disease_drug_pairs.items():
    combined_disease_drug_pairs[disease].update(drugs)

# 然后添加路径推断的关系
for disease, drugs in inferred_disease_drug_pairs.items():
    combined_disease_drug_pairs[disease].update(drugs)

# 计算每个疾病的有效药物数量
disease_drug_counts = {disease: len(drugs) for disease, drugs in combined_disease_drug_pairs.items()}

# 计算直接关系和路径推断的独特贡献
unique_direct_pairs = 0
unique_inferred_pairs = 0
total_pairs = 0

for disease, drugs in combined_disease_drug_pairs.items():
    total_pairs += len(drugs)
    for drug in drugs:
        direct_has = drug in direct_disease_drug_pairs.get(disease, set())
        inferred_has = drug in inferred_disease_drug_pairs.get(disease, set())

        if direct_has and not inferred_has:
            unique_direct_pairs += 1
        elif not direct_has and inferred_has:
            unique_inferred_pairs += 1

print(f"合并后共有 {total_pairs} 个有效药物-疾病对，涉及 {len(combined_disease_drug_pairs)} 个疾病。")
print(f"其中，直接治疗关系独特贡献: {unique_direct_pairs} 对，路径推断独特贡献: {unique_inferred_pairs} 对。")

# --- 第4步：基于综合关系进行任务集划分 ---
print("\n4. 基于综合关系进行任务集划分...")

# 筛选出至少有MIN_DRUGS_FOR_SPLIT个药物的疾病
MIN_DRUGS_FOR_SPLIT = 6  # 最低要求（至少1个药物）
valid_task_diseases = {disease: count for disease, count in disease_drug_counts.items()
                      if count >= MIN_DRUGS_FOR_SPLIT}

# 按药物数量从多到少排序疾病
sorted_diseases = sorted(valid_task_diseases.keys(), key=lambda d: valid_task_diseases[d], reverse=True)

# 详细分析药物数量分布
drug_counts = [disease_drug_counts[d] for d in sorted_diseases]
print("\n药物数量分布统计:")
print(f"平均每个疾病的药物数量: {np.mean(drug_counts):.2f}")
print(f"中位数药物数量: {np.median(drug_counts):.2f}")
print(f"最小药物数量: {min(drug_counts)}")
print(f"最大药物数量: {max(drug_counts)}")

# 设置划分比例，确保每个集合都有足够的样本
n = len(sorted_diseases)
MIN_TEST_DISEASES = max(5, int(n * 0.1))  # 确保测试集至少有5个疾病或10%
MIN_VAL_DISEASES = max(5, int(n * 0.1))   # 确保验证集至少有5个疾病或10%

# 计算划分点
val_start = n - MIN_TEST_DISEASES - MIN_VAL_DISEASES
test_start = n - MIN_TEST_DISEASES

# 进行划分
meta_train_diseases = sorted_diseases[:val_start]
meta_val_diseases = sorted_diseases[val_start:test_start]
meta_test_diseases = sorted_diseases[test_start:]

# 确保划分的合理性
if len(meta_test_diseases) < 5 or len(meta_val_diseases) < 5:
    print("警告: 测试集或验证集疾病数量太少，重新调整分配...")
    # 简单的重新分配策略
    if n >= 21:  # 至少要有21个疾病才能保证每组至少7个
        meta_train_diseases = sorted_diseases[:n-14]
        meta_val_diseases = sorted_diseases[n-14:n-7]
        meta_test_diseases = sorted_diseases[n-7:]
    else:
        # 如果疾病总数少于21，尝试均匀分配
        chunk_size = n // 3
        meta_train_diseases = sorted_diseases[:chunk_size]
        meta_val_diseases = sorted_diseases[chunk_size:2*chunk_size]
        meta_test_diseases = sorted_diseases[2*chunk_size:]

# 统计每个集合的总药物数和平均每个疾病的药物数
train_drugs = sum(disease_drug_counts[d] for d in meta_train_diseases)
val_drugs = sum(disease_drug_counts[d] for d in meta_val_diseases)
test_drugs = sum(disease_drug_counts[d] for d in meta_test_diseases)

# 分析测试集药物分布
test_drug_dist = [disease_drug_counts[d] for d in meta_test_diseases]
print("\n测试集药物分布:")
for i, disease_id in enumerate(meta_test_diseases):
    drug_count = disease_drug_counts[disease_id]
    print(f"{i+1}. 疾病 {disease_id.split('::')[-1]} - {drug_count} 个药物")

# 输出划分结果
print("\n任务集划分结果:")
print(f"- 训练集: {len(meta_train_diseases)} 个疾病, 共 {train_drugs} 个药物对, "
      f"平均每疾病 {train_drugs/len(meta_train_diseases):.2f} 个药物")
print(f"- 验证集: {len(meta_val_diseases)} 个疾病, 共 {val_drugs} 个药物对, "
      f"平均每疾病 {val_drugs/len(meta_val_diseases):.2f} 个药物")
print(f"- 测试集: {len(meta_test_diseases)} 个疾病, 共 {test_drugs} 个药物对, "
      f"平均每疾病 {test_drugs/len(meta_test_diseases):.2f} 个药物")

# --- 第5步：保存增强的划分结果 ---
# 创建详细的疾病-药物映射
disease_to_drugs_map = {}
for disease_id in sorted_diseases:
    disease_to_drugs_map[disease_id] = list(combined_disease_drug_pairs[disease_id])

# 保存划分结果和额外的详细信息
split_info = {
    'meta_train': meta_train_diseases,
    'meta_val': meta_val_diseases,
    'meta_test': meta_test_diseases,
    'disease_drug_counts': disease_drug_counts,
    'disease_to_drugs': disease_to_drugs_map,  # 增加详细的疾病-药物映射
    'stats': {
        'total_diseases': len(combined_disease_drug_pairs),
        'total_pairs': total_pairs,
        'direct_unique_pairs': unique_direct_pairs,
        'inferred_unique_pairs': unique_inferred_pairs,
        'train_stats': {'diseases': len(meta_train_diseases), 'drugs': train_drugs},
        'val_stats': {'diseases': len(meta_val_diseases), 'drugs': val_drugs},
        'test_stats': {'diseases': len(meta_test_diseases), 'drugs': test_drugs},
    }
}

with open(NEW_TASK_SPLIT_PATH, 'w') as f:
    json.dump(split_info, f, indent=4)
print(f"\n增强的任务划分信息已保存到: {NEW_TASK_SPLIT_PATH}")

# --- 第6步：创建与增强划分兼容的支撑集扩充评估代码 ---
print("\n完成！请在评估代码中使用新的任务划分文件:")
print(f"META_TASK_SPLIT_PATH = os.path.join(BASE_DATA_DIR, 'meta_task_split_enhanced_v3.json')")

## 阶段二：Meta-KGLM训练与预测

#### 路径查找

In [None]:
import os
import json
import pandas as pd
import networkx as nx
from itertools import product, islice
import multiprocessing
from tqdm import tqdm
import collections
import time
import pickle
import signal
import psutil
from functools import partial
import numpy as np

print("--- Stage 2, Block 1: Stratified Pathfinding for All Autoimmune Diseases ---")

# --- Configuration ---
GDRIVE_PATH = '/content/drive/MyDrive/Colab Notebooks/'
OUTPUT_DIR = os.path.join(GDRIVE_PATH, "prediction_outputs_stratified_v4")
KG_PATH = os.path.join(GDRIVE_PATH, 'drkg_autoimmune_enhanced_glm4.tsv')
STRATIFIED_PATHS_FILE = os.path.join(OUTPUT_DIR, "stratified_paths_k1000.json")
AUTOIMMUNE_DISEASE_LIST_PATH = os.path.join(GDRIVE_PATH, 'mondo_autoimmune_disease_ALL_results.csv')

# --- Constants ---
K_SHORTEST_PATHS = 50       # Maximum paths to find per pair
MAX_PATH_LENGTH = 5         # Maximum path length to consider
NUM_PATHS_PER_LENGTH = 2    # Number of paths to keep per length in stratified selection
BATCH_SIZE = 200           # Increased batch size for better throughput
PATH_TIMEOUT = 15           # Reduced timeout for faster throughput
MAX_CONCURRENT_WORKERS = 40 # Parallel processes to use
SAVE_INTERVAL = 50          # Save results after every N pairs within a batch

# Relation weights (lower is better)
RELATION_WEIGHTS = {
    'DRUGBANK::target::Compound:Gene': 1, 'DGIDB::INHIBITOR::Gene:Compound': 1,
    'STRING::INHIBITION::Gene:Gene': 1, 'GNBR::Te::Gene:Disease': 1,
    'Hetionet::CtD::Compound:Disease': 1, 'treats': 1, 'DGIDB::ANTAGONIST::Gene:Compound': 2,
    'DGIDB::BLOCKER::Gene:Compound': 2, 'STRING::ACTIVATION::Gene:Gene': 2,
    'Hetionet::CuG::Compound:Gene': 2, 'Hetionet::CdG::Compound:Gene': 2,
    'DRUGBANK::enzyme::Compound:Gene': 5, 'Hetionet::CbG::Compound:Gene': 5,
    'STRING::BINDING::Gene:Gene': 5, 'INTACT::DIRECT INTERACTION::Compound:Gene': 5,
    'Hetionet::GpPW::Gene:Pathway': 5, 'Hetionet::DaG::Disease:Gene': 10,
    'associated_with': 10, 'interacts_with': 10, 'default': 8
}

# --- Global variables (will be assigned in each process) ---
G_multi = None
G_simple_weighted = None
node_neighborhoods = None  # Will hold precomputed neighborhoods for each node
disease_nodes_set = None   # Will hold the set of disease nodes to exclude

class TimeoutException(Exception):
    pass

def timeout_handler(signum, frame):
    raise TimeoutException("Pathfinding took too long")

def save_graph(G, file_path):
    """Save a NetworkX graph using pickle directly"""
    with open(file_path, 'wb') as f:
        pickle.dump(G, f)

def load_graph(file_path):
    """Load a NetworkX graph using pickle directly"""
    with open(file_path, 'rb') as f:
        return pickle.load(f)

def log_memory_usage():
    """Log current memory usage"""
    process = psutil.Process(os.getpid())
    memory_info = process.memory_info()
    print(f"Memory usage: {memory_info.rss / (1024 * 1024 * 1024):.2f} GB")

def is_disease_node(node):
    """Check if a node is a disease node"""
    return 'Disease' in node

def build_disease_nodes_set(G):
    """Build a set of all disease nodes in the graph"""
    return {n for n in G.nodes() if is_disease_node(n)}

def precompute_neighborhood(G, cutoff=2):
    """
    Precompute neighborhood information for all nodes in the graph.
    This significantly speeds up path finding by providing quick access
    to nodes within a certain distance.
    """
    print(f"Precomputing {cutoff}-hop neighborhoods for all nodes...")
    neighborhoods = {}

    # Get all unique drug and disease nodes
    drug_nodes = [n for n in G.nodes() if 'Drug' in n or 'Compound' in n]
    disease_nodes = [n for n in G.nodes() if 'Disease' in n]
    important_nodes = drug_nodes + disease_nodes

    # Only precompute for important nodes to save memory
    with tqdm(total=len(important_nodes), desc=f"Building {cutoff}-hop neighborhoods") as pbar:
        for node in important_nodes:
            try:
                # Get nodes within cutoff distance
                neighborhood = nx.single_source_shortest_path_length(G, node, cutoff=cutoff)
                neighborhoods[node] = neighborhood
            except Exception:
                # Fall back to empty neighborhood if there's an error
                neighborhoods[node] = {}
            pbar.update(1)

    print(f"Precomputed neighborhoods for {len(neighborhoods)} nodes")
    return neighborhoods

def init_worker(multi_graph_path, simple_graph_path, neighborhoods_path=None):
    """Initialize the worker process with graph data"""
    global G_multi, G_simple_weighted, node_neighborhoods, disease_nodes_set

    worker_id = os.getpid()
    print(f"Worker process {worker_id} initializing...")
    start_time = time.time()

    # Load the graphs from the saved files
    G_multi = load_graph(multi_graph_path)
    G_simple_weighted = load_graph(simple_graph_path)

    # Load precomputed neighborhoods if available
    if neighborhoods_path and os.path.exists(neighborhoods_path):
        node_neighborhoods = load_graph(neighborhoods_path)
        print(f"Worker {worker_id} loaded precomputed neighborhoods")
    else:
        node_neighborhoods = {}

    # Build set of disease nodes to exclude from pathfinding
    disease_nodes_set = build_disease_nodes_set(G_simple_weighted)
    print(f"Worker {worker_id} initialized in {time.time() - start_time:.2f}s, found {len(disease_nodes_set)} disease nodes")

def get_path_details(node_path):
    """
    Takes a list of nodes and reconstructs the full path with edge
    details (relation, weight) from the global MultiDiGraph.
    """
    path_details = []
    total_weight = 0
    for i in range(len(node_path) - 1):
        u, v = node_path[i], node_path[i+1]
        edge_data = G_multi.get_edge_data(u, v)
        best_edge_key = min(edge_data, key=lambda k: edge_data[k]['weight'])
        best_edge = edge_data[best_edge_key]
        path_details.append({
            "head": u,
            "relation": best_edge['relation'],
            "tail": v,
            "weight": best_edge['weight']
        })
        total_weight += best_edge['weight']
    return {"nodes": node_path, "details": path_details, "total_weight": total_weight}

def bidirectional_search(G, source, target, cutoff=4):
    """
    Perform bidirectional search from source and target to find paths.
    This is much more efficient than exhaustive search for large graphs.
    Modified to exclude disease nodes as intermediate nodes.
    """
    # Use precomputed neighborhoods if available
    if node_neighborhoods and source in node_neighborhoods and target in node_neighborhoods:
        source_neighbors = node_neighborhoods[source]
        target_neighbors = node_neighborhoods[target]

        # Find intersection of neighborhoods
        common_nodes = set(source_neighbors.keys()) & set(target_neighbors.keys())

        # Filter out disease nodes from common_nodes (except the target disease)
        if disease_nodes_set:
            common_nodes = {node for node in common_nodes
                           if node == target or node == source or node not in disease_nodes_set}

        # Construct paths through common nodes
        paths = []
        for mid in common_nodes:
            # Try to find shortest paths for each segment
            try:
                if mid == source:
                    forward_path = [source]
                else:
                    # Use a modified path finding that avoids disease nodes
                    forward_path = find_path_avoiding_disease(G, source, mid)
                    if not forward_path:  # If no valid path found
                        continue

                if mid == target:
                    backward_path = [target]
                else:
                    # Allow the target disease node only
                    backward_path = find_path_avoiding_disease(G, mid, target, allow_target_disease=True)
                    if not backward_path:  # If no valid path found
                        continue

                # Join the paths (remove duplicate middle node)
                full_path = forward_path + backward_path[1:]

                # Only keep paths shorter than cutoff
                if len(full_path) <= cutoff + 1:
                    paths.append(full_path)
            except (nx.NetworkXNoPath, nx.NodeNotFound):
                continue

        return paths

    # Fall back to standard shortest path if neighborhoods aren't available
    try:
        # Use our custom disease-avoiding path function
        path = find_path_avoiding_disease(G, source, target, allow_target_disease=True)
        if path and len(path) <= cutoff + 1:
            return [path]
        else:
            return []
    except (nx.NetworkXNoPath, nx.NodeNotFound):
        return []

def find_path_avoiding_disease(G, source, target, allow_target_disease=False):
    """
    Find a path from source to target while avoiding disease nodes as intermediates.
    Uses a modified BFS approach to find the shortest valid path.

    Parameters:
    - allow_target_disease: If True, allows the target node to be a disease node
    """
    # Quick check for direct connection
    if G.has_edge(source, target):
        return [source, target]

    # Special case: if target is a disease node and allowed
    if allow_target_disease and target in disease_nodes_set:
        target_is_disease = True
    else:
        target_is_disease = False

    # BFS to find shortest path avoiding disease nodes
    queue = collections.deque([(source, [source])])
    visited = set([source])

    while queue:
        current, path = queue.popleft()

        # Check all neighbors
        for neighbor in G.neighbors(current):
            # Skip if already visited
            if neighbor in visited:
                continue

            # Skip if it's a disease node (unless it's the allowed target)
            if neighbor in disease_nodes_set and (not target_is_disease or neighbor != target):
                continue

            # Create new path
            new_path = path + [neighbor]

            # If found target, return path
            if neighbor == target:
                return new_path

            # Otherwise, add to queue to explore further
            visited.add(neighbor)
            queue.append((neighbor, new_path))

    # No path found
    return None

def find_and_filter_paths_worker(pair):
    """
    Worker function for finding paths between drug-disease pairs with timeout.
    Modified to exclude disease nodes as intermediate nodes.
    """
    drug_id, disease_id = pair

    if not (G_simple_weighted.has_node(drug_id) and G_simple_weighted.has_node(disease_id)):
        return (pair, [])

    # Set up timeout handling
    signal.signal(signal.SIGALRM, timeout_handler)
    signal.alarm(PATH_TIMEOUT)

    try:
        # Use bidirectional search that avoids disease nodes
        found_paths = bidirectional_search(
            G_simple_weighted, source=drug_id, target=disease_id,
            cutoff=MAX_PATH_LENGTH
        )

        # If we found fewer than 5 paths, try modified shortest paths approach
        if len(found_paths) < 5:
            try:
                # Create a temporary copy of the graph with disease nodes removed
                # (except for the target disease)
                tmp_graph = G_simple_weighted.copy()

                # Remove all disease nodes except the target
                nodes_to_remove = [n for n in disease_nodes_set if n != disease_id]
                tmp_graph.remove_nodes_from(nodes_to_remove)

                if drug_id in tmp_graph and disease_id in tmp_graph:
                    try:
                        # Find more paths with shortest_simple_paths on filtered graph
                        path_generator = nx.shortest_simple_paths(
                            tmp_graph, source=drug_id, target=disease_id, weight='weight'
                        )
                        # We'll limit this to avoid hanging
                        more_paths = list(islice(path_generator, K_SHORTEST_PATHS))
                        found_paths.extend(more_paths)
                    except (nx.NodeNotFound, nx.NetworkXNoPath):
                        pass
            except Exception as e:
                if "Pathfinding took too long" not in str(e):
                    pass

    except TimeoutException:
        # This is expected for some complex pairs
        # Try one last approach - direct paths or very short paths
        try:
            # Just check if there's a direct connection
            if G_simple_weighted.has_edge(drug_id, disease_id):
                found_paths = [[drug_id, disease_id]]
            else:
                # Try very fast search with small cutoff
                found_paths = bidirectional_search(
                    G_simple_weighted, drug_id, disease_id, cutoff=2
                )
        except Exception:
            found_paths = []

    except Exception:
        found_paths = []

    finally:
        # Disable the alarm
        signal.alarm(0)

    # Additional verification to ensure no disease nodes in paths (except target)
    verified_paths = []
    for p in found_paths:
        # Check if any intermediate nodes are disease nodes
        has_intermediate_disease = any(
            node in disease_nodes_set for node in p[1:-1]
        )

        if not has_intermediate_disease:
            verified_paths.append(p)

    filtered_paths = []
    for p in verified_paths:
        if len(p) == 2:
            # Check if direct relationship is 'treats'
            edge_data = G_multi.get_edge_data(p[0], p[1])
            is_direct_treats = any(
                d.get('relation') in ['treats', 'Hetionet::CtD::Compound:Disease']
                for d in edge_data.values()
            )
            if not is_direct_treats:
                filtered_paths.append(p)
        elif len(p) <= MAX_PATH_LENGTH:
             filtered_paths.append(p)

    if not filtered_paths:
        return (pair, [])

    # Group paths by length for stratified selection
    paths_by_length = collections.defaultdict(list)
    for p in filtered_paths:
        paths_by_length[len(p)].append(p)

    stratified_selection = []
    for length in sorted(paths_by_length.keys()):
        selected = paths_by_length[length][:NUM_PATHS_PER_LENGTH]
        stratified_selection.extend(selected)

    # Get path details
    final_path_details = [get_path_details(p) for p in stratified_selection]

    return (pair, final_path_details)

def process_batch(batch_idx, pairs, multi_graph_path, simple_graph_path, neighborhoods_path=None):
    """
    Process a batch of pairs with progress tracking through a shared counter
    Implements intermediate saving to prevent data loss for large batches
    """
    batch_start = time.time()
    batch_results = {}

    # Initialize worker for this process
    init_worker(multi_graph_path, simple_graph_path, neighborhoods_path)

    # Setup temp file for intermediate results
    temp_file = os.path.join(OUTPUT_DIR, f"temp_batch_{batch_idx}_{int(time.time())}.json")

    # Process pairs
    for i, pair in enumerate(pairs):
        pair_result = find_and_filter_paths_worker(pair)

        # Save result if paths were found
        if pair_result[1]:
            key = f"{pair_result[0][0]}|||{pair_result[0][1]}"
            batch_results[key] = pair_result[1]

        # Save intermediate results periodically
        if (i + 1) % SAVE_INTERVAL == 0:
            try:
                with open(temp_file, 'w') as f:
                    json.dump(batch_results, f)
            except Exception as e:
                print(f"Warning: Could not save intermediate results for batch {batch_idx}: {str(e)}")

    # Return the results
    batch_time = time.time() - batch_start
    print(f"Batch {batch_idx} completed in {batch_time:.2f}s, found paths for {len(batch_results)} pairs")

    # Clean up temp file
    if os.path.exists(temp_file):
        try:
            os.remove(temp_file)
        except:
            pass

    return batch_results

def build_subgraph_for_pathfinding(kg_df):
    """
    Build a simplified graph for pathfinding by filtering to more relevant
    paths between drugs and diseases.
    """
    print("Building graphs with node type filtering...")

    # Assign weights based on relation types
    print("Assigning biological weights to graph edges...")
    kg_df['weight'] = kg_df['relation'].apply(lambda r: RELATION_WEIGHTS.get(r, RELATION_WEIGHTS['default']))

    # Build complete MultiDiGraph first
    print("Building MultiDiGraph for detailed lookups...")
    G_multi = nx.from_pandas_edgelist(
        kg_df, 'head', 'tail', edge_attr=True, create_using=nx.MultiDiGraph()
    )

    # Build simple weighted DiGraph for pathfinding
    print("Building weighted SimpleDiGraph for efficient pathfinding...")
    G_simple_weighted = nx.DiGraph()

    with tqdm(total=G_multi.number_of_edges(), desc="Collapsing MultiGraph") as pbar:
        for u, v, data in G_multi.edges(data=True):
            weight = data['weight']
            if G_simple_weighted.has_edge(u, v):
                if weight < G_simple_weighted[u][v]['weight']:
                    G_simple_weighted[u][v]['weight'] = weight
            else:
                G_simple_weighted.add_edge(u, v, weight=weight)
            pbar.update(1)

    return G_multi, G_simple_weighted

def load_all_batch_results(output_dir):
    """
    Load all previously saved batch results from the output directory
    """
    all_results = {}
    batch_files = [f for f in os.listdir(output_dir) if f.startswith("stratified_paths_batch_") and f.endswith(".json")]

    if not batch_files:
        return all_results

    print(f"Found {len(batch_files)} existing batch result files")

    for batch_file in batch_files:
        try:
            with open(os.path.join(output_dir, batch_file), 'r') as f:
                batch_data = json.load(f)
                all_results.update(batch_data)
            print(f"Loaded {len(batch_data)} results from {batch_file}")
        except Exception as e:
            print(f"Error loading {batch_file}: {str(e)}")

    return all_results

def main():
    if not os.path.exists(OUTPUT_DIR):
        os.makedirs(OUTPUT_DIR)
        print(f"Created output directory: {OUTPUT_DIR}")

    # Define paths for saving the graphs
    multi_graph_path = os.path.join(OUTPUT_DIR, "multi_graph.pkl")
    simple_graph_path = os.path.join(OUTPUT_DIR, "simple_graph.pkl")
    neighborhoods_path = os.path.join(OUTPUT_DIR, "node_neighborhoods.pkl")

    # Check if graph files already exist
    if not (os.path.exists(multi_graph_path) and os.path.exists(simple_graph_path)):
        print("Loading knowledge graph...")
        kg_df = pd.read_csv(KG_PATH, sep='\\t', header=None,
                           names=['head', 'relation', 'tail'], engine='python')
        kg_df.dropna(inplace=True)

        # Build graphs - potentially simplified
        G_multi, G_simple_weighted = build_subgraph_for_pathfinding(kg_df)

        print("Saving graph files for multiprocessing...")
        save_graph(G_multi, multi_graph_path)
        save_graph(G_simple_weighted, simple_graph_path)
        print("Graph construction and serialization complete.")

        # Precompute neighborhoods for faster pathfinding
        print("Computing node neighborhoods...")
        neighborhoods = precompute_neighborhood(G_simple_weighted, cutoff=2)
        save_graph(neighborhoods, neighborhoods_path)
        print(f"Neighborhood data saved to {neighborhoods_path}")
    else:
        print(f"Using precomputed graph files from {OUTPUT_DIR}")
        # Temporarily load for node checking
        print("Loading simple graph to check nodes...")
        G_simple_weighted = load_graph(simple_graph_path)

        # Create neighborhoods file if it doesn't exist
        if not os.path.exists(neighborhoods_path):
            print("Computing neighborhood data for first time...")
            neighborhoods = precompute_neighborhood(G_simple_weighted, cutoff=2)
            save_graph(neighborhoods, neighborhoods_path)
            print(f"Neighborhood data saved to {neighborhoods_path}")

    print("Preparing drug-disease pairs...")
    drug_status_path = os.path.join(GDRIVE_PATH, "prediction_outputs_paper_final", 'drug_status_lookup.json')

    # Load drug status lookup
    with open(drug_status_path, 'r') as f:
        drug_status_lookup = json.load(f)

    # Get all approved drugs that are in the graph
    all_approved_drugs = {d for d, s in drug_status_lookup.items()
                         if s == 'approved' and d in G_simple_weighted.nodes}
    print(f"Found {len(all_approved_drugs)} approved drugs in the knowledge graph")

    # Load all autoimmune diseases from CSV
    print("Loading all autoimmune diseases from CSV file...")
    autoimmune_df = pd.read_csv(AUTOIMMUNE_DISEASE_LIST_PATH)

    # Format disease IDs correctly (add 'Disease::' prefix if not already present)
    all_autoimmune_disease_ids = set()
    for disease_id in autoimmune_df['ID'].astype(str):
        formatted_id = disease_id if disease_id.startswith('Disease::') else f"Disease::{disease_id}"
        all_autoimmune_disease_ids.add(formatted_id)

    # Filter to only include diseases that exist in the knowledge graph
    all_diseases = {d for d in all_autoimmune_disease_ids if d in G_simple_weighted.nodes}

    print(f"Found {len(all_diseases)} autoimmune diseases in the knowledge graph out of {len(all_autoimmune_disease_ids)} total")

    # Free up memory
    del G_simple_weighted
    log_memory_usage()

    # Create all possible drug-disease pairs
    drug_disease_pairs = list(product(all_approved_drugs, all_diseases))
    print(f"Total pairs to search: {len(drug_disease_pairs)}")

    # Create batches with larger batch size
    batches = []
    for i in range(0, len(drug_disease_pairs), BATCH_SIZE):
        batches.append((i // BATCH_SIZE, drug_disease_pairs[i:i + BATCH_SIZE]))

    print(f"Split work into {len(batches)} batches of ~{BATCH_SIZE} pairs each")
    print(f"Using up to {MAX_CONCURRENT_WORKERS} parallel workers")

    # Check if resume file exists
    resume_file = os.path.join(OUTPUT_DIR, "resume_info.json")
    start_batch = 0

    # First check if we have any existing batch results
    final_paths_map = load_all_batch_results(OUTPUT_DIR)

    if os.path.exists(resume_file):
        with open(resume_file, 'r') as f:
            resume_info = json.load(f)
        start_batch = resume_info.get('next_batch', 0)
        print(f"Resume file found, starting from batch {start_batch}")

        # Check if start_batch is valid
        if start_batch >= len(batches):
            print(f"WARNING: Resume file indicates batch {start_batch}, but there are only {len(batches)} batches")
            print("All batches appear to be completed already.")

            # Ask user if they want to restart from beginning
            user_input = input("Would you like to restart from the beginning? (y/n): ")
            if user_input.lower() == 'y':
                start_batch = 0
                print("Restarting from batch 0")
            else:
                print("Using existing results and skipping processing")

                # Just consolidate existing results and exit
                print(f"Found path sets for {len(final_paths_map)} pairs.")
                print(f"Saving final results to {STRATIFIED_PATHS_FILE}...")

                try:
                    with open(STRATIFIED_PATHS_FILE, 'w') as f:
                        json.dump(final_paths_map, f)
                    print(f"Results saved successfully to {STRATIFIED_PATHS_FILE}")
                    print(f"\n--- Block 1 finished. Output saved to {STRATIFIED_PATHS_FILE} ---")
                    return  # Exit the function
                except Exception as e:
                    print(f"Error saving results: {e}")
                    print("Please check individual batch files for results")
                    return  # Exit the function
    else:
        print("No resume file found, starting from the beginning")

    # Process batches in parallel using a pool
    remaining_batches = batches[start_batch:]
    total_pairs = sum(len(batch[1]) for batch in remaining_batches)
    print(f"Processing {len(remaining_batches)} remaining batches with {total_pairs} total pairs")

    # Check if there are actually batches to process
    if len(remaining_batches) == 0:
        print("No batches to process. Consolidating existing results...")
        print(f"Found path sets for {len(final_paths_map)} pairs.")
        print(f"Saving final results to {STRATIFIED_PATHS_FILE}...")

        try:
            with open(STRATIFIED_PATHS_FILE, 'w') as f:
                json.dump(final_paths_map, f)
            print(f"Results saved successfully to {STRATIFIED_PATHS_FILE}")
            print(f"\n--- Block 1 finished. Output saved to {STRATIFIED_PATHS_FILE} ---")
            return  # Exit the function
        except Exception as e:
            print(f"Error saving results: {e}")
            print("Please check individual batch files for results")
            return  # Exit the function

    # Create a master progress bar for overall progress
    with tqdm(total=len(remaining_batches), desc="Overall Progress", unit="batch") as pbar:
        # Process batches with a pool of workers
        with multiprocessing.Pool(processes=MAX_CONCURRENT_WORKERS) as pool:
            active_tasks = []

            # Submit initial batch of tasks
            num_initial_tasks = min(MAX_CONCURRENT_WORKERS, len(remaining_batches))
            for i in range(num_initial_tasks):
                batch_idx, batch_pairs = remaining_batches[i]
                task = pool.apply_async(process_batch,
                    args=(batch_idx, batch_pairs, multi_graph_path, simple_graph_path, neighborhoods_path))
                active_tasks.append((batch_idx, task))

            next_batch_idx = num_initial_tasks

            # Process results as they complete and submit new tasks
            while active_tasks:
                for i, (batch_idx, task) in enumerate(active_tasks):
                    if task.ready():
                        # Process completed task
                        try:
                            batch_results = task.get()
                            final_paths_map.update(batch_results)

                            # Save after each batch - use a temporary file first to prevent corruption
                            batch_file = os.path.join(OUTPUT_DIR, f"stratified_paths_batch_{batch_idx}.json.tmp")
                            final_batch_file = os.path.join(OUTPUT_DIR, f"stratified_paths_batch_{batch_idx}.json")

                            try:
                                # Write to temporary file first
                                with open(batch_file, 'w') as f:
                                    json.dump(batch_results, f)

                                # If successful, rename to final file
                                if os.path.exists(batch_file):
                                    os.replace(batch_file, final_batch_file)

                                # Update resume info
                                with open(resume_file, 'w') as f:
                                    json.dump({'next_batch': batch_idx + 1, 'total_results': len(final_paths_map)}, f)

                                # Update progress
                                pbar.update(1)
                                pbar.set_description(f"Overall Progress (Found paths: {len(final_paths_map)})")

                            except Exception as e:
                                print(f"Error saving batch {batch_idx} results: {e}")

                        except Exception as e:
                            print(f"Error in batch {batch_idx}: {e}")
                            # Still update resume file to skip this batch
                            with open(resume_file, 'w') as f:
                                json.dump({
                                    'next_batch': batch_idx + 1,
                                    'total_results': len(final_paths_map),
                                    'error_in_batch': batch_idx
                                }, f)
                            pbar.update(1)

                        # Remove completed task
                        active_tasks.pop(i)

                        # Submit a new task if there are more batches
                        if next_batch_idx < len(remaining_batches):
                            batch_idx, batch_pairs = remaining_batches[next_batch_idx]
                            task = pool.apply_async(process_batch,
                                args=(batch_idx, batch_pairs, multi_graph_path, simple_graph_path, neighborhoods_path))
                            active_tasks.append((batch_idx, task))
                            next_batch_idx += 1

                        break

                # Small sleep to avoid busy waiting
                time.sleep(0.1)

    print(f"Pathfinding complete. Found path sets for {len(final_paths_map)} pairs.")
    print(f"Saving final results to {STRATIFIED_PATHS_FILE}...")

    # Save final results with a safe approach
    temp_final_file = f"{STRATIFIED_PATHS_FILE}.tmp"
    try:
        with open(temp_final_file, 'w') as f:
            json.dump(final_paths_map, f)

        if os.path.exists(temp_final_file):
            os.replace(temp_final_file, STRATIFIED_PATHS_FILE)
    except Exception as e:
        print(f"Error saving final results: {e}")
        # Try again with a smaller chunk approach
        try:
            print("Trying alternative saving approach...")
            with open(temp_final_file, 'w') as f:
                # Serialize in chunks to avoid memory issues
                f.write("{")
                for i, (key, value) in enumerate(final_paths_map.items()):
                    if i > 0:
                        f.write(",")
                    chunk = json.dumps({key: value})[1:-1]  # Remove outer braces
                    f.write(chunk)
                f.write("}")

            if os.path.exists(temp_final_file):
                os.replace(temp_final_file, STRATIFIED_PATHS_FILE)
        except Exception as e2:
            print(f"Final attempt to save failed: {e2}")
            print("Please check batch files for results")

    print(f"\n--- Block 1 finished. Output saved to {STRATIFIED_PATHS_FILE} ---")

if __name__ == '__main__':
    multiprocessing.freeze_support()
    main()

#### llm的推理

In [None]:
# ===================================================================
# Stage 2, Block 2: API-based Mechanism Synthesis & API-based Embedding
# Environment: Any environment with a stable internet connection
#
# [USER UPDATE V3.7]:
# 1. Local Embedding Removed: The script no longer uses a local
#    SentenceTransformer model. The GPU is not required for this block.
# 2. API-based Embedding: Embedding is now performed by calling the
#    `qwen3-embedding:8b` model via the same API endpoint.
# 3. Parallel Embedding: The embedding process is also parallelized
#    using a ThreadPoolExecutor with 100 workers for high throughput.
#
# [FIX V3.8]:
# 1. Resolved UnboundLocalError by initializing `newly_generated_results`
#    before the conditional block that might skip its creation.
# ===================================================================
import os
import json
import torch
import torch.nn.functional as F
from openai import OpenAI
from tqdm import tqdm
import time
import collections
from google.colab import userdata
# 导入并发处理库
import concurrent.futures

print("--- 阶段二, 代码块2: API机制合成与API嵌入 ---")

# --- 配置 ---
GDRIVE_PATH = '/content/drive/MyDrive/Colab Notebooks/'
IO_DIR = os.path.join(GDRIVE_PATH, "prediction_outputs_stratified_v4")

# --- 输入和输出文件 ---
STRATIFIED_PATHS_FILE = os.path.join(IO_DIR, "stratified_paths_k1000.json")
LLM_STORIES_FILE = os.path.join(IO_DIR, "llm_generated_stories.json")
STORY_EMBEDDINGS_FILE = os.path.join(IO_DIR, "story_embeddings_from_llm.pt")

# --- API 配置 ---
try:
    API_TOKEN = userdata.get('DEEPSEEK_API_TOKEN')
except userdata.SecretNotFoundError:
    print("错误：未找到名为 'DEEPSEEK_API_TOKEN' 的Colab密钥。")
    print("请在Colab左侧的“密钥”选项卡(钥匙图标)中添加您的DeepSeek API Token。")
    API_TOKEN = None

BASE_URL = "https://uni-api.cstcloud.cn/v1"

# --- 辅助函数 ---
def format_path_for_prompt(path_detail):
    """将单个路径格式化为LLM可读的字符串。"""
    nodes_str = " -> ".join([n.split('::')[-1] for n in path_detail['nodes']])
    return f"(Length {len(path_detail['nodes'])}, Weight {path_detail['total_weight']:.2f}): {nodes_str}"

def get_story_from_api(client, drug_name, disease_name, paths, retries=3, delay=5):
    """使用带重试机制的API调用来获取机制故事。"""
    mechanisms_str = "\n".join([f"{i+1}. {format_path_for_prompt(p)}" for i, p in enumerate(paths)])

    messages = [
        {"role": "system", "content": "You are a biomedical data analyst synthesizing a factual summary of potential mechanistic pathways."},
        {"role": "user", "content": f"""
Based ONLY on the {len(paths)} connections provided below, generate a single, concise paragraph in English outlining the potential mechanistic links between '{drug_name}' and '{disease_name}'.

Your summary should:
1.  Start with the primary mechanism (the path with the lowest weight).
2.  Integrate any secondary or alternative connections from other paths into a cohesive narrative.
3.  Throughout the description, naturally highlight the key intermediate entities (genes, pathways) that are central to the connections.

Your description must be strictly limited to the entities and relationships given. DO NOT add any speculative or external information not present in the paths.

Provided Connections:
{mechanisms_str}

Mechanistic Summary:
"""}
    ]

    for attempt in range(retries):
        try:
            completion = client.chat.completions.create(
                model="deepseek-v3:671b",
                messages=messages,
                temperature=0.1,
            )
            return completion.choices[0].message.content.strip()
        except Exception as e:
            print(f"API call for {drug_name}-{disease_name} failed (Attempt {attempt + 1}/{retries}): {e}")
            if attempt < retries - 1:
                time.sleep(delay * (2 ** attempt))
            else:
                return f"Error: Failed to generate story for {drug_name} and {disease_name} after {retries} attempts."

def get_embedding_from_api(client, texts, retries=3, delay=5):
    """使用API获取一批文本的嵌入向量。"""
    for attempt in range(retries):
        try:
            response = client.embeddings.create(
                model="qwen3-embedding:8b",
                input=texts
            )
            return [torch.tensor(data.embedding) for data in response.data]
        except Exception as e:
            print(f"Embedding API call failed (Attempt {attempt + 1}/{retries}): {e}")
            if attempt < retries - 1:
                time.sleep(delay)
            else:
                return [None] * len(texts)

def save_checkpoint(data, file_path):
    """原子化地保存JSON文件以防止写入中断。"""
    temp_path = file_path + ".tmp"
    with open(temp_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, indent=2, ensure_ascii=False)
    os.replace(temp_path, file_path)

def main():
    if not API_TOKEN:
        print("API Token未设置，程序已终止。请先在Colab密钥中设置您的Token。")
        return

    os.makedirs(IO_DIR, exist_ok=True)

    # --- 1. 初始化API客户端 ---
    print("初始化API客户端...")
    client = OpenAI(api_key=API_TOKEN, base_url=BASE_URL)

    # --- 2. 生成机制故事 ---
    llm_stories_map = {}
    if os.path.exists(LLM_STORIES_FILE):
        print(f"发现已有的故事文件，正在从 {LLM_STORIES_FILE} 加载...")
        with open(LLM_STORIES_FILE, 'r', encoding='utf-8') as f:
            llm_stories_map = json.load(f)
        print(f"已加载 {len(llm_stories_map)} 个缓存的故事。")

    print(f"正在从 {STRATIFIED_PATHS_FILE} 加载所有路径...")
    with open(STRATIFIED_PATHS_FILE, 'r') as f:
        stratified_paths_map = json.load(f)

    items_to_process = [
        (key, paths) for key, paths in stratified_paths_map.items()
        if key not in llm_stories_map
    ]

    # [修改] 将 newly_generated_results 的初始化移到 if/else 块之外
    newly_generated_results = []
    if not items_to_process:
        print("所有故事均已生成，跳过API调用步骤。")
    else:
        print(f"需要为 {len(items_to_process)} 个新的药-病对生成故事。")
        with concurrent.futures.ThreadPoolExecutor(max_workers=100) as executor:
            future_to_key = {
                executor.submit(
                    get_story_from_api,
                    client,
                    key.split('|||')[0].split('::')[-1],
                    key.split('|||')[1].split('::')[-1],
                    paths
                ): key
                for key, paths in items_to_process
            }
            for future in tqdm(concurrent.futures.as_completed(future_to_key), total=len(items_to_process), desc="通过API并行生成故事"):
                key = future_to_key[future]
                try:
                    story = future.result()
                    llm_stories_map[key] = story
                    newly_generated_results.append((key, story))
                    if len(newly_generated_results) >= 100:
                        print(f"\n--- 已生成100个新结果，正在保存检查点并打印最新示例 ---")
                        last_key, last_story = newly_generated_results[-1]
                        print(f"--- 最新示例 ---\n药-病 对: {last_key}\n生成的故事: {last_story}\n")
                        save_checkpoint(llm_stories_map, LLM_STORIES_FILE)
                        print(f"检查点已保存到 {LLM_STORIES_FILE}。包含总共 {len(llm_stories_map)} 个故事。")
                        newly_generated_results.clear()
                except Exception as exc:
                    print(f'{key} 生成时产生了一个错误: {exc}')
                    llm_stories_map[key] = f"Error during story generation: {exc}"

    if newly_generated_results:
        print("\n--- 所有API调用完成，正在保存最后批次的结果 ---")
        save_checkpoint(llm_stories_map, LLM_STORIES_FILE)
        print(f"最终缓存已保存到 {LLM_STORIES_FILE}。包含总共 {len(llm_stories_map)} 个故事。")

    # --- 3. 通过API并行嵌入独特的故事 ---
    print("准备通过API并行嵌入故事...")
    all_unique_stories = list(set(llm_stories_map.values()))
    print(f"找到 {len(all_unique_stories)} 个独特的故事进行嵌入。")

    story_to_embedding = {}
    batch_size = 128
    story_batches = [all_unique_stories[i:i + batch_size] for i in range(0, len(all_unique_stories), batch_size)]

    with concurrent.futures.ThreadPoolExecutor(max_workers=100) as executor:
        future_to_batch = {
            executor.submit(get_embedding_from_api, client, batch): batch
            for batch in story_batches
        }
        for future in tqdm(concurrent.futures.as_completed(future_to_batch), total=len(story_batches), desc="并行获取嵌入"):
            original_batch = future_to_batch[future]
            try:
                embedding_vectors = future.result()
                for story, emb_tensor in zip(original_batch, embedding_vectors):
                    if emb_tensor is not None:
                        story_to_embedding[story] = F.normalize(emb_tensor, p=2, dim=0)
            except Exception as exc:
                print(f'一个嵌入批次处理失败: {exc}')

    # --- 4. 将嵌入映射回“药物-疾病”对并保存 ---
    print("将嵌入映射回“药物-疾病”对...")
    final_pair_to_embedding = {}
    for pair_key, story in tqdm(llm_stories_map.items(), desc="聚合嵌入"):
        drug_id, disease_id = pair_key.split('|||')
        pair_tuple = (drug_id, disease_id)
        if story in story_to_embedding:
            final_pair_to_embedding[pair_tuple] = story_to_embedding[story]

    print(f"正在将最终的“药物-疾病”对到嵌入的映射保存到 {STORY_EMBEDDINGS_FILE}...")
    torch.save(final_pair_to_embedding, STORY_EMBEDDINGS_FILE)
    print(f"\n--- 代码块2完成。输出已保存到 {STORY_EMBEDDINGS_FILE} ---")

if __name__ == '__main__':
    main()


#### 元学习和最后评估

# ===================================================================
# Stage 2, Block 3: Advanced Multi-Prototype Meta-Learning
# Environment: Standard GPU (e.g., L4, T4)
# [ENHANCED VERSION V5.1]:
# 1. 设备一致性优化：确保所有张量在同一设备上，解决CUDA错误
# 2. 单样本评估增强：特殊处理单样本疾病，提升MRR
# 3. 动态参数调整：根据样本数量自适应选择最佳参数
# 4. 多级对比评分：融合多种相似度度量与交叉疾病知识
# 5. 数值稳定性增强：避免边缘情况导致的数值问题
# ===================================================================
import os
import json
import torch
import torch.nn.functional as F
from openai import OpenAI
import pandas as pd
import numpy as np
import random
import collections
from tqdm.notebook import tqdm
from sklearn.cluster import DBSCAN, AgglomerativeClustering, KMeans
from sklearn.metrics.pairwise import cosine_similarity as sklearn_cosine_similarity
from IPython.display import display, HTML
import time
from google.colab import userdata
import concurrent.futures
from torch_geometric.nn import ComplEx
from collections import OrderedDict, defaultdict
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from sklearn.metrics import silhouette_score
from sklearn.preprocessing import normalize
from sklearn.decomposition import PCA
from scipy.spatial.distance import cdist
import warnings
warnings.filterwarnings('ignore')

print("--- 阶段二, 代码块3: 高级多原型元学习 (版本5.1) ---")

# --- 配置 ---
GDRIVE_PATH = '/content/drive/MyDrive/Colab Notebooks/'
IO_DIR = os.path.join(GDRIVE_PATH, "prediction_outputs_stratified_v4")
BASE_DATA_DIR = os.path.join(GDRIVE_PATH, "prediction_outputs_paper_final")

# --- 输入文件 ---
STORY_EMBEDDINGS_FILE = os.path.join(IO_DIR, "story_embeddings_from_llm.pt")
LLM_STORIES_FILE = os.path.join(IO_DIR, "llm_generated_stories.json")
META_TASK_SPLIT_PATH = os.path.join(BASE_DATA_DIR, 'meta_task_split_enhanced_v3.json')
DRUG_STATUS_PATH = os.path.join(BASE_DATA_DIR, 'drug_status_lookup.json')
KG_PATH = os.path.join(GDRIVE_PATH, 'drkg_autoimmune_enhanced_glm4.tsv')
GNN_MODEL_PATH = os.path.join(BASE_DATA_DIR, 'complex_model_paper.pth')

# --- 输出文件 ---
EXPLAINABLE_RESULTS_FILE = os.path.join(IO_DIR, 'final_explainable_predictions_enhanced.json')
VISUALIZATION_DIR = os.path.join(IO_DIR, 'visualizations')
os.makedirs(VISUALIZATION_DIR, exist_ok=True)

# --- 评估超参数 ---
LOCAL_SUPPORT_SHOTS = 3 
NUM_EVAL_RUNS = 100      # 平衡评估稳定性和速度
NUM_NEG_SAMPLES_PER_QUERY = 1000  # 增加负样本数量以提高评估稳定性
K_SIMILAR_DISEASES = 20   # 增加相似疾病数量以获取更多样本
MIN_AUGMENTED_SUPPORT_SIZE = 5   # 降低门槛使更多疾病可评估
TARGET_AUGMENTED_SAMPLES = 200  # 增加目标增强样本数以提高多样性
SIM_THRESHOLD = 0.65       # 调整阈值以平衡相关性和多样性

# --- 增强型聚类和原型构建超参数 ---
NUM_CLUSTERS_RANGE = range(2, 4)  # 优化聚类数量范围
LOCAL_WEIGHT = 5.0       # 增加本地样本权重以反映其重要性
HIERARCHY_LEVELS = 2     # 层次结构级别
MIN_CLUSTER_SAMPLES = 2  # 每个簇的最小样本数

# --- 评分超参数 ---
# 将根据样本数量动态调整
PROTO_QUALITY_WEIGHT = True  # 启用原型质量加权
DISTANCE_METRICS = ['cosine', 'euclidean', 'manhattan']  # 使用多种距离度量
METRIC_WEIGHTS = {'cosine': 0.7, 'euclidean': 0.2, 'manhattan': 0.1}  # 默认权重
TEMPERATURE = 0.1        # 对比学习温度参数
USE_CALIBRATION = True   # 启用距离校准

# --- 负样本选择参数 ---
HARD_NEG_RATIO = 0.25     # 困难负样本比例
SEMI_HARD_NEG_RATIO = 0.35  # 中等难度负样本
RANDOM_NEG_RATIO = 0.4   # 随机负样本

# --- 交叉疾病参数 ---
ENABLE_CROSS_DISEASE = True  # 启用交叉疾病学习
CROSS_DISEASE_WEIGHT = 0.8   # 交叉疾病样本权重

POSITIVE_TREATMENT_RELATIONS = {
    'treats',
    'DRUGBANK::treats::Compound:Disease',
    'GNBR::T::Compound:Disease',
    'Hetionet::CtD::Compound:Disease'
}

# --- API 配置 ---
try:
    API_TOKEN = userdata.get('DEEPSEEK_API_TOKEN')
except Exception:
    print("错误：未找到名为 'DEEPSEEK_API_TOKEN' 的Colab密钥。")
    print("请在Colab左侧的密钥选项卡(钥匙图标)中添加您的DeepSeek API Token。")
    API_TOKEN = None

BASE_URL = "https://uni-api.cstcloud.cn/v1"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")

# --- 动态参数调整 ---
def adjust_parameters_for_sample_size(sample_size):
    """根据样本数量动态调整超参数"""
    params = {}

    # 少样本情况 (1-2个样本)
    if sample_size <= 2:
        params['METRIC_WEIGHTS'] = {'cosine': 0.85, 'euclidean': 0.15, 'manhattan': 0.0}
        params['TEMPERATURE'] = 0.05
        params['HARD_NEG_RATIO'] = 0.15
        params['SEMI_HARD_NEG_RATIO'] = 0.25
        params['RANDOM_NEG_RATIO'] = 0.6
        params['LOCAL_WEIGHT'] = 5.0
        params['USE_CALIBRATION'] = True
        params['SIM_THRESHOLD'] = 0.55  # 降低相似度阈值以获得更多样本

    # 中等样本情况 (3-4个样本)
    elif sample_size <= 4:
        params['METRIC_WEIGHTS'] = {'cosine': 0.75, 'euclidean': 0.2, 'manhattan': 0.05}
        params['TEMPERATURE'] = 0.08
        params['HARD_NEG_RATIO'] = 0.2
        params['SEMI_HARD_NEG_RATIO'] = 0.3
        params['RANDOM_NEG_RATIO'] = 0.5
        params['LOCAL_WEIGHT'] = 4.0
        params['USE_CALIBRATION'] = True
        params['SIM_THRESHOLD'] = 0.6

    # 较多样本情况 (5个以上样本)
    else:
        params['METRIC_WEIGHTS'] = {'cosine': 0.7, 'euclidean': 0.2, 'manhattan': 0.1}
        params['TEMPERATURE'] = 0.1
        params['HARD_NEG_RATIO'] = 0.25
        params['SEMI_HARD_NEG_RATIO'] = 0.35
        params['RANDOM_NEG_RATIO'] = 0.4
        params['LOCAL_WEIGHT'] = 3.0
        params['USE_CALIBRATION'] = True
        params['SIM_THRESHOLD'] = 0.65

    return params

# --- 辅助函数：机制簇命名 ---
def get_theme_from_api(client, stories, retries=3, delay=5):
    """使用API为机制簇命名。"""
    stories_text = "\n\n".join(f"Evidence {i+1}: {story}" for i, story in enumerate(stories[:3]))
    messages = [
        {"role": "system", "content": "You are a biologist classifying mechanisms. Your task is to identify the single, most specific, common biological process from the provided texts and state it as a label."},
        {"role": "user", "content": f"""
Based on the following mechanism descriptions, identify and state the single most specific, common biological process or pathway.

Output ONLY the name of the process/pathway (e.g., 'JAK-STAT Signaling Inhibition', 'PI3K Pathway Modulation'). Do not add any extra text.

Mechanisms:
{stories_text}

Process/Pathway Label:
"""}
    ]
    for attempt in range(retries):
        try:
            completion = client.chat.completions.create(
                model="deepseek-v3:671b",
                messages=messages,
                temperature=0.0,
                max_tokens=20,
            )
            return completion.choices[0].message.content.strip()
        except Exception as e:
            if attempt < retries - 1:
                time.sleep(delay * (2 ** attempt))
            else:
                return "Mechanism Theme Undetermined"

# --- 增强型聚类和原型构建 ---
def advanced_clustering_and_prototypes(support_embeddings, support_pairs, all_pos_pairs,
                                     disease_id=None, local_weight=None, adaptive=True):
    """增强型聚类和原型构建，使用多种聚类方法和层次化原型"""
    # 确保所有嵌入都在同一设备上
    device = support_embeddings.device
    support_embeddings_cpu = support_embeddings.cpu().numpy()

    # 使用自定义的LOCAL_WEIGHT或全局默认值
    local_weight = local_weight if local_weight is not None else LOCAL_WEIGHT

    # 标准化嵌入
    norm_embeddings = normalize(support_embeddings_cpu)

    # 第一阶段：选择最佳聚类方法和参数
    best_score = -1
    best_labels = None
    best_method = None
    best_params = {}

    # 根据样本数量决定聚类方法
    n_samples = len(support_embeddings)

    if n_samples < 4:
        # 样本太少，使用简单KMeans聚类
        n_clusters = min(2, n_samples)
        if n_clusters < 1:
            # 只有一个样本，直接返回
            prototype = F.normalize(support_embeddings[0], p=2, dim=0).unsqueeze(0)
            return prototype, {0: {"prototype_vector": prototype[0].cpu(),
                                    "indices": [0],
                                    "local_samples": 1 if support_pairs[0] in all_pos_pairs else 0,
                                    "augmented_samples": 0 if support_pairs[0] in all_pos_pairs else 1,
                                    "cluster_size": 1,
                                    "quality_score": 1.0 if support_pairs[0] in all_pos_pairs else 0.5,
                                    "stories": [support_pairs[0]],
                                    "method": "single_sample",
                                    "params": {}}}, [0]

        # 为小样本尝试更稳定的KMeans设置
        clustering = KMeans(n_clusters=n_clusters, random_state=42, n_init=20, max_iter=500)
        best_labels = clustering.fit_predict(norm_embeddings)
        best_method = "kmeans_few_shot"
        best_params = {"n_clusters": n_clusters}
    else:
        # 尝试层次聚类和DBSCAN
        # 1. 层次聚类
        for n_clusters in NUM_CLUSTERS_RANGE:
            if n_clusters >= n_samples:
                continue

            # 尝试不同的linkage和affinity组合
            for linkage in ['average', 'complete', 'ward']:
                for affinity in ['euclidean', 'cosine'] if linkage != 'ward' else ['euclidean']:
                    try:
                        clustering = AgglomerativeClustering(
                            n_clusters=n_clusters,
                            affinity=affinity,
                            linkage=linkage
                        )
                        labels = clustering.fit_predict(norm_embeddings)

                        # 验证聚类质量
                        if len(set(labels)) > 1:
                            score = silhouette_score(norm_embeddings, labels, metric='cosine')
                            if score > best_score:
                                best_score = score
                                best_labels = labels
                                best_method = "hierarchical"
                                best_params = {
                                    "n_clusters": n_clusters,
                                    "linkage": linkage,
                                    "affinity": affinity
                                }
                    except Exception as e:
                        continue

        # 2. DBSCAN (不需要预先指定簇数)
        for eps in [0.15, 0.2, 0.25, 0.3]:
            for min_samples in [2, 3]:
                try:
                    if min_samples >= n_samples:
                        continue

                    clustering = DBSCAN(eps=eps, min_samples=min_samples, metric='cosine')
                    labels = clustering.fit_predict(norm_embeddings)

                    # 检查是否所有样本都被归为噪声(-1)
                    unique_labels = set(labels)
                    if len(unique_labels) > 1 and -1 not in unique_labels:
                        score = silhouette_score(norm_embeddings, labels, metric='cosine')
                        if score > best_score:
                            best_score = score
                            best_labels = labels
                            best_method = "dbscan"
                            best_params = {
                                "eps": eps,
                                "min_samples": min_samples
                            }
                except Exception as e:
                    continue

    # 如果上述方法都失败，回退到简单KMeans
    if best_labels is None:
        n_clusters = min(2, n_samples)
        clustering = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
        best_labels = clustering.fit_predict(norm_embeddings)
        best_method = "kmeans_fallback"
        best_params = {"n_clusters": n_clusters}

    # 第二阶段：构建层次化原型
    # 对每个簇应用权重和质量评分
    prototypes = []
    cluster_info = {}

    # 对于每个簇，构建质量感知原型
    unique_labels = np.unique(best_labels)
    for label in unique_labels:
        if label == -1:  # 跳过噪声点
            continue

        # 获取簇的样本索引
        cluster_indices = np.where(best_labels == label)[0]

        # 区分本地样本和增强样本
        local_indices = [i for i in cluster_indices if support_pairs[i] in all_pos_pairs]
        augmented_indices = [i for i in cluster_indices if support_pairs[i] not in all_pos_pairs]

        # 计算加权原型
        # 1. 本地样本贡献（给予更高权重）
        if local_indices:
            local_embs = support_embeddings[local_indices]

            # 进一步按照样本质量加权
            local_weights = torch.ones(len(local_indices), device=device) * local_weight
            local_emb = torch.sum(local_embs * local_weights.unsqueeze(1), dim=0) / torch.sum(local_weights)
            local_weight_sum = local_weight * len(local_indices)
        else:
            local_emb = torch.zeros_like(support_embeddings[0])
            local_weight_sum = 0

        # 2. 增强样本贡献
        if augmented_indices:
            # 考虑增强样本与本地样本的相似度
            aug_embs = support_embeddings[augmented_indices]

            # 如果有本地样本，根据相似度加权增强样本
            if local_indices:
                local_centroid = local_emb / local_weight

                # 计算增强样本与本地中心的相似度
                sim_to_local = torch.nn.functional.cosine_similarity(
                    aug_embs, local_centroid.unsqueeze(0), dim=1
                )

                # 应用基于相似度的权重
                aug_weights = 0.5 + 0.5 * sim_to_local  # 权重范围：[0.5, 1.0]
                aug_emb = torch.sum(aug_embs * aug_weights.unsqueeze(1), dim=0) / torch.sum(aug_weights)
                aug_weight_sum = torch.sum(aug_weights).item()
            else:
                aug_emb = torch.mean(aug_embs, dim=0)
                aug_weight_sum = len(augmented_indices)
        else:
            aug_emb = torch.zeros_like(support_embeddings[0])
            aug_weight_sum = 0

        # 3. 合并本地和增强贡献
        if local_weight_sum + aug_weight_sum > 0:
            prototype = (local_emb + aug_emb) / (local_weight_sum + aug_weight_sum)
            # 确保原型向量是单位向量
            prototype = F.normalize(prototype, p=2, dim=0)
            prototypes.append(prototype)

            # 记录簇信息
            quality_score = (local_weight * len(local_indices)) / (local_weight * len(local_indices) + len(augmented_indices) + 0.001)
            cluster_stories = [support_pairs[i] for i in cluster_indices]

            cluster_info[int(label)] = {
                "indices": cluster_indices.tolist(),
                "prototype_vector": prototype.cpu(),  # CPU存储以避免内存问题
                "local_samples": len(local_indices),
                "augmented_samples": len(augmented_indices),
                "cluster_size": len(cluster_indices),
                "quality_score": quality_score,
                "stories": cluster_stories,
                "method": best_method,
                "params": best_params
            }

    # 如果没有产生任何原型，使用整体均值
    if not prototypes:
        overall_mean = F.normalize(torch.mean(support_embeddings, dim=0), p=2, dim=0)
        prototypes.append(overall_mean)

        cluster_info[-999] = {
            "indices": list(range(len(support_embeddings))),
            "prototype_vector": overall_mean.cpu(),
            "local_samples": sum(1 for p in support_pairs if p in all_pos_pairs),
            "augmented_samples": sum(1 for p in support_pairs if p not in all_pos_pairs),
            "cluster_size": len(support_embeddings),
            "quality_score": 0.5,
            "stories": support_pairs,
            "method": "fallback_mean",
            "params": {}
        }

    # 确保返回的原型在正确的设备上
    if prototypes:
        return torch.stack(prototypes).to(device), cluster_info, best_labels
    else:
        # 创建一个默认原型
        default_proto = F.normalize(support_embeddings.mean(dim=0), p=2, dim=0).unsqueeze(0)
        return default_proto, cluster_info, np.zeros(len(support_embeddings))

# --- 多层次负样本选择 ---
def select_stratified_negative_samples(query_pos_pair, candidate_neg_pairs, pair_to_story_embedding,
                                     num_samples=100, hard_ratio=None, semi_hard_ratio=None,
                                     random_ratio=None):
    """分层选择负样本，包括困难、中等和随机负样本"""
    # 使用传入的参数或全局默认值
    hard_ratio = hard_ratio if hard_ratio is not None else HARD_NEG_RATIO
    semi_hard_ratio = semi_hard_ratio if semi_hard_ratio is not None else SEMI_HARD_NEG_RATIO
    random_ratio = random_ratio if random_ratio is not None else RANDOM_NEG_RATIO

    if len(candidate_neg_pairs) <= num_samples:
        return candidate_neg_pairs

    # 获取正样本嵌入
    pos_embedding = pair_to_story_embedding[query_pos_pair]

    # 计算所有负样本与正样本的相似度
    neg_embeddings = torch.stack([pair_to_story_embedding[p] for p in candidate_neg_pairs])
    # 确保张量在同一设备上
    pos_embedding = pos_embedding.to(neg_embeddings.device)
    similarities = F.cosine_similarity(pos_embedding.unsqueeze(0), neg_embeddings)

    # 对相似度排序
    neg_with_sim = list(zip(candidate_neg_pairs, similarities.cpu().numpy()))
    neg_with_sim.sort(key=lambda x: x[1], reverse=True)  # 按相似度降序排序

    # 计算不同类型负样本的数量
    num_hard = int(num_samples * hard_ratio)
    num_semi_hard = int(num_samples * semi_hard_ratio)
    num_random = num_samples - num_hard - num_semi_hard

    # 选择困难负样本（最相似的）
    hard_samples = [p for p, _ in neg_with_sim[:num_hard]]

    # 选择中等困难负样本（中等相似度）
    mid_start = len(neg_with_sim) // 4
    mid_end = mid_start + num_semi_hard
    mid_range = range(mid_start, min(mid_end, len(neg_with_sim)))
    semi_hard_samples = [neg_with_sim[i][0] for i in mid_range]

    # 从剩余样本中随机选择
    remaining_indices = list(range(num_hard, mid_start)) + list(range(mid_end, len(neg_with_sim)))
    if len(remaining_indices) > num_random and num_random > 0:
        random_indices = random.sample(remaining_indices, num_random)
        random_samples = [neg_with_sim[i][0] for i in random_indices]
    else:
        random_samples = [neg_with_sim[i][0] for i in remaining_indices[:num_random]]

    # 合并所有负样本
    selected_samples = hard_samples + semi_hard_samples + random_samples

    # 确保没有重复且长度正确
    unique_samples = list(dict.fromkeys(selected_samples))

    # 如果不够数量，从原始列表中补充
    if len(unique_samples) < num_samples and len(candidate_neg_pairs) > len(unique_samples):
        remaining = [p for p in candidate_neg_pairs if p not in unique_samples]
        unique_samples.extend(remaining[:num_samples - len(unique_samples)])

    return unique_samples[:num_samples]

# --- 智能支撑集增强 ---
def augment_support_set_advanced(target_disease_id, all_pos_pairs, disease_to_pos_pairs,
                               disease_embeddings, split_info, pair_to_story_embedding,
                               pairkey_to_story_text, sim_threshold=None):
    """增强支撑集选择，使用多种相似度度量和机制相似性"""
    # 使用传入的参数或全局默认值
    sim_threshold = sim_threshold if sim_threshold is not None else SIM_THRESHOLD

    # 如果目标疾病没有嵌入，返回空
    target_embedding = disease_embeddings.get(target_disease_id)
    if target_embedding is None:
        return [], []

    # 同时计算疾病嵌入相似度和药物机制相似度
    disease_similarities = {}
    for train_disease_id in split_info['meta_train']:
        # 跳过没有嵌入或与当前疾病相同的疾病
        if train_disease_id not in disease_embeddings or train_disease_id == target_disease_id:
            continue

        train_embedding = disease_embeddings[train_disease_id]
        # 使用余弦相似度比较疾病嵌入
        sim_score = F.cosine_similarity(target_embedding.unsqueeze(0), train_embedding.unsqueeze(0)).item()

        # 同时考虑疾病对应药物数量作为权重因子
        drug_count = len(disease_to_pos_pairs.get(train_disease_id, []))
        adjusted_score = sim_score * min(1.0, np.log1p(drug_count) / 3.0)

        disease_similarities[train_disease_id] = adjusted_score

    # 按调整后相似度排序
    sorted_similar_diseases = sorted(disease_similarities.keys(),
                                    key=lambda d: disease_similarities[d],
                                    reverse=True)

    # 构建增强支撑集
    augmented_support = []
    borrowed_from = []

    # 对目标疾病的已知样本计算机制中心
    target_embeddings = [pair_to_story_embedding[pair] for pair in all_pos_pairs
                        if pair in pair_to_story_embedding]

    if not target_embeddings:
        # 如果目标疾病没有本地样本，降低要求，使用更多相似疾病的样本
        max_diseases = K_SIMILAR_DISEASES + 2
        current_sim_threshold = sim_threshold * 0.8
    else:
        max_diseases = K_SIMILAR_DISEASES
        current_sim_threshold = sim_threshold
        target_centroid = torch.stack(target_embeddings).mean(dim=0)

    # 为每个相似疾病选择药物样本
    for sim_disease_id in sorted_similar_diseases[:max_diseases]:
        if len(augmented_support) >= TARGET_AUGMENTED_SAMPLES:
            break

        borrowed_pairs = disease_to_pos_pairs.get(sim_disease_id, [])
        # 跳过没有足够样本的疾病
        if len(borrowed_pairs) < 1:
            continue

        # 筛选相关机制样本
        filtered_borrowed_pairs = []

        # 如果目标疾病有本地样本，基于机制相似度选择
        if target_embeddings:
            for pair in borrowed_pairs:
                if pair in pair_to_story_embedding:
                    emb = pair_to_story_embedding[pair]
                    # 确保张量在同一设备上
                    emb = emb.to(target_centroid.device)
                    sim = F.cosine_similarity(emb.unsqueeze(0), target_centroid.unsqueeze(0))
                    filtered_borrowed_pairs.append((pair, sim.item()))

            # 按相似度排序
            filtered_borrowed_pairs.sort(key=lambda x: x[1], reverse=True)

            # 选择高于阈值的样本
            filtered_pairs = [(p, s) for p, s in filtered_borrowed_pairs if s > current_sim_threshold]

        else:
            # 如果没有本地样本，直接使用该疾病的所有样本
            filtered_pairs = [(p, 0.5) for p in borrowed_pairs if p in pair_to_story_embedding]

        # 计算要借用的样本数量 - 根据相似度调整
        disease_sim = disease_similarities[sim_disease_id]
        similarity_factor = max(0.5, min(1.5, disease_sim * 2))  # 将相似度映射到0.5-1.5的系数
        samples_from_this_disease = min(len(filtered_pairs),
                                       max(2, int(TARGET_AUGMENTED_SAMPLES / max_diseases * similarity_factor)))

        if filtered_pairs:
            # 记录借用信息
            borrowed_from.append((sim_disease_id, disease_sim, samples_from_this_disease))

            # 添加到增强支撑集
            augmented_support.extend([p for p, _ in filtered_pairs[:samples_from_this_disease]])

    # 确保不会有重复
    augmented_support = list(dict.fromkeys(augmented_support))

    # 确保不包含原始集合中的样本
    augmented_support = [p for p in augmented_support if p not in all_pos_pairs]

    # 限制总增强样本数量
    if len(augmented_support) > TARGET_AUGMENTED_SAMPLES:
        augmented_support = augmented_support[:TARGET_AUGMENTED_SAMPLES]

    return augmented_support, borrowed_from

# --- 多距离融合评分机制 ---
def multi_distance_scoring(candidate_embeddings, prototypes, cluster_info=None, disease_id=None,
                         metric_weights=None, temperature=None, use_calibration=None):
    """综合多种距离度量的评分机制，支持温度缩放和质量权重"""
    # 使用传入的参数或全局默认值
    metric_weights = metric_weights if metric_weights is not None else METRIC_WEIGHTS
    temperature = temperature if temperature is not None else TEMPERATURE
    use_calibration = use_calibration if use_calibration is not None else USE_CALIBRATION

    # 确保所有张量在同一设备上
    device = candidate_embeddings.device
    prototypes = prototypes.to(device)

    # 初始化距离矩阵
    n_candidates = len(candidate_embeddings)
    n_prototypes = len(prototypes)
    all_distances = {}

    # 1. 计算多种距离度量
    for metric in metric_weights.keys():
        if metric == 'cosine':
            # 余弦距离 (1 - 相似度)
            # 确保规范化，防止数值错误
            candidate_norm = F.normalize(candidate_embeddings, p=2, dim=1)
            proto_norm = F.normalize(prototypes, p=2, dim=1)
            cosine_sim = torch.mm(candidate_norm, proto_norm.t())
            # 确保相似度值在[-1,1]范围内
            cosine_sim = torch.clamp(cosine_sim, -1.0, 1.0)
            all_distances['cosine'] = 1.0 - cosine_sim

        elif metric == 'euclidean':
            # 欧氏距离
            euclidean_dist = torch.cdist(candidate_embeddings, prototypes)
            # 归一化
            if euclidean_dist.max() > 0:
                all_distances['euclidean'] = euclidean_dist / euclidean_dist.max()
            else:
                all_distances['euclidean'] = euclidean_dist

        elif metric == 'manhattan':
            # 曼哈顿距离
            manhattan_dist = torch.zeros((n_candidates, n_prototypes), device=device)
            for i in range(n_candidates):
                for j in range(n_prototypes):
                    manhattan_dist[i, j] = torch.sum(torch.abs(candidate_embeddings[i] - prototypes[j]))
            # 归一化
            if manhattan_dist.max() > 0:
                all_distances['manhattan'] = manhattan_dist / manhattan_dist.max()
            else:
                all_distances['manhattan'] = manhattan_dist

    # 2. 加权合并距离度量
    combined_distances = torch.zeros((n_candidates, n_prototypes), device=device)
    for metric, weight in metric_weights.items():
        if metric in all_distances:
            combined_distances += all_distances[metric] * weight

    # 3. 应用原型质量权重
    if cluster_info:
        # 收集每个原型的质量评分
        quality_weights = torch.ones(n_prototypes, device=device)
        for i, proto in enumerate(prototypes):
            for label, info in cluster_info.items():
                # 将存储的CPU向量移动到设备上进行比较
                proto_vector = info['prototype_vector'].to(device)
                if torch.allclose(proto_vector, proto, atol=1e-4):
                    # 质量分数越高，距离应越小
                    quality_factor = info['quality_score']
                    # 调整为距离修正因子：质量高的簇距离应该更小
                    quality_weights[i] = 1.0 / (quality_factor + 0.2)
                    break

        # 应用质量权重到距离上
        quality_adjusted_distances = combined_distances * quality_weights.unsqueeze(0)
    else:
        quality_adjusted_distances = combined_distances

    # 4. 应用对比学习温度缩放
    if temperature > 0:
        # 温度缩放：较低的温度使距离差异更明显
        scaled_distances = quality_adjusted_distances / temperature
    else:
        scaled_distances = quality_adjusted_distances

    # 5. 计算最终距离/相似度
    if use_calibration:
        # 使用Softmin进行概率校准：将距离转化为概率，然后取最小值对应的概率
        # 小的距离应该有较大的概率
        negative_distances = -scaled_distances
        probabilities = F.softmax(negative_distances, dim=1)
        # 对每个候选取其最匹配原型的概率
        best_probs, _ = torch.max(probabilities, dim=1)
        # 转换回距离度量：概率高则距离小
        calibrated_distances = 1.0 - best_probs
        return calibrated_distances
    else:
        # 直接取每个候选到最近原型的距离
        min_distances, _ = torch.min(scaled_distances, dim=1)
        return min_distances

# --- 交叉疾病知识增强 ---
def get_cross_disease_predictions(disease_id, query_pos_pair, all_approved_drugs,
                                disease_embeddings, disease_to_pos_pairs,
                                pair_to_story_embedding, split_info):
    """利用相似疾病的已知药物关系增强预测"""
    if not ENABLE_CROSS_DISEASE:
        return None

    # 获取目标疾病的嵌入
    target_embedding = disease_embeddings.get(disease_id)
    if target_embedding is None:
        return None

    # 计算与所有训练集疾病的相似度
    similarities = {}
    for train_disease_id in split_info['meta_train']:
        train_embedding = disease_embeddings.get(train_disease_id)
        if train_embedding is not None:
            # 确保张量在同一设备上
            train_embedding = train_embedding.to(target_embedding.device)
            sim = F.cosine_similarity(target_embedding.unsqueeze(0), train_embedding.unsqueeze(0))
            similarities[train_disease_id] = sim.item()

    # 选择最相似的几个疾病
    top_similar = sorted(similarities.items(), key=lambda x: x[1], reverse=True)[:K_SIMILAR_DISEASES]

    # 收集这些疾病已知的药物关系
    cross_disease_scores = {}
    for similar_disease_id, sim_score in top_similar:
        pos_pairs = disease_to_pos_pairs.get(similar_disease_id, [])
        pos_drugs = {pair[0] for pair in pos_pairs}

        # 对每个药物，如果有相似疾病的关系，增加其分数
        for drug_id in all_approved_drugs:
            if drug_id in pos_drugs:
                # 如果药物治疗相似疾病，增加其分数
                if drug_id not in cross_disease_scores:
                    cross_disease_scores[drug_id] = 0
                # 根据疾病相似度加权
                cross_disease_scores[drug_id] += sim_score * CROSS_DISEASE_WEIGHT

    return cross_disease_scores

# --- 主评估函数 ---
def evaluate_with_enhancements(disease_id, disease_name, all_pos_pairs, local_support_pairs,
                              augmented_support, pair_to_story_embedding, disease_to_true_drugs,
                              all_approved_drugs, entity_to_name, disease_embeddings=None,
                              split_info=None):
    """增强型评估方法，支持少样本和交叉疾病学习"""
    # 根据样本数量动态调整参数
    params = adjust_parameters_for_sample_size(len(all_pos_pairs))

    # 准备支撑集
    support_pairs = local_support_pairs + augmented_support

    # 特殊处理极少样本情况
    if len(local_support_pairs) < 1:
        print(f"  - 警告: 没有本地样本，跳过评估")
        return None, None

    # 少样本情况特殊处理
    few_shot_mode = len(all_pos_pairs) <= 2
    single_sample_mode = len(all_pos_pairs) == 1

    # 获取支撑集嵌入
    support_embeddings = torch.stack([pair_to_story_embedding[p] for p in support_pairs]).to(device)

    # 进行高级聚类和原型构建
    prototypes, cluster_info, labels = advanced_clustering_and_prototypes(
        support_embeddings, support_pairs, all_pos_pairs, disease_id,
        local_weight=params['LOCAL_WEIGHT']
    )
    prototypes = prototypes.to(device)

    # 存储评估指标
    metrics = []

    # 评估每个查询样本
    if few_shot_mode:
        # 少样本模式：使用留一法评估，每个样本轮流作为查询
        for i, query_pos_pair in enumerate(all_pos_pairs):
            # 构建临时支撑集（排除当前查询样本）
            temp_local_support = [p for p in all_pos_pairs if p != query_pos_pair]

            # 单样本特殊处理：使用增强样本作为支撑
            if single_sample_mode:
                temp_support_pairs = augmented_support
                if not temp_support_pairs:
                    continue  # 如果没有增强样本，跳过
            else:
                # 非单样本：使用其他本地样本 + 增强样本
                temp_support_pairs = temp_local_support + augmented_support

            # 获取临时支撑集嵌入
            if temp_support_pairs:
                temp_support_embeddings = torch.stack([pair_to_story_embedding[p] for p in temp_support_pairs]).to(device)

                # 重新构建原型
                temp_prototypes, temp_cluster_info, _ = advanced_clustering_and_prototypes(
                    temp_support_embeddings, temp_support_pairs, all_pos_pairs, disease_id,
                    local_weight=params['LOCAL_WEIGHT']
                )
                temp_prototypes = temp_prototypes.to(device)
            else:
                # 如果没有支撑样本，使用查询样本自身作为原型
                query_embedding = pair_to_story_embedding[query_pos_pair].to(device)
                temp_prototypes = F.normalize(query_embedding, p=2, dim=0).unsqueeze(0)
                temp_cluster_info = {
                    0: {"prototype_vector": temp_prototypes[0].cpu(),
                        "quality_score": 1.0}
                }

            # 选择负样本
            pos_drugs_in_task = {p[0] for p in all_pos_pairs}
            neg_drug_pool = all_approved_drugs - pos_drugs_in_task
            candidate_neg_pairs = [(d, disease_id) for d in neg_drug_pool
                                 if (d, disease_id) in pair_to_story_embedding]

            if not candidate_neg_pairs:
                continue

            # 使用分层负样本选择
            sampled_neg_pairs = select_stratified_negative_samples(
                query_pos_pair, candidate_neg_pairs, pair_to_story_embedding,
                num_samples=NUM_NEG_SAMPLES_PER_QUERY,
                hard_ratio=params['HARD_NEG_RATIO'],
                semi_hard_ratio=params['SEMI_HARD_NEG_RATIO'],
                random_ratio=params['RANDOM_NEG_RATIO']
            )

            # 构建候选样本
            candidate_pairs = [query_pos_pair] + sampled_neg_pairs
            candidate_embeddings = torch.stack([pair_to_story_embedding[p] for p in candidate_pairs]).to(device)

            # 使用多距离评分机制
            distances = multi_distance_scoring(
                candidate_embeddings, temp_prototypes, temp_cluster_info, disease_id,
                metric_weights=params['METRIC_WEIGHTS'],
                temperature=params['TEMPERATURE'],
                use_calibration=params['USE_CALIBRATION']
            )

            # 集成交叉疾病知识（如果启用）
            if ENABLE_CROSS_DISEASE and disease_embeddings and split_info:
                cross_scores = get_cross_disease_predictions(
                    disease_id, query_pos_pair, all_approved_drugs,
                    disease_embeddings, disease_to_true_drugs,
                    pair_to_story_embedding, split_info
                )

                if cross_scores:
                    # 应用交叉疾病分数调整
                    for i, pair in enumerate(candidate_pairs):
                        drug_id = pair[0]
                        if drug_id in cross_scores:
                            # 交叉知识分数越高，距离应越小
                            distances[i] = distances[i] * (1.0 - cross_scores[drug_id] * 0.5)

            # 计算排名 - 使用更稳健的排序方法
            sorted_indices = torch.argsort(distances)
            rank = (sorted_indices == 0).nonzero().item() + 1

            # 统计指标
            metrics.append([1 if rank <= k else 0 for k in [1, 5, 10]] + [1.0 / rank])

    else:
        # 标准模式：使用固定支撑集和查询集
        query_pos_pairs = [p for p in all_pos_pairs if p not in local_support_pairs]

        if not query_pos_pairs:
            print(f"  - 警告: 没有可用于查询的本地样本")
            # 仍然返回聚类结果用于可视化
            return None, (prototypes, cluster_info, labels)

        for query_pos_pair in query_pos_pairs:
            # 选择负样本
            pos_drugs_in_task = {p[0] for p in all_pos_pairs}
            neg_drug_pool = all_approved_drugs - pos_drugs_in_task
            candidate_neg_pairs = [(d, disease_id) for d in neg_drug_pool
                                  if (d, disease_id) in pair_to_story_embedding]

            if not candidate_neg_pairs:
                continue

            # 使用分层负样本选择
            sampled_neg_pairs = select_stratified_negative_samples(
                query_pos_pair, candidate_neg_pairs, pair_to_story_embedding,
                num_samples=NUM_NEG_SAMPLES_PER_QUERY,
                hard_ratio=params['HARD_NEG_RATIO'],
                semi_hard_ratio=params['SEMI_HARD_NEG_RATIO'],
                random_ratio=params['RANDOM_NEG_RATIO']
            )

            # 构建候选样本
            candidate_pairs = [query_pos_pair] + sampled_neg_pairs
            candidate_embeddings = torch.stack([pair_to_story_embedding[p] for p in candidate_pairs]).to(device)

            # 使用多距离评分机制
            distances = multi_distance_scoring(
                candidate_embeddings, prototypes, cluster_info, disease_id,
                metric_weights=params['METRIC_WEIGHTS'],
                temperature=params['TEMPERATURE'],
                use_calibration=params['USE_CALIBRATION']
            )

            # 集成交叉疾病知识（如果启用）
            if ENABLE_CROSS_DISEASE and disease_embeddings and split_info:
                cross_scores = get_cross_disease_predictions(
                    disease_id, query_pos_pair, all_approved_drugs,
                    disease_embeddings, disease_to_true_drugs,
                    pair_to_story_embedding, split_info
                )

                if cross_scores:
                    # 应用交叉疾病分数调整
                    for i, pair in enumerate(candidate_pairs):
                        drug_id = pair[0]
                        if drug_id in cross_scores:
                            # 交叉知识分数越高，距离应越小
                            distances[i] = distances[i] * (1.0 - cross_scores[drug_id] * 0.5)

            # 计算排名
            sorted_indices = torch.argsort(distances)
            rank = (sorted_indices == 0).nonzero().item() + 1

            metrics.append([1 if rank <= k else 0 for k in [1, 5, 10]] + [1.0 / rank])

    # 计算平均指标
    avg_metrics = np.mean(metrics, axis=0) if metrics else np.zeros(4)

    result = {
        'disease': disease_name,
        'h1': avg_metrics[0], 'h5': avg_metrics[1],
        'h10': avg_metrics[2], 'mrr': avg_metrics[3],
        'num_eval_points': len(metrics)
    }

    return result, (prototypes, cluster_info, labels)

# --- 生成解释性预测 ---
def generate_explainable_predictions(disease_id, entity_to_name, drug_status_lookup,
                                    pair_to_story_embedding, pairkey_to_story_text,
                                    disease_to_true_drugs, all_approved_drugs,
                                    api_client, final_clustering_result=None,
                                    disease_embeddings=None, split_info=None):
    """生成最终的可解释预测结果，增强可解释性"""
    if not final_clustering_result:
        return None

    # 确保设备一致性
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    prototypes, cluster_info, _ = final_clustering_result

    # 确保原型在正确设备上
    prototypes = prototypes.to(device)

    # 为簇信息中的向量更新设备
    for label in cluster_info:
        if "prototype_vector" in cluster_info[label]:
            cluster_info[label]["prototype_vector"] = cluster_info[label]["prototype_vector"].to(device)

    # 为每个簇命名
    with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
        future_to_label = {}
        for label, info in cluster_info.items():
            stories = []
            for pair in info['stories'][:5]:  # 最多使用5个故事进行命名
                story_key = f"{pair[0]}|||{pair[1]}"
                if story_key in pairkey_to_story_text:
                    stories.append(pairkey_to_story_text[story_key])
            if stories:
                future_to_label[executor.submit(get_theme_from_api, api_client, stories)] = label

        for future in concurrent.futures.as_completed(future_to_label):
            label = future_to_label[future]
            try:
                theme_name = future.result()
                cluster_info[label]['theme_name'] = theme_name
            except Exception:
                cluster_info[label]['theme_name'] = "Mechanism Theme Undetermined"

    # 获取交叉疾病知识（如果启用）
    cross_disease_scores = {}
    if ENABLE_CROSS_DISEASE and disease_embeddings and split_info:
        # 为所有药物计算交叉疾病分数
        for drug_id in all_approved_drugs:
            pseudo_pair = (drug_id, disease_id)
            temp_scores = get_cross_disease_predictions(
                disease_id, pseudo_pair, all_approved_drugs,
                disease_embeddings, disease_to_true_drugs,
                pair_to_story_embedding, split_info
            )
            if temp_scores:
                cross_disease_scores.update(temp_scores)

    # 为所有候选药物生成预测
    predictions = []
    true_drug_set = disease_to_true_drugs.get(disease_id, set())

    # 收集所有需要评估的药物
    all_candidates = []
    for drug_id in all_approved_drugs:
        pair = (drug_id, disease_id)
        if pair in pair_to_story_embedding:
            all_candidates.append(drug_id)

    # 如果没有候选药物，返回空
    if not all_candidates:
        return []

    # 批量计算嵌入以提高效率
    candidate_pairs = [(d, disease_id) for d in all_candidates]
    candidate_embeddings = torch.stack([
        pair_to_story_embedding[(d, disease_id)].to(device) for d in all_candidates
    ])

    # 批量计算距离
    try:
        # 使用较好的参数配置
        params = adjust_parameters_for_sample_size(3)  # 使用标准参数
        distances = multi_distance_scoring(
            candidate_embeddings, prototypes, cluster_info, disease_id,
            metric_weights=params['METRIC_WEIGHTS'],
            temperature=params['TEMPERATURE'],
            use_calibration=params['USE_CALIBRATION']
        )

        # 转换距离为相似度分数
        scores = 1.0 - distances
    except Exception as e:
        print(f"计算距离时出错: {e}")
        # 使用简单余弦相似度作为回退
        candidate_norm = F.normalize(candidate_embeddings, p=2, dim=1)
        proto_norm = F.normalize(prototypes, p=2, dim=1)
        cos_sim = torch.mm(candidate_norm, proto_norm.t())
        scores, _ = torch.max(cos_sim, dim=1)

    # 确定每个候选药物的最佳匹配簇
    candidate_norm = F.normalize(candidate_embeddings, p=2, dim=1)
    proto_norm = F.normalize(prototypes, p=2, dim=1)
    cos_sim = torch.mm(candidate_norm, proto_norm.t())
    best_proto_indices = torch.argmax(cos_sim, dim=1)

    # 生成预测
    for i, drug_id in enumerate(all_candidates):
        # 获取分数
        score = scores[i].item()

        # 应用交叉疾病知识加成
        cross_score = 0.0
        if drug_id in cross_disease_scores:
            cross_score = cross_disease_scores[drug_id]
            # 综合分数：原始分数和交叉疾病分数的加权和
            score = score * (1.0 - CROSS_DISEASE_WEIGHT) + cross_score * CROSS_DISEASE_WEIGHT

        # 确定最佳匹配簇
        best_proto_idx = best_proto_indices[i].item()

        # 查找簇信息
        matched_cluster_label = -1
        matched_cluster_theme = "未知机制"
        for label, info in cluster_info.items():
            proto_vector = info['prototype_vector']
            if torch.allclose(proto_vector, prototypes[best_proto_idx], atol=1e-4):
                matched_cluster_label = label
                matched_cluster_theme = info.get('theme_name', "未知机制")
                break

        # 获取故事
        story_key = f"{drug_id}|||{disease_id}"
        mechanism_story = pairkey_to_story_text.get(story_key, "")

        # 添加预测
        predictions.append({
            "drug_id": drug_id,
            "drug_name": entity_to_name.get(drug_id, drug_id),
            "score": float(score),  # 确保是Python标量以便JSON序列化
            "mechanism_story": mechanism_story,
            "matched_cluster_id": int(matched_cluster_label),
            "matched_cluster_theme": matched_cluster_theme,
            "is_true_positive": drug_id in true_drug_set,
            "cross_disease_score": float(cross_score)
        })

    # 按分数排序
    predictions.sort(key=lambda x: x['score'], reverse=True)

    # 添加排名信息
    for i, pred in enumerate(predictions):
        pred['rank'] = i + 1

    return predictions

# --- 可视化函数 ---
def visualize_clustering(support_embeddings, labels, disease_name, support_pairs, all_pos_pairs):
    """增强的聚类可视化"""
    plt.figure(figsize=(12, 10))

    # 样本数量
    n_samples = len(support_embeddings)

    # 降维方法选择
    if n_samples >= 5:
        # 尝试t-SNE
        try:
            # 自适应调整perplexity
            perplexity = min(30, max(5, n_samples // 3))
            tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity)
            embeddings_2d = tsne.fit_transform(support_embeddings.cpu().numpy())
            vis_method = "t-SNE"
        except ValueError:
            # t-SNE失败，回退到PCA
            pca = PCA(n_components=2)
            embeddings_2d = pca.fit_transform(support_embeddings.cpu().numpy())
            vis_method = "PCA (t-SNE失败)"
    else:
        # 样本太少，直接使用PCA
        pca = PCA(n_components=2)
        embeddings_2d = pca.fit_transform(support_embeddings.cpu().numpy())
        vis_method = "PCA (样本数<5)"

    # 为每个簇选择颜色
    unique_labels = set(labels)
    colors = plt.cm.tab10(np.linspace(0, 1, max(10, len(unique_labels))))
    color_map = {label: colors[i % 10] for i, label in enumerate(unique_labels)}

    # 区分本地样本和增强样本
    is_local = [pair in all_pos_pairs for pair in support_pairs]

    # 绘制连接线：连接每个簇的样本到其质心
    cluster_centers = {}
    for label in unique_labels:
        if label == -1:  # 噪声点没有质心
            continue
        mask = labels == label
        if np.any(mask):
            cluster_centers[label] = np.mean(embeddings_2d[mask], axis=0)

    for i, (x, y) in enumerate(embeddings_2d):
        label = labels[i]
        if label != -1 and label in cluster_centers:
            # 画一条从样本到簇中心的细线
            cx, cy = cluster_centers[label]
            plt.plot([x, cx], [y, cy], '-', color=color_map.get(label, 'black'),
                    alpha=0.2, linewidth=0.5)

    # 绘制点
    for i, (x, y) in enumerate(embeddings_2d):
        label = labels[i]
        marker = 'o' if is_local[i] else 'x'
        size = 120 if is_local[i] else 60
        color = color_map.get(label, 'black')
        edge_color = 'black' if is_local[i] else 'none'
        plt.scatter(x, y, c=[color], marker=marker, s=size, alpha=0.8,
                   edgecolors=edge_color, linewidths=1)

    # 绘制簇中心
    for label, (cx, cy) in cluster_centers.items():
        plt.scatter(cx, cy, marker='*', s=250, c=[color_map.get(label, 'black')],
                   edgecolors='black', linewidths=1.5)

    # 添加图例
    legend_elements = []
    for label in sorted(unique_labels):
        if label == -1:
            legend_elements.append(plt.Line2D([0], [0], marker='o', color='w',
                                             markerfacecolor='black', markersize=10, label=f'噪声点'))
        else:
            color = color_map.get(label, 'black')
            legend_elements.append(plt.Line2D([0], [0], marker='o', color='w',
                                             markerfacecolor=color, markersize=10, label=f'簇 {label}'))

    # 样本类型图例
    legend_elements.append(plt.Line2D([0], [0], marker='o', color='w',
                                     markerfacecolor='gray', markersize=10, label='本地样本'))
    legend_elements.append(plt.Line2D([0], [0], marker='x', color='gray',
                                     markersize=10, label='增强样本'))
    legend_elements.append(plt.Line2D([0], [0], marker='*', color='w',
                                     markerfacecolor='gold', markersize=15, label='簇中心'))

    plt.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.1, 1.0))
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.title(f'疾病 {disease_name} 的机制聚类 (样本数: {n_samples}, 方法: {vis_method})', fontsize=14)
    plt.tight_layout()

    # 增加注释：显示每个簇的样本数量
    for label in sorted(unique_labels):
        if label == -1:
            continue
        mask = labels == label

def main():
    if not API_TOKEN:
        print("API Token未设置，程序已终止。请先在Colab密钥中设置您的Token。")
        return

    print("加载数据...")
    pair_to_story_embedding = torch.load(STORY_EMBEDDINGS_FILE)
    with open(LLM_STORIES_FILE, 'r', encoding='utf-8') as f:
        pairkey_to_story_text = json.load(f)
    with open(META_TASK_SPLIT_PATH, 'r') as f:
        split_info = json.load(f)
    with open(DRUG_STATUS_PATH, 'r') as f:
        drug_status_lookup = json.load(f)
    kg_df = pd.read_csv(KG_PATH, sep='\t', header=None, names=['head', 'relation', 'tail'], engine='python')
    entities = list(pd.concat([kg_df['head'], kg_df['tail']]).unique())
    relations = kg_df['relation'].unique()
    entity_to_idx = {name: i for i, name in enumerate(entities)}
    entity_to_name = {e: e.split('::')[-1] for e in entities}

    all_approved_drugs = {d for d, s in drug_status_lookup.items() if s == 'approved'}

    treats_df = kg_df[kg_df['relation'].isin(POSITIVE_TREATMENT_RELATIONS)]
    disease_to_true_drugs = collections.defaultdict(set)
    for _, row in treats_df.iterrows():
        disease_to_true_drugs[row['tail']].add(row['head'])

    # [新增] 使用原始方式构建disease_to_pos_pairs
    original_disease_to_pos_pairs = collections.defaultdict(list)
    for disease_id, drugs in disease_to_true_drugs.items():
        for drug_id in drugs:
            pair = (drug_id, disease_id)
            if pair in pair_to_story_embedding:
                original_disease_to_pos_pairs[disease_id].append(pair)

    # [新增] 检查和更新药物-疾病对映射，使用新划分文件中的信息
    print("检查和更新有效药物-疾病对...")
    enhanced_disease_to_pos_pairs = collections.defaultdict(list)
    pairs_missing_embedding = 0
    pairs_with_embedding = 0

    # 如果新划分文件包含详细的疾病-药物映射
    if 'disease_to_drugs' in split_info:
        for disease_id, drug_ids in split_info['disease_to_drugs'].items():
            for drug_id in drug_ids:
                pair = (drug_id, disease_id)
                if pair in pair_to_story_embedding:
                    enhanced_disease_to_pos_pairs[disease_id].append(pair)
                    pairs_with_embedding += 1
                else:
                    pairs_missing_embedding += 1

        print(f"从新划分文件构建药物-疾病对: {pairs_with_embedding} 个有嵌入，{pairs_missing_embedding} 个缺失嵌入")

        # 比较原始映射和增强映射
        original_total = sum(len(pairs) for _, pairs in original_disease_to_pos_pairs.items())
        enhanced_total = sum(len(pairs) for _, pairs in enhanced_disease_to_pos_pairs.items())
        print(f"原始映射: {len(original_disease_to_pos_pairs)} 个疾病, {original_total} 个有效对")
        print(f"增强映射: {len(enhanced_disease_to_pos_pairs)} 个疾病, {enhanced_total} 个有效对")

        # 使用增强后的映射替换原来的映射，如果它包含更多对
        if enhanced_total >= original_total:
            disease_to_pos_pairs = enhanced_disease_to_pos_pairs
            print("成功更新药物-疾病对映射！")
        else:
            disease_to_pos_pairs = original_disease_to_pos_pairs
            print("保留原始药物-疾病对映射，因为它包含更多有效对。")
    else:
        disease_to_pos_pairs = original_disease_to_pos_pairs
        print("使用原始药物-疾病对映射（新划分文件缺少详细映射）。")

    # 检查测试集疾病中的药物数量
    test_disease_drug_counts = {}
    for disease_id in split_info['meta_test']:
        count = len(disease_to_pos_pairs.get(disease_id, []))
        test_disease_drug_counts[disease_id] = count

    print("\n测试集疾病的有效药物数量:")
    for disease_id, count in test_disease_drug_counts.items():
        disease_name = entity_to_name.get(disease_id, disease_id)
        print(f"  - {disease_name}: {count} 个有效药物")

    print("加载GNN模型以计算疾病嵌入相似度...")
    num_nodes, num_relations = len(entity_to_idx), len(relations)
    gnn_model = ComplEx(num_nodes, num_relations, hidden_channels=512).to(device)
    state_dict_gnn = torch.load(GNN_MODEL_PATH, map_location=device)
    new_state_dict_gnn = OrderedDict()
    for k, v in state_dict_gnn.items():
        name = k.replace('_orig_mod.', '')
        new_state_dict_gnn[name] = v
    gnn_model.load_state_dict(new_state_dict_gnn)
    gnn_model.eval()

    disease_ids = split_info['meta_train'] + split_info['meta_test']
    disease_embeddings = {}
    with torch.no_grad():
        for d_id in disease_ids:
            if d_id in entity_to_idx:
                idx = torch.tensor([entity_to_idx[d_id]], device=device)
                disease_embeddings[d_id] = gnn_model.node_emb(idx).squeeze()

    print("数据加载和GNN嵌入提取完成。")

    print("初始化API客户端用于可解释性步骤...")
    api_client = OpenAI(api_key=API_TOKEN, base_url=BASE_URL)
    print("API客户端加载成功。")

    print("\n--- 开始在测试疾病上进行评估 (增强型支撑集 + 多级原型) ---")
    final_explainable_output = {}
    all_results_list = []

    for disease_id in tqdm(split_info['meta_test'], desc="评估测试疾病"):
        disease_name = entity_to_name.get(disease_id, disease_id)
        print(f"\n--- 正在处理疾病: {disease_name} ---")

        all_pos_pairs = disease_to_pos_pairs.get(disease_id, [])
        print(f"  - 本地找到 {len(all_pos_pairs)} 个有效药物对。")

        # [修改] 即使样本少于LOCAL_SUPPORT_SHOTS+1，也尝试评估
        if len(all_pos_pairs) < 1:
            print(f"  - 跳过: 没有本地有效药物对。")
            continue

        # 如果样本数量很少，使用特殊处理
        if len(all_pos_pairs) <= LOCAL_SUPPORT_SHOTS:
            print(f"  - 注意: 本地样本数 ({len(all_pos_pairs)}) ≤ {LOCAL_SUPPORT_SHOTS}，使用全部样本作为支撑集")
            special_mode = True
        else:
            special_mode = False

        # --- 1. 执行多次评估运行来获取稳定指标 ---
        disease_runs_results = []
        for run in range(NUM_EVAL_RUNS):
            # 随机划分本地支撑集和查询集
            random.shuffle(all_pos_pairs)

            if special_mode:
                # 特殊模式：使用所有样本作为支撑集，最后一个作为查询
                local_support_pairs = all_pos_pairs[:-1] if len(all_pos_pairs) > 1 else all_pos_pairs
            else:
                # 正常模式：使用LOCAL_SUPPORT_SHOTS个样本作为支撑集
                local_support_pairs = all_pos_pairs[:LOCAL_SUPPORT_SHOTS]

            # 智能增强支撑集
            augmented_support, borrowed_info = augment_support_set_intelligently(
                disease_id, all_pos_pairs, disease_to_pos_pairs,
                disease_embeddings, split_info, pair_to_story_embedding,
                pairkey_to_story_text
            )

            # 执行评估
            result, _ = evaluate_with_enhancements(
                disease_id, disease_name, all_pos_pairs, local_support_pairs,
                augmented_support, pair_to_story_embedding, disease_to_true_drugs,
                all_approved_drugs, entity_to_name
            )

            if result:
                disease_runs_results.append(result)

        # 计算平均性能
        if disease_runs_results:
            avg_h1 = sum(r['h1'] for r in disease_runs_results) / len(disease_runs_results)
            avg_h5 = sum(r['h5'] for r in disease_runs_results) / len(disease_runs_results)
            avg_h10 = sum(r['h10'] for r in disease_runs_results) / len(disease_runs_results)
            avg_mrr = sum(r['mrr'] for r in disease_runs_results) / len(disease_runs_results)

            all_results_list.append({
                'disease': disease_name,
                'h1': avg_h1, 'h5': avg_h5, 'h10': avg_h10, 'mrr': avg_mrr
            })

            print(f"  - 平均性能: MRR = {avg_mrr:.4f}, Hit@1 = {avg_h1:.4f}, "
                  f"Hit@5 = {avg_h5:.4f}, Hit@10 = {avg_h10:.4f}")
        else:
            print(f"  - 无法计算评估指标，跳过疾病 {disease_name}。")
            continue

        # --- 2. 为最终预测生成一个固定的聚类 ---
        print(f"  - 为疾病 {disease_name} 生成最终预测...")

        # 使用所有本地样本作为支撑
        local_support_pairs = all_pos_pairs

        # 再次获取增强支撑集
        augmented_support, borrowed_info = augment_support_set_intelligently(
            disease_id, all_pos_pairs, disease_to_pos_pairs,
            disease_embeddings, split_info, pair_to_story_embedding,
            pairkey_to_story_text
        )

        support_pairs = local_support_pairs + augmented_support
        if not support_pairs:
            continue

        support_embeddings = torch.stack([pair_to_story_embedding[p] for p in support_pairs])

        # 进行增强型聚类和原型构建
        prototypes, cluster_info, labels = enhanced_clustering_and_prototypes(
            support_embeddings, support_pairs, all_pos_pairs
        )

        # 可视化聚类结果
        try:
            visualize_clustering(support_embeddings, labels, disease_name, support_pairs, all_pos_pairs)
        except Exception as e:
            print(f"  - 可视化时出错: {e}")

        # 生成解释性预测
        try:
            predictions = generate_explainable_predictions(
                disease_id, entity_to_name, drug_status_lookup,
                pair_to_story_embedding, pairkey_to_story_text,
                disease_to_true_drugs, all_approved_drugs,
                api_client, (prototypes, cluster_info, labels)
            )

            if predictions:
                final_explainable_output[disease_name] = predictions
        except Exception as e:
            print(f"  - 生成预测时出错: {e}")

    # --- 3. 显示和保存最终结果 ---
    if all_results_list:
        results_df = pd.DataFrame(all_results_list).set_index('disease')
        print("\n--- 最终性能分数 (增强型多原型模型) ---")
        display(HTML(results_df.to_html()))

        avg_scores_df = results_df.mean(numeric_only=True).to_frame('Average Score').T
        print("\n--- 平均性能 ---")
        display(HTML(avg_scores_df.to_html()))
    else:
        print("评估失败或没有符合条件的测试疾病。")

    with open(EXPLAINABLE_RESULTS_FILE, 'w', encoding='utf-8') as f:
        json.dump(final_explainable_output, f, indent=2, ensure_ascii=False)
    print(f"\n详细的可解释性预测已保存到: {EXPLAINABLE_RESULTS_FILE}")
    print(f"\n--- 代码块3完成。 ---")


if __name__ == '__main__':
    main()