In [None]:
# 导入库
import os
import sys
import json
import time
import math
import warnings
from typing import Tuple, Dict, List, Optional

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from matplotlib_venn import venn3

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    accuracy_score, roc_auc_score, f1_score, 
    confusion_matrix, roc_curve, precision_recall_curve, 
    average_precision_score
)
from sklearn.utils.class_weight import compute_class_weight
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC, SVC
from sklearn.pipeline import Pipeline
from sklearn.calibration import CalibratedClassifierCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.manifold import TSNE

import shap
from lime import lime_tabular

warnings.filterwarnings("ignore", category=UserWarning, message=".*dropout_adj.*", module="torch_geometric")

# 设置matplotlib中文字体和样式
plt.rcParams['font.size'] = 10
plt.rcParams['axes.labelsize'] = 11
plt.rcParams['axes.titlesize'] = 12
plt.rcParams['xtick.labelsize'] = 9
plt.rcParams['ytick.labelsize'] = 9
plt.rcParams['legend.fontsize'] = 9
plt.rcParams['figure.titlesize'] = 13
sns.set_style("ticks")

In [None]:
# 配置和路径设置
EXPR_FILE = "/home/fujing/ad_ssl/0_rawdata/exp_mono.csv"
LABEL_FILE = "/home/fujing/ad_ssl/0_rawdata/ph_mono.csv"
SAMPLE_ID_COL = "id"
LABEL_COL = "group"

OUT_DIR = "/home/fujing/ad_ssl/3_ssl_model"
os.makedirs(OUT_DIR, exist_ok=True)


In [None]:
# 超参数配置
CONFIG = {
    # 基础参数
    "RANDOM_SEED": 2025,
    "N_FOLDS": 5,
    "MAX_EPOCHS": 100,
    "PATIENCE": 30,

    # 预处理参数       
    "SELECT_TOP_VAR": 3000,
    "LOG1P": True,
    "Z_SCORE_PER_GENE": True,
    "Z_SCORE_PER_SAMPLE": True, 

    # 数据增强
    "USE_MIXUP": True,
    "MIXUP_ALPHA": 0.2,
    "USE_GAUSSIAN_NOISE": True,
    "NOISE_STD": 0.05,
    "USE_FEATURE_DROPOUT": True,
    "FEATURE_DROPOUT_RATE": 0.1,
    
    # 损失函数
    "USE_FOCAL_LOSS": True,
    "FOCAL_ALPHA": 0.25,
    "FOCAL_GAMMA": 2.0,
    "USE_LABEL_SMOOTHING": True,
    "LABEL_SMOOTHING": 0.1,
    
    # 对比学习
    "SUPCON_WEIGHT": 0.5,
    "SUPCON_TEMP": 0.05,
    
    # 训练策略
    "LR": 1e-3,
    "MIN_LR": 1e-6,
    "WEIGHT_DECAY": 1e-5,
    "BATCH_SIZE": 64,
    "WARMUP_EPOCHS": 5,
    "USE_SWA": False,
    "SWA_START": 70,
    
    # 模型架构
    "HIDDEN_DIMS": [512, 256, 128, 64],
    "DROPOUT": 0.4,
    "PROJ_DIM": 256,
    "NORM_FEATS": True,  
    "USE_BATCH_NORM": True,
}


In [None]:
# tools
def seed_everything(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def get_device():
    return "cuda" if torch.cuda.is_available() else "cpu"


def _detect_delimiter(path: str):
    if path.endswith((".tsv", ".txt")):
        return "\t"
    return ","


def read_expression(expr_file: str) -> pd.DataFrame:
    delim = "\t" if expr_file.endswith((".tsv", ".txt")) else ","
    df = pd.read_csv(expr_file, sep=delim, header=0, dtype=str, low_memory=False)

    df.columns = df.columns.astype(str).str.strip().str.replace('"', "")
    first_col = df.columns[0]

    is_numeric_firstcol = df[first_col].str.replace('"', "").str.match(r"^-?\d+(\.\d+)?$").all()
    if not is_numeric_firstcol:
        df[first_col] = df[first_col].astype(str).str.strip().str.replace('"', "")
        df = df.set_index(first_col)

    for c in df.columns:
        df[c] = pd.to_numeric(df[c].astype(str).str.strip().str.replace('"', ""), errors="coerce")

    df = df.dropna(how="all").fillna(0.0)
    df = df.loc[(df.sum(axis=1) > 0)]
    return df


def read_labels(label_file: str, sample_col: str, label_col: str) -> pd.Series:
    delim = _detect_delimiter(label_file)
    tab = pd.read_csv(label_file, sep=delim, header=0)
    tab = tab[[sample_col, label_col]].dropna()
    y_raw = tab[label_col].astype(str).str.upper().str.strip()
    if set(y_raw.unique()) <= {"0","1"}:
        y_num = y_raw.astype(int)
    else:
        y_num = y_raw.replace({"AD":1, "CASE":1, "PATIENT":1, "POS":1,
                               "HC":0, "CTRL":0, "CONTROL":0, "NEG":0}).astype(int)
    y = pd.Series(y_num.values, index=tab[sample_col].astype(str).values, name="label")
    return y


def align_expr_labels(expr: pd.DataFrame, labels: pd.Series=None):
    samples = expr.columns.astype(str)
    if labels is None:
        y = pd.Series(np.zeros(len(samples), dtype=int), index=samples, name="label")
    else:
        labels.index = labels.index.astype(str)
        common = samples.intersection(labels.index)
        expr = expr[common]
        y = labels.loc[common]
    return expr, y


def select_top_var_genes(expr: pd.DataFrame, n_top: int=None) -> pd.DataFrame:
    if not n_top or n_top >= expr.shape[0]:
        return expr
    var = expr.var(axis=1)
    top_idx = var.sort_values(ascending=False).head(n_top).index
    return expr.loc[top_idx]


def preprocess_expr(expr: pd.DataFrame, log1p=True, zscore_per_gene=True, stats=None) -> pd.DataFrame:
    X = expr.copy()
    if log1p:
        X = np.log1p(X)
    if zscore_per_gene:
        if stats is not None:
            mean, std = stats
        else:
            mean = X.mean(axis=1)
            std  = X.std(axis=1, ddof=0)
        std = std.replace(0, 1.0)
        X = (X.sub(mean, axis=0)).div(std, axis=0)
    return X


def load_and_preprocess_raw_data(expr_file, label_file, sample_id_col, label_col, config=CONFIG):
    
    expr_df = read_expression(expr_file)
    labels_series = read_labels(label_file, sample_id_col, label_col)
    
    expr_aligned, labels_aligned = align_expr_labels(expr_df, labels_series)

    expr_top_var = select_top_var_genes(expr_aligned)

    expr_for_stats = np.log1p(expr_top_var) if config["LOG1P"] else expr_top_var
    mean_all = expr_for_stats.mean(axis=1)
    std_all  = expr_for_stats.std(axis=1, ddof=0).replace(0, 1.0) # ddof=0 avoids NaN for single columns
    stats = (mean_all, std_all)

    expr_processed = preprocess_expr(expr_top_var, log1p=config["LOG1P"], zscore_per_gene=config["Z_SCORE_PER_GENE"], stats=stats)

    features_matrix = expr_processed.T.values # Shape: (n_samples, n_genes)
    labels_vector = labels_aligned.values    # Shape: (n_samples,)
    gene_names = expr_processed.index.tolist()  # 基因名称列表

    print(f"Final features matrix shape: {features_matrix.shape}")
    print(f"Final labels vector shape: {labels_vector.shape}")
    print(f"Unique labels: {np.unique(labels_vector)}")
    print(f"Number of genes: {len(gene_names)}")
    
    return features_matrix, labels_vector, gene_names

In [None]:
# dedine model and training
class GeneDataset(Dataset):

    def __init__(self, features, labels):
        self.features = torch.tensor(features, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]


class GeneClassifier(nn.Module):

    def __init__(self, input_dim, hidden_dims, num_classes=2,
                 dropout_rate=0.3, use_batch_norm=True, proj_dim=256):
        super(GeneClassifier, self).__init__()

        feat_layers = []
        prev_dim = input_dim
        for hidden_dim in hidden_dims:
            feat_layers.append(nn.Linear(prev_dim, hidden_dim))
            if use_batch_norm:
                feat_layers.append(nn.BatchNorm1d(hidden_dim))
            feat_layers.append(nn.ReLU())
            feat_layers.append(nn.Dropout(dropout_rate))
            prev_dim = hidden_dim
        self.feature_extractor = nn.Sequential(*feat_layers)
        self.feat_dim = prev_dim

        self.classifier = nn.Linear(self.feat_dim, num_classes)

        self.proj_head = nn.Sequential(
            nn.Linear(self.feat_dim, self.feat_dim),
            nn.ReLU(inplace=True),
            nn.Linear(self.feat_dim, proj_dim)
        )

    def forward(self, x, return_feats=False, return_proj=False):
        feats = self.feature_extractor(x)        # (B, feat_dim)
        logits = self.classifier(feats)          # (B, num_classes)
        if return_proj:
            z = self.proj_head(feats)            # (B, proj_dim)
            return logits, feats, z
        if return_feats:
            return logits, feats
        return logits


class SupConLoss(nn.Module):

    def __init__(self, temperature=0.07, eps=1e-8, normalize=True):
        super().__init__()
        self.tau = temperature
        self.eps = eps
        self.normalize = normalize

    def forward(self, features: torch.Tensor, labels: torch.Tensor):

        device = features.device
        B = features.size(0)
        if self.normalize:
            features = torch.nn.functional.normalize(features, dim=1)

        logits = torch.div(features @ features.t(), self.tau)

        logits = logits - torch.eye(B, device=device) * 1e9  

        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels, labels.t()).float().to(device)
        mask = mask - torch.eye(B, device=device)

        log_prob = torch.nn.functional.log_softmax(logits, dim=1)

        positives_per_sample = mask.sum(dim=1)  # (B,)
        loss = -(mask * log_prob).sum(dim=1) / (positives_per_sample + self.eps)
        valid = (positives_per_sample > 0).float()
        denom = valid.sum() + self.eps
        return (loss * valid).sum() / denom


def evaluate_model(model, dataloader, device, criterion):
    """Evaluate the model on a given dataloader."""
    model.eval()
    total_loss = 0.0
    all_preds = []
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for batch_features, batch_labels in dataloader:
            batch_features, batch_labels = batch_features.to(device), batch_labels.to(device)
            outputs = model(batch_features)
            loss = criterion(outputs, batch_labels)
            total_loss += loss.item()

            probs = torch.softmax(outputs, dim=1)[:, 1].cpu().numpy()  # Prob of positive class
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            
            all_preds.extend(preds)
            all_labels.extend(batch_labels.cpu().numpy())
            all_probs.extend(probs)

    if len(dataloader) == 0:
        print("Warning: DataLoader is empty in evaluate function.")
        return float('nan'), float('nan'), float('nan'), float('nan'), [], [], []

    avg_loss = total_loss / len(dataloader)
    acc = accuracy_score(all_labels, all_preds)
    
    auc = 0.0
    f1 = 0.0
    try:
        if len(np.unique(all_labels)) > 1:
            auc = roc_auc_score(all_labels, all_probs)
            f1 = f1_score(all_labels, all_preds)
        else:
            print(f"Warning: Only one unique label ({np.unique(all_labels)}) in evaluation batch. Cannot calculate AUC/F1.")
    except ValueError as e:
        print(f"Could not calculate AUC/F1: {e}")

    return avg_loss, acc, auc, f1, all_preds, all_labels, all_probs

In [None]:
# class DataAugmentation:
class DataAugmentation:
    def __init__(self, config):
        self.config = config
        
    def mixup(self, x, y, alpha=0.2):
        if alpha > 0:
            lam = np.random.beta(alpha, alpha)
        else:
            lam = 1
        
        batch_size = x.size(0)
        index = torch.randperm(batch_size).to(x.device)
        
        mixed_x = lam * x + (1 - lam) * x[index]
        y_a, y_b = y, y[index]
        return mixed_x, y_a, y_b, lam
    
    def gaussian_noise(self, x, std=0.05):
        noise = torch.randn_like(x) * std
        return x + noise
    
    def feature_dropout(self, x, dropout_rate=0.1):
        mask = torch.bernoulli(torch.ones_like(x) * (1 - dropout_rate))
        return x * mask
    
    def apply_augmentation(self, x, training=True):
        if not training:
            return x
        
        if self.config.get("USE_GAUSSIAN_NOISE", False):
            if np.random.rand() > 0.5: 
                x = self.gaussian_noise(x, self.config.get("NOISE_STD", 0.05))
        
        if self.config.get("USE_FEATURE_DROPOUT", False):
            if np.random.rand() > 0.5:
                x = self.feature_dropout(x, self.config.get("FEATURE_DROPOUT_RATE", 0.1))
        
        return x


class FocalLoss(nn.Module):
    
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, inputs, targets):
        ce_loss = nn.functional.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss


class LabelSmoothingCrossEntropy(nn.Module):
    
    def __init__(self, smoothing=0.1, weight=None):
        super().__init__()
        self.smoothing = smoothing
        self.weight = weight
    
    def forward(self, pred, target):
        n_classes = pred.size(-1)
        log_pred = nn.functional.log_softmax(pred, dim=-1)
        
        with torch.no_grad():
            true_dist = torch.zeros_like(log_pred)
            true_dist.fill_(self.smoothing / (n_classes - 1))
            true_dist.scatter_(1, target.unsqueeze(1), 1.0 - self.smoothing)
        
        loss = -true_dist * log_pred
        
        if self.weight is not None:
            loss = loss * self.weight.unsqueeze(0)
        
        return loss.sum(dim=-1).mean()


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


In [None]:
# define train_model
def train_model(train_loader, val_loader, input_dim, num_classes, device, 
                class_weights, config=CONFIG):
    
    model = GeneClassifier(
        input_dim=input_dim, 
        hidden_dims=config["HIDDEN_DIMS"], 
        num_classes=num_classes, 
        dropout_rate=config["DROPOUT"], 
        use_batch_norm=config["USE_BATCH_NORM"], 
        proj_dim=config["PROJ_DIM"]
    ).to(device)
    
    augmenter = DataAugmentation(config)
    
    if config["USE_FOCAL_LOSS"]:
        ce_criterion = FocalLoss(alpha=config["FOCAL_ALPHA"], gamma=config["FOCAL_GAMMA"])
    elif config["USE_LABEL_SMOOTHING"]:
        cw = torch.tensor([class_weights[0], class_weights[1]], dtype=torch.float, device=device)
        ce_criterion = LabelSmoothingCrossEntropy(smoothing=config["LABEL_SMOOTHING"], weight=cw)
    else:
        cw = torch.tensor([class_weights[0], class_weights[1]], dtype=torch.float, device=device)
        ce_criterion = nn.CrossEntropyLoss(weight=cw)
    
    supcon_criterion = SupConLoss(temperature=config["SUPCON_TEMP"], normalize=config["NORM_FEATS"])
    optimizer = optim.AdamW(model.parameters(), lr=config["LR"], weight_decay=config["WEIGHT_DECAY"])
    
    def warmup_lambda(epoch):
        if epoch < config["WARMUP_EPOCHS"]:
            return (epoch + 1) / config["WARMUP_EPOCHS"]
        return 1.0
    
    warmup_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_lambda)
    cosine_scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=config["MAX_EPOCHS"] - config["WARMUP_EPOCHS"], eta_min=config["MIN_LR"]
    )
    
    best_val_loss = float('inf')
    best_val_auc = 0.0
    epochs_without_improvement = 0
    best_model_state = None
    
    for epoch in range(1, config["MAX_EPOCHS"] + 1):
        model.train()
        train_loss_sum = 0.0
        train_correct = 0
        train_count = 0
        
        for batch_features, batch_labels in train_loader:
            batch_features = batch_features.to(device)
            batch_labels = batch_labels.to(device)
            bs = batch_labels.size(0)
            
            optimizer.zero_grad()
            
            if config["USE_MIXUP"] and np.random.rand() > 0.5:
                mixed_x, y_a, y_b, lam = augmenter.mixup(
                    batch_features, batch_labels, alpha=config["MIXUP_ALPHA"]
                )
                mixed_x = augmenter.apply_augmentation(mixed_x, training=True)
                
                logits, _, z = model(mixed_x, return_proj=True)
                
                # Mixup loss
                ce_loss = mixup_criterion(ce_criterion, logits, y_a, y_b, lam)
                sc_loss = supcon_criterion(z, batch_labels)
            else:
                aug_x = augmenter.apply_augmentation(batch_features, training=True)
                logits, _, z = model(aug_x, return_proj=True)
                
                ce_loss = ce_criterion(logits, batch_labels)
                sc_loss = supcon_criterion(z, batch_labels)
            
            loss = ce_loss + config["SUPCON_WEIGHT"] * sc_loss
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
            optimizer.step()
            
            train_loss_sum += loss.item() * bs
            train_correct += (logits.argmax(dim=1) == batch_labels).sum().item()
            train_count += bs
        
        avg_train_loss = train_loss_sum / max(1, train_count)
        train_acc = train_correct / max(1, train_count)
        
        val_loss, val_acc, val_auc, val_f1, _, _, _ = evaluate_model(
            model, val_loader, device, nn.CrossEntropyLoss()
        )
        
        if epoch <= config["WARMUP_EPOCHS"]:
            warmup_scheduler.step()
        else:
            cosine_scheduler.step()
        
        current_lr = optimizer.param_groups[0]['lr']
        
        if epoch % 5 == 0 or epoch == 1:
            print(f"[Epoch {epoch:03d}] LR: {current_lr:.6f} | "
                  f"TrainLoss: {avg_train_loss:.4f}, TrainAcc: {train_acc:.3f} | "
                  f"ValLoss: {val_loss:.4f}, ValAcc: {val_acc:.3f}, ValAUC: {val_auc:.3f}")
        
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            best_val_loss = val_loss
            epochs_without_improvement = 0
            best_model_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
        else:
            epochs_without_improvement += 1
        
        if epochs_without_improvement >= config["PATIENCE"]:
            print(f"Early stopping triggered after {epoch} epochs.")
            break
    
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    
    return model, best_val_loss, best_val_auc


In [None]:
# define run_classification
def run_improved_classification():
    
    seed_everything(CONFIG["RANDOM_SEED"])
    device = get_device()
    
    features, labels, _ = load_and_preprocess_raw_data(EXPR_FILE, LABEL_FILE, SAMPLE_ID_COL, LABEL_COL)
    n_classes = len(np.unique(labels))
    
    classes = np.unique(labels)
    class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=labels)
    class_weight_dict = {int(cls): float(weight) for cls, weight in zip(classes, class_weights)}
    
    skf = StratifiedKFold(n_splits=CONFIG["N_FOLDS"], shuffle=True, random_state=CONFIG["RANDOM_SEED"])
    fold_results = []
    
    for fold, (train_idx, test_idx) in enumerate(skf.split(features, labels), 1):

        X_train_fold, X_test_fold = features[train_idx], features[test_idx]
        y_train_fold, y_test_fold = labels[train_idx], labels[test_idx]
        
        skf_inner = StratifiedKFold(n_splits=5, shuffle=True, random_state=CONFIG["RANDOM_SEED"] + fold)
        inner_splits = list(skf_inner.split(X_train_fold, y_train_fold))
        inner_train_idx, inner_val_idx = inner_splits[0]
        
        X_tr, X_val = X_train_fold[inner_train_idx], X_train_fold[inner_val_idx]
        y_tr, y_val = y_train_fold[inner_train_idx], y_train_fold[inner_val_idx]
        
        train_dataset = GeneDataset(X_tr, y_tr)
        val_dataset = GeneDataset(X_val, y_val)
        test_dataset = GeneDataset(X_test_fold, y_test_fold)
        
        train_loader = DataLoader(train_dataset, batch_size=CONFIG["BATCH_SIZE"], shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=CONFIG["BATCH_SIZE"], shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=CONFIG["BATCH_SIZE"], shuffle=False)
        
        start_time = time.time()
        trained_model, best_val_loss, best_val_auc = train_model(
            train_loader, val_loader, features.shape[1], n_classes, device, class_weight_dict
        )
        training_time = time.time() - start_time
        
        test_loss, test_acc, test_auc, test_f1, test_preds, test_labels, test_probs = evaluate_model(
            trained_model, test_loader, device, nn.CrossEntropyLoss()
        )
        
        
        fold_results.append({
            "fold": fold,
            "test_loss": test_loss,
            "test_acc": test_acc,
            "test_auc": test_auc,
            "test_f1": test_f1,
            "best_val_auc": best_val_auc,
            "training_time": training_time,
            "y_true": list(map(int, test_labels)),
            "y_prob": list(map(float, test_probs)),
            "y_pred": list(map(int, test_preds))
        })
        
        fold_model_path = os.path.join(OUT_DIR, f"improved_model_fold_{fold}.pth")
        torch.save(trained_model.state_dict(), fold_model_path)
    
    test_accs = [r["test_acc"] for r in fold_results]
    test_aucs = [r["test_auc"] for r in fold_results]
    test_f1s = [r["test_f1"] for r in fold_results]
    
    mean_acc = np.mean(test_accs)
    std_acc = np.std(test_accs)
    mean_auc = np.mean(test_aucs)
    std_auc = np.std(test_aucs)
    mean_f1 = np.mean(test_f1s)
    std_f1 = np.std(test_f1s)
    
    print(f"Accuracy:  {mean_acc:.4f} ± {std_acc:.4f}")
    print(f"AUC:       {mean_auc:.4f} ± {std_auc:.4f}")
    print(f"F1-Score:  {mean_f1:.4f} ± {std_f1:.4f}")
    
    summary = {
        "config": CONFIG,
        "fold_results": fold_results,
        "overall_metrics": {
            "mean_acc": mean_acc,
            "std_acc": std_acc,
            "mean_auc": mean_auc,
            "std_auc": std_auc,
            "mean_f1": mean_f1,
            "std_f1": std_f1
        }
    }
    
    summary_path = os.path.join(OUT_DIR, "improved_classification_results.json")
    with open(summary_path, 'w') as f:
        json.dump(summary, f, indent=2)
    
    return summary

In [None]:
# define train_model
def train_ensemble_models(n_models=10):

    device = get_device()
    
    features, labels, _ = load_and_preprocess_raw_data(EXPR_FILE, LABEL_FILE, SAMPLE_ID_COL, LABEL_COL)
    n_classes = len(np.unique(labels))
    
    classes = np.unique(labels)
    class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=labels)
    class_weight_dict = {int(cls): float(weight) for cls, weight in zip(classes, class_weights)}
    
    skf = StratifiedKFold(n_splits=CONFIG["N_FOLDS"], shuffle=True, random_state=CONFIG["RANDOM_SEED"])
    
    ensemble_fold_results = []
    
    for fold, (train_idx, test_idx) in enumerate(skf.split(features, labels), 1):
        
        X_train_fold, X_test_fold = features[train_idx], features[test_idx]
        y_train_fold, y_test_fold = labels[train_idx], labels[test_idx]
        
        fold_model_probs = []
        fold_models = []
        
        for model_idx in range(n_models):
            
            seed_everything(CONFIG["RANDOM_SEED"] + fold * 100 + model_idx)
            
            skf_inner = StratifiedKFold(n_splits=5, shuffle=True, random_state=CONFIG["RANDOM_SEED"] + fold * 100 + model_idx)
            inner_splits = list(skf_inner.split(X_train_fold, y_train_fold))
            inner_train_idx, inner_val_idx = inner_splits[0]
            
            X_tr, X_val = X_train_fold[inner_train_idx], X_train_fold[inner_val_idx]
            y_tr, y_val = y_train_fold[inner_train_idx], y_train_fold[inner_val_idx]
            
            train_dataset = GeneDataset(X_tr, y_tr)
            val_dataset = GeneDataset(X_val, y_val)
            test_dataset = GeneDataset(X_test_fold, y_test_fold)
            
            train_loader = DataLoader(train_dataset, batch_size=CONFIG["BATCH_SIZE"], shuffle=True)
            val_loader = DataLoader(val_dataset, batch_size=CONFIG["BATCH_SIZE"], shuffle=False)
            test_loader = DataLoader(test_dataset, batch_size=CONFIG["BATCH_SIZE"], shuffle=False)
            
            trained_model, _, _ = train_model(
                train_loader, val_loader, features.shape[1], n_classes, device, class_weight_dict
            )

            
            _, _, _, _, _, _, test_probs = evaluate_model(
                trained_model, test_loader, device, nn.CrossEntropyLoss()
            )
            fold_model_probs.append(test_probs)
            fold_models.append(trained_model)
        
        ensemble_probs = np.mean(fold_model_probs, axis=0)
        ensemble_preds = (ensemble_probs > 0.5).astype(int)
        
        test_acc = accuracy_score(y_test_fold, ensemble_preds)
        test_auc = roc_auc_score(y_test_fold, ensemble_probs)
        test_f1 = f1_score(y_test_fold, ensemble_preds)
        
        print(f"- Ensemble AUC: {test_auc:.4f}")
        print(f"- Ensemble Acc: {test_acc:.4f}")
        print(f"- Ensemble F1:  {test_f1:.4f}")
        
        individual_aucs = []
        for i, probs in enumerate(fold_model_probs):
            auc = roc_auc_score(y_test_fold, probs)
            individual_aucs.append(auc)

        ensemble_fold_results.append({
            "fold": fold,
            "test_acc": test_acc,
            "test_auc": test_auc,
            "test_f1": test_f1,
            "individual_aucs": individual_aucs,
            "ensemble_prob": ensemble_probs.tolist(),
            "ensemble_pred": ensemble_preds.tolist(),
            "y_true": y_test_fold.tolist()
        })
        
        for i, model in enumerate(fold_models):
            model_path = os.path.join(OUT_DIR, f"ensemble_fold{fold}_model{i+1}.pth")
            torch.save(model.state_dict(), model_path)
    
    ensemble_aucs = [r["test_auc"] for r in ensemble_fold_results]
    ensemble_accs = [r["test_acc"] for r in ensemble_fold_results]
    ensemble_f1s = [r["test_f1"] for r in ensemble_fold_results]
    
    mean_auc = np.mean(ensemble_aucs)
    std_auc = np.std(ensemble_aucs)
    mean_acc = np.mean(ensemble_accs)
    std_acc = np.std(ensemble_accs)
    mean_f1 = np.mean(ensemble_f1s)
    std_f1 = np.std(ensemble_f1s)
    
    print(f"Ensemble Accuracy:  {mean_acc:.4f} ± {std_acc:.4f}")
    print(f"Ensemble AUC:       {mean_auc:.4f} ± {std_auc:.4f}")
    print(f"Ensemble F1-Score:  {mean_f1:.4f} ± {std_f1:.4f}")
    
    # 保存结果
    summary = {
        "n_models": n_models,
        "fold_results": ensemble_fold_results,
        "overall_metrics": {
            "mean_acc": mean_acc,
            "std_acc": std_acc,
            "mean_auc": mean_auc,
            "std_auc": std_auc,
            "mean_f1": mean_f1,
            "std_f1": std_f1
        }
    }
    
    summary_path = os.path.join(OUT_DIR, f"ensemble_{n_models}models_results.json")
    with open(summary_path, 'w') as f:
        json.dump(summary, f, indent=2)
    
    return summary


In [None]:
# 可视化模块 - 所有绘图函数
def plot_three_metrics_separate_subplots(
    df_summary: pd.DataFrame, 
    fold_results_dict: Dict,
    figsize: Tuple[int, int] = (12, 4.5),
    colors: Dict[str, str] = None,
    save_path: Optional[str] = None,
    show: bool = True
) -> Tuple[plt.Figure, np.ndarray]:
    """
    绘制三个指标（Accuracy, AUC, F1）的柱状图
    
    Returns:
        fig, axes: matplotlib图形对象，可在notebook中进一步调整
    """
    if colors is None:
        colors = {"acc": "#4C72B0", "auc": "#55A868", "f1": "#C44E52"}
    
    models = df_summary.index.tolist()
    x = np.arange(len(models))
    width = 0.6

    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize)

    # Accuracy
    means_acc = df_summary["mean_acc"].values
    stds_acc = df_summary["std_acc"].values
    bars1 = ax1.bar(x, means_acc, yerr=stds_acc, capsize=6, alpha=0.9, color=colors["acc"])

    for j, model in enumerate(models):
        if model not in fold_results_dict:
            continue
        fold_results = fold_results_dict[model]["fold_results"]
        vals = [fr.get("test_acc") for fr in fold_results if "test_acc" in fr]
        jitter = np.random.uniform(-0.06, 0.06, size=len(vals))
        ax1.scatter(np.full(len(vals), x[j]) + jitter, vals, color="gray", alpha=0.75, s=20, zorder=3)

    for j, (m, s) in enumerate(zip(means_acc, stds_acc)):
        if np.isnan(m): continue
        ax1.text(x[j], m + s + 0.02, f"{m:.3f}±{s:.3f}", ha="center", va="bottom", fontsize=8)

    ax1.set_title("Accuracy")
    ax1.set_xticks(x)
    ax1.set_xticklabels(models, rotation=30, ha="right")
    ax1.set_ylabel("Accuracy")
    ax1.set_ylim(0.5, 0.75)

    # AUC
    means_auc = df_summary["mean_auc"].values
    stds_auc = df_summary["std_auc"].values
    bars2 = ax2.bar(x, means_auc, yerr=stds_auc, capsize=6, alpha=0.9, color=colors["auc"])

    for j, model in enumerate(models):
        if model not in fold_results_dict:
            continue
        fold_results = fold_results_dict[model]["fold_results"]
        vals = [fr.get("test_auc") for fr in fold_results if "test_auc" in fr and fr["test_auc"] is not None]
        jitter = np.random.uniform(-0.06, 0.06, size=len(vals))
        ax2.scatter(np.full(len(vals), x[j]) + jitter, vals, color="gray", alpha=0.75, s=20, zorder=3)

    for j, (m, s) in enumerate(zip(means_auc, stds_auc)):
        if np.isnan(m): continue
        ax2.text(x[j], m + s + 0.02, f"{m:.3f}±{s:.3f}", ha="center", va="bottom", fontsize=8)

    ax2.set_title("AUC")
    ax2.set_xticks(x)
    ax2.set_xticklabels(models, rotation=30, ha="right")
    ax2.set_ylabel("AUC")
    ax2.set_ylim(0.6, 0.8)

    # F1
    means_f1 = df_summary["mean_f1"].values
    stds_f1 = df_summary["std_f1"].values
    bars3 = ax3.bar(x, means_f1, yerr=stds_f1, capsize=6, alpha=0.9, color=colors["f1"])

    for j, model in enumerate(models):
        if model not in fold_results_dict:
            continue
        fold_results = fold_results_dict[model]["fold_results"]
        vals = [fr.get("test_f1") for fr in fold_results if "test_f1" in fr and fr["test_f1"] is not None]
        jitter = np.random.uniform(-0.06, 0.06, size=len(vals))
        ax3.scatter(np.full(len(vals), x[j]) + jitter, vals, color="gray", alpha=0.75, s=20, zorder=3)

    for j, (m, s) in enumerate(zip(means_f1, stds_f1)):
        if np.isnan(m): continue
        ax3.text(x[j], m + s + 0.02, f"{m:.3f}±{s:.3f}", ha="center", va="bottom", fontsize=8)

    ax3.set_title("F1-Score")
    ax3.set_xticks(x)
    ax3.set_xticklabels(models, rotation=30, ha="right")
    ax3.set_ylabel("F1")
    ax3.set_ylim(0, 0.7)

    fig.suptitle("Model Performance (mean ± sd with per-fold points)", fontsize=14, y=1.02)
    fig.tight_layout()
    
    if save_path:
        plt.savefig(save_path, bbox_inches="tight", dpi=300)
    if show:
        plt.show()
    
    return fig, (ax1, ax2, ax3)


def _collect_curves_from_folds(fold_pack):
    """收集所有折的y_true和y_prob对"""
    pairs = []
    for fr in fold_pack.get("fold_results", []):
        y_true = fr.get("y_true", None)
        y_prob = fr.get("y_prob") or fr.get("ensemble_prob")
        if y_true is None or y_prob is None:
            continue
        if len(set(y_true)) < 2:
            continue
        if len(y_true) != len(y_prob):
            continue
        pairs.append((np.array(y_true, dtype=int), np.array(y_prob, dtype=float)))
    return pairs


def _mean_std_curve_roc(pairs, n_points=101):
    """计算ROC曲线的均值和标准差"""
    if not pairs:
        return None
    fpr_grid = np.linspace(0, 1, n_points)
    tpr_mat = []
    auc_list = []
    for y, p in pairs:
        fpr, tpr, _ = roc_curve(y, p)
        tpr_interp = np.interp(fpr_grid, fpr, tpr)
        tpr_interp[0] = 0.0
        tpr_mat.append(tpr_interp)
        auc_list.append(roc_auc_score(y, p))
    tpr_mat = np.vstack(tpr_mat)
    mean_tpr = tpr_mat.mean(axis=0)
    std_tpr = tpr_mat.std(axis=0)
    mean_auc = float(np.mean(auc_list))
    std_auc = float(np.std(auc_list))
    return {"x": fpr_grid, "mean": mean_tpr, "std": std_tpr, "mean_auc": mean_auc, "std_auc": std_auc}


def _mean_std_curve_pr(pairs, n_points=101):
    """计算PR曲线的均值和标准差"""
    if not pairs:
        return None
    recall_grid = np.linspace(0, 1, n_points)
    prec_mat = []
    ap_list = []
    for y, p in pairs:
        precision, recall, _ = precision_recall_curve(y, p)
        order = np.argsort(recall)
        recall_sorted = recall[order]
        precision_sorted = precision[order]
        prec_interp = np.interp(recall_grid, recall_sorted, precision_sorted)
        prec_mat.append(prec_interp)
        ap_list.append(average_precision_score(y, p))
    prec_mat = np.vstack(prec_mat)
    mean_prec = prec_mat.mean(axis=0)
    std_prec = prec_mat.std(axis=0)
    mean_ap = float(np.mean(ap_list))
    std_ap = float(np.std(ap_list))
    return {"x": recall_grid, "mean": mean_prec, "std": std_prec, "mean_ap": mean_ap, "std_ap": std_ap}


def plot_mean_roc_curve(
    fold_results_dict: Dict,
    figsize: Tuple[int, int] = (5, 5),
    save_path: Optional[str] = None,
    show: bool = True
) -> Tuple[plt.Figure, plt.Axes]:
    """绘制平均ROC曲线"""
    roc_data = {}
    for model_name, pack in fold_results_dict.items():
        pairs = _collect_curves_from_folds(pack)
        if not pairs:
            continue
        roc_data[model_name] = _mean_std_curve_roc(pairs)

    fig, ax = plt.subplots(figsize=figsize)
    for model_name, dat in roc_data.items():
        if dat is None:
            continue
        x = dat["x"]
        m = dat["mean"]
        s = dat["std"]
        ax.plot(x, m, label=f"{model_name} (AUC {dat['mean_auc']:.3f}±{dat['std_auc']:.3f})", linewidth=2)
        ax.fill_between(x, np.maximum(m - s, 0), np.minimum(m + s, 1), alpha=0.2)
    
    ax.plot([0, 1], [0, 1], linestyle="--", color="gray", label="Random")
    ax.set_xlabel("False Positive Rate", fontsize=11)
    ax.set_ylabel("True Positive Rate", fontsize=11)
    ax.set_title("Mean ROC Curves (mean ± 1 sd)", fontsize=12)
    ax.legend(loc="lower right", fontsize=9)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, bbox_inches="tight", dpi=300)
    if show:
        plt.show()
    
    return fig, ax


def plot_mean_pr_curve(
    fold_results_dict: Dict,
    figsize: Tuple[int, int] = (5, 5),
    save_path: Optional[str] = None,
    show: bool = True
) -> Tuple[plt.Figure, plt.Axes]:
    """绘制平均PR曲线"""
    pr_data = {}
    pos_rate = None
    
    for model_name, pack in fold_results_dict.items():
        pairs = _collect_curves_from_folds(pack)
        if not pairs:
            continue
        all_y = np.concatenate([y for y, _ in pairs])
        pos_rate = float(all_y.mean()) if pos_rate is None else pos_rate
        pr_data[model_name] = _mean_std_curve_pr(pairs)

    fig, ax = plt.subplots(figsize=figsize)
    for model_name, dat in pr_data.items():
        if dat is None:
            continue
        x = dat["x"]
        m = dat["mean"]
        s = dat["std"]
        ax.plot(x, m, label=f"{model_name} (AP {dat['mean_ap']:.3f}±{dat['std_ap']:.3f})", linewidth=2)
        ax.fill_between(x, np.maximum(m - s, 0), np.minimum(m + s, 1), alpha=0.2)
    
    if pos_rate is not None:
        ax.hlines(pos_rate, 0, 1, linestyles="--", color="gray", label=f"Random (AP={pos_rate:.3f})")
    ax.set_xlabel("Recall", fontsize=11)
    ax.set_ylabel("Precision", fontsize=11)
    ax.set_title("Mean PR Curves (mean ± 1 sd)", fontsize=12)
    ax.legend(loc="lower left", fontsize=9)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, bbox_inches="tight", dpi=300)
    if show:
        plt.show()
    
    return fig, ax


In [None]:
# 可视化模块 - 混淆矩阵和t-SNE可视化
def _compute_confusion_stats(fold_results_dict, target_model, normalize=True, quiet=False):
    """计算混淆矩阵的统计信息"""
    if target_model not in fold_results_dict:
        if not quiet:
            print(f"Model '{target_model}' not found.")
        return None, None

    folds = fold_results_dict[target_model].get("fold_results", [])
    cm_list = []

    for fr in folds:
        y_true = fr.get("y_true")
        y_pred = fr.get("y_pred") or fr.get("ensemble_pred")
        if y_true is None or y_pred is None:
            continue
        cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
        if normalize:
            row_sum = cm.sum(axis=1, keepdims=True)
            row_sum[row_sum == 0] = 1
            cm = cm.astype("float") / row_sum
        cm_list.append(cm)

    if not cm_list:
        if not quiet:
            print(f"No valid confusion matrices found for '{target_model}'.")
        return None, None

    cm_array = np.stack(cm_list)
    mean_cm = cm_array.mean(axis=0)
    std_cm = cm_array.std(axis=0)
    return mean_cm, std_cm


def plot_confusion_matrix_single(
    fold_results_dict: Dict,
    target_model: str = "DL+SupCon",
    class_names: Tuple[str, str] = ("Negative", "Positive"),
    normalize: bool = True,
    figsize: Tuple[int, int] = (5, 4),
    cmap: str = "Blues",
    save_path: Optional[str] = None,
    show: bool = True
) -> Tuple[plt.Figure, plt.Axes]:
    """绘制单个模型的平均混淆矩阵"""
    mean_cm, std_cm = _compute_confusion_stats(fold_results_dict, target_model, normalize=normalize)
    if mean_cm is None:
        return None, None

    fig, ax = plt.subplots(figsize=figsize)
    im = ax.imshow(mean_cm, interpolation="nearest", cmap=cmap)
    ax.set_title(f"Mean Confusion Matrix ({target_model})", fontsize=13)
    plt.colorbar(im, ax=ax)

    tick_marks = np.arange(len(class_names))
    ax.set_xticks(tick_marks)
    ax.set_yticks(tick_marks)
    ax.set_xticklabels(class_names)
    ax.set_yticklabels(class_names)
    ax.set_xlabel("Predicted label", fontsize=11)
    ax.set_ylabel("True label", fontsize=11)


    for i in range(mean_cm.shape[0]):
        for j in range(mean_cm.shape[1]):
            ax.text(j, i, f"{mean_cm[i, j]:.2f}\n±{std_cm[i, j]:.2f}",
                    ha="center", va="center",
                    color="black",
                    bbox=dict(boxstyle="round,pad=0.3", facecolor="white", edgecolor="none", alpha=0.8),
                    fontsize=10)

    fig.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight", facecolor='white', edgecolor='none')
    if show:
        plt.show()
    
    return fig, ax


def plot_confusion_matrices_grid(
    fold_results_dict: Dict,
    target_models: Optional[List[str]] = None,
    class_names: Tuple[str, str] = ("Negative", "Positive"),
    normalize: bool = True,
    n_cols: int = 3,
    figsize_per_panel: Tuple[float, float] = (4.2, 4.2),
    cmap: str = "Blues",
    save_path: Optional[str] = None,
    save_individual_path: Optional[str] = None,
    show: bool = True
) -> Tuple[plt.Figure, np.ndarray]:
    """在同一张画布中绘制多个模型的平均混淆矩阵"""
    if target_models is None:
        target_models = list(fold_results_dict.keys())

    valid_results = []
    for model_name in target_models:
        mean_cm, std_cm = _compute_confusion_stats(
            fold_results_dict, model_name, normalize=normalize, quiet=True)
        if mean_cm is not None:
            valid_results.append((model_name, mean_cm, std_cm))
        else:
            print(f"⚠️ 跳过模型 '{model_name}'，缺少有效的折结果。")

    if not valid_results:
        print("未收集到任何有效的混淆矩阵。")
        return None, None

    n_models = len(valid_results)
    n_cols = max(1, min(n_cols, n_models))
    n_rows = math.ceil(n_models / n_cols)

    fig_width = figsize_per_panel[0] * n_cols
    fig_height = figsize_per_panel[1] * n_rows
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height))
    if n_rows == 1 and n_cols == 1:
        axes = np.array([axes])
    axes = axes.reshape(n_rows, n_cols)

    im = None
    for idx, (model_name, mean_cm, std_cm) in enumerate(valid_results):
        r, c = divmod(idx, n_cols)
        ax = axes[r, c]
        im = ax.imshow(mean_cm, interpolation="nearest", cmap=cmap)
        ax.set_title(model_name, fontsize=12)

        tick_marks = np.arange(len(class_names))
        ax.set_xticks(tick_marks)
        ax.set_yticks(tick_marks)
        ax.set_xticklabels(class_names)
        ax.set_yticklabels(class_names)
        ax.set_xlabel("Predicted", fontsize=10)
        ax.set_ylabel("True", fontsize=10)

    
        for i in range(mean_cm.shape[0]):
            for j in range(mean_cm.shape[1]):
                ax.text(j, i, f"{mean_cm[i, j]:.2f}\n±{std_cm[i, j]:.2f}",
                        ha="center", va="center",
                        color="black",
                    bbox=dict(boxstyle="round,pad=0.3", facecolor="white", edgecolor="none", alpha=0.8),
                        fontsize=9)

    total_slots = n_rows * n_cols
    for idx in range(len(valid_results), total_slots):
        r, c = divmod(idx, n_cols)
        axes[r, c].axis("off")

    if im is not None:
        cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.85)
        cbar.set_label("Normalized" if normalize else "Count", fontsize=10)

    fig.suptitle("Mean Confusion Matrices", fontsize=14, y=0.995)
    fig.tight_layout(rect=[0, 0, 1, 0.97])

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight", facecolor='white', edgecolor='none')

    # 保存单独的混淆矩阵
    if save_individual_path:
        os.makedirs(save_individual_path, exist_ok=True)
        for model_name, mean_cm, std_cm in valid_results:
            fig_individual, ax_individual = plt.subplots(figsize=(5, 4))
            im_individual = ax_individual.imshow(mean_cm, interpolation="nearest", cmap=cmap)

            ax_individual.set_title(f"{model_name} Confusion Matrix", fontsize=12)

            tick_marks = np.arange(len(class_names))
            ax_individual.set_xticks(tick_marks)
            ax_individual.set_yticks(tick_marks)
            ax_individual.set_xticklabels(class_names)
            ax_individual.set_yticklabels(class_names)
            ax_individual.set_xlabel("Predicted", fontsize=10)
            ax_individual.set_ylabel("True", fontsize=10)

        
            for i in range(mean_cm.shape[0]):
                for j in range(mean_cm.shape[1]):
                    ax_individual.text(j, i, f"{mean_cm[i, j]:.2f}\n±{std_cm[i, j]:.2f}",
                                       ha="center", va="center",
                                       color="black",
                    bbox=dict(boxstyle="round,pad=0.3", facecolor="white", edgecolor="none", alpha=0.8),
                                       fontsize=10)

            cbar_individual = fig_individual.colorbar(im_individual, ax=ax_individual, shrink=0.8)
            cbar_individual.set_label("Normalized" if normalize else "Count", fontsize=10)

            fig_individual.tight_layout()

            # 生成安全的文件名
            safe_model_name = model_name.replace("/", "_").replace("\\", "_").replace(" ", "_")
            individual_save_path = os.path.join(save_individual_path, f"confusion_matrix_{safe_model_name}.pdf")
            fig_individual.savefig(individual_save_path, dpi=300, bbox_inches="tight", facecolor='white', edgecolor='none')
            plt.close(fig_individual)  # 关闭单独的图以释放内存

    if show:
        plt.show()
    
    return fig, axes


def plot_tsne_comparison(
    features_raw: np.ndarray,
    features_learned: np.ndarray,
    labels: np.ndarray,
    random_state: int = 2025,
    figsize: Tuple[int, int] = (10, 4),
    save_path: Optional[str] = None,
    show: bool = True
) -> Tuple[plt.Figure, np.ndarray]:
    """对比原始特征和模型学习特征的t-SNE可视化"""
    # 原始特征
    scaler = StandardScaler()
    features_scaled = scaler.fit_transform(features_raw)
    tsne_raw = TSNE(n_components=2, random_state=random_state, perplexity=30)
    raw_2d = tsne_raw.fit_transform(features_scaled)

    # 学习特征
    tsne_learned = TSNE(n_components=2, random_state=random_state, perplexity=30)
    learned_2d = tsne_learned.fit_transform(features_learned)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
    
    # 原始特征
    for label in np.unique(labels):
        idx = labels == label
        ax1.scatter(raw_2d[idx, 0], raw_2d[idx, 1], label=f"Class {label}", s=15, alpha=0.7)
    ax1.set_title("t-SNE of Raw Gene Expression Features", fontsize=12)
    ax1.set_xlabel("Dim 1", fontsize=11)
    ax1.set_ylabel("Dim 2", fontsize=11)
    ax1.legend(fontsize=9)

    # 学习特征
    for label in np.unique(labels):
        idx = labels == label
        ax2.scatter(learned_2d[idx, 0], learned_2d[idx, 1], label=f"Class {label}", s=15, alpha=0.7)
    ax2.set_title("t-SNE of Learned Representation (DL+SupCon)", fontsize=12)
    ax2.set_xlabel("Dim 1", fontsize=11)
    ax2.set_ylabel("Dim 2", fontsize=11)
    ax2.legend(fontsize=9)

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight", facecolor='white', edgecolor='none')
    if show:
        plt.show()
    
    return fig, (ax1, ax2)


In [None]:
# 可解释性分析可视化函数
def plot_shap_summary(
    shap_values: np.ndarray,
    test_data: np.ndarray,
    gene_names: Optional[List[str]] = None,
    max_display: int = 20,
    plot_type: str = "summary",
    figsize: Tuple[int, int] = (6, 5),
    save_path: Optional[str] = None,
    show: bool = True
) -> Tuple[plt.Figure, plt.Axes]:
    """
    绘制SHAP summary图
    
    Args:
        shap_values: SHAP值数组 (n_samples, n_features)
        test_data: 测试数据 (n_samples, n_features)
        gene_names: 基因名称列表
        max_display: 显示的最大特征数
        plot_type: "summary" 或 "bar"
        figsize: 图形大小
        save_path: 保存路径
        show: 是否显示
    
    Returns:
        fig, ax: matplotlib图形对象
    """
    # SHAP的summary_plot不支持ax参数，它会自己创建图形
    # 我们需要先设置图形大小，然后让SHAP创建图形
    plt.figure(figsize=figsize)
    
    if plot_type == "summary":
        shap.summary_plot(
            shap_values, test_data, 
            feature_names=gene_names,
            max_display=max_display,
            show=False
        )
    elif plot_type == "bar":
        shap.summary_plot(
            shap_values, test_data,
            feature_names=gene_names,
            plot_type="bar",
            max_display=max_display,
            show=False
        )
    
    # 获取SHAP创建的图形对象
    fig = plt.gcf()
    ax = plt.gca()
    
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight", facecolor='white', edgecolor='none')
    if show:
        plt.show()
    # 注意：即使show=False，我们也保留图形对象，以便在notebook中进一步操作
    
    return fig, ax


def plot_shap_top_features(
    shap_importance: pd.DataFrame,
    top_n: int = 20,
    figsize: Tuple[int, int] = (6, 5),
    color: str = "#4C72B0",
    save_path: Optional[str] = None,
    show: bool = True
) -> Tuple[plt.Figure, plt.Axes]:
    """绘制SHAP Top N特征重要性柱状图"""
    top_features = shap_importance.head(top_n)
    
    fig, ax = plt.subplots(figsize=figsize)
    ax.barh(range(top_n), top_features['mean_abs_shap'].values, color=color)
    ax.set_yticks(range(top_n))
    ax.set_yticklabels(top_features['feature'].values, fontsize=9)
    ax.set_xlabel('Mean Absolute SHAP Value', fontsize=11)
    ax.set_title(f'Top {top_n} Features by SHAP', fontsize=12)
    ax.invert_yaxis()
    ax.grid(True, alpha=0.3, axis='x')
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight", facecolor='white', edgecolor='none')
    if show:
        plt.show()
    
    return fig, ax


def plot_lime_top_features(
    lime_importance: pd.DataFrame,
    top_n: int = 20,
    figsize: Tuple[int, int] = (6, 5),
    color: str = "#55A868",
    save_path: Optional[str] = None,
    show: bool = True
) -> Tuple[plt.Figure, plt.Axes]:
    """绘制LIME Top N特征重要性柱状图"""
    top_features = lime_importance.head(top_n)
    
    fig, ax = plt.subplots(figsize=figsize)
    ax.barh(range(top_n), top_features['mean_abs_importance'].values, color=color)
    ax.set_yticks(range(top_n))
    ax.set_yticklabels(top_features['feature'].values, fontsize=9)
    ax.set_xlabel('Mean Absolute LIME Importance', fontsize=11)
    ax.set_title(f'Top {top_n} Features by LIME', fontsize=12)
    ax.invert_yaxis()
    ax.grid(True, alpha=0.3, axis='x')
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight", facecolor='white', edgecolor='none')
    if show:
        plt.show()
    
    return fig, ax


def plot_lime_heatmap(
    feature_importance_matrix: np.ndarray,
    gene_names: Optional[List[str]] = None,
    top_n_genes: int = 30,
    n_samples: int = 50,
    figsize: Tuple[int, int] = (8, 6),
    cmap: str = 'RdBu_r',
    save_path: Optional[str] = None,
    show: bool = True
) -> Tuple[plt.Figure, plt.Axes]:
    """绘制LIME重要性热图"""
    if feature_importance_matrix.max() == 0:
        print("警告: 所有特征重要性为0，无法绘制热图")
        return None, None
    
    mean_importance = feature_importance_matrix.mean(axis=0)
    top_features_idx = np.argsort(mean_importance)[-top_n_genes:][::-1]
    
    n_samples_actual = min(n_samples, feature_importance_matrix.shape[0])
    sample_indices = np.random.choice(
        feature_importance_matrix.shape[0], 
        n_samples_actual, 
        replace=False
    )
    
    heatmap_data = feature_importance_matrix[sample_indices][:, top_features_idx]
    top_gene_names = [
        gene_names[i] if gene_names else f"Gene_{i}" 
        for i in top_features_idx
    ]
    
    fig, ax = plt.subplots(figsize=figsize)
    sns.heatmap(
        heatmap_data.T,
        yticklabels=top_gene_names,
        xticklabels=False,
        cmap=cmap,
        center=0,
        cbar_kws={'label': 'LIME Importance'},
        ax=ax
    )
    ax.set_xlabel('Samples', fontsize=11)
    ax.set_ylabel('Genes', fontsize=11)
    ax.set_title(f'LIME: Top {top_n_genes} Genes Importance Heatmap', fontsize=12)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight", facecolor='white', edgecolor='none')
    if show:
        plt.show()
    
    return fig, ax


def plot_ablation_top_features(
    ablation_importance: pd.DataFrame,
    top_n: int = 20,
    figsize: Tuple[int, int] = (6, 5),
    color: str = "coral",
    save_path: Optional[str] = None,
    show: bool = True
) -> Tuple[plt.Figure, plt.Axes]:
    """绘制消融分析Top N特征重要性柱状图"""
    top_features = ablation_importance.head(top_n)
    
    fig, ax = plt.subplots(figsize=figsize)
    ax.barh(range(top_n), top_features['mean_abs_prob_change'].values, color=color)
    ax.set_yticks(range(top_n))
    ax.set_yticklabels(top_features['feature'].values, fontsize=9)
    ax.set_xlabel('Mean Absolute Probability Change', fontsize=11)
    ax.set_title(f'Top {top_n} Features by Feature Ablation', fontsize=12)
    ax.invert_yaxis()
    ax.grid(True, alpha=0.3, axis='x')
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight", facecolor='white', edgecolor='none')
    if show:
        plt.show()
    
    return fig, ax


def plot_ablation_distribution(
    feature_importance_scores: np.ndarray,
    figsize: Tuple[int, int] = (6, 4),
    color: str = "coral",
    bins: int = 50,
    save_path: Optional[str] = None,
    show: bool = True
) -> Tuple[plt.Figure, plt.Axes]:
    """绘制消融分析特征重要性分布图"""
    fig, ax = plt.subplots(figsize=figsize)
    ax.hist(feature_importance_scores, bins=bins, alpha=0.7, color=color, edgecolor='black')
    ax.set_xlabel('Absolute Probability Change', fontsize=11)
    ax.set_ylabel('Number of Features', fontsize=11)
    ax.set_title('Distribution of Feature Importance (Ablation Method)', fontsize=12)
    ax.axvline(
        feature_importance_scores.mean(), 
        color='red', 
        linestyle='--', 
        linewidth=2, 
        label=f'Mean: {feature_importance_scores.mean():.4f}'
    )
    ax.legend(fontsize=9)
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight", facecolor='white', edgecolor='none')
    if show:
        plt.show()
    
    return fig, ax


def plot_interpretability_comparison_top_features(
    merged_df: pd.DataFrame,
    top_n: int = 30,
    figsize: Tuple[int, int] = (14, 5),
    colors: Dict[str, str] = None,
    save_path: Optional[str] = None,
    show: bool = True
) -> Tuple[plt.Figure, np.ndarray]:
    """绘制三种方法Top N特征的对比图"""
    if colors is None:
        colors = {"shap": "#4C72B0", "lime": "#55A868", "ablation": "#C44E52"}
    
    top_features = merged_df.head(top_n)
    
    fig, axes = plt.subplots(1, 3, figsize=figsize)
    
    # SHAP
    axes[0].barh(range(top_n), top_features['shap_norm'].values, color=colors["shap"])
    axes[0].set_yticks(range(top_n))
    axes[0].set_yticklabels(top_features['feature'].values, fontsize=8)
    axes[0].set_xlabel('Normalized Importance', fontsize=10)
    axes[0].set_title('SHAP', fontsize=11)
    axes[0].invert_yaxis()
    axes[0].grid(True, alpha=0.3, axis='x')
    
    # LIME
    axes[1].barh(range(top_n), top_features['lime_norm'].values, color=colors["lime"])
    axes[1].set_yticks(range(top_n))
    axes[1].set_yticklabels(top_features['feature'].values, fontsize=8)
    axes[1].set_xlabel('Normalized Importance', fontsize=10)
    axes[1].set_title('LIME', fontsize=11)
    axes[1].invert_yaxis()
    axes[1].grid(True, alpha=0.3, axis='x')
    
    # Feature Ablation
    axes[2].barh(range(top_n), top_features['ablation_norm'].values, color=colors["ablation"])
    axes[2].set_yticks(range(top_n))
    axes[2].set_yticklabels(top_features['feature'].values, fontsize=8)
    axes[2].set_xlabel('Normalized Importance', fontsize=10)
    axes[2].set_title('Feature Ablation', fontsize=11)
    axes[2].invert_yaxis()
    axes[2].grid(True, alpha=0.3, axis='x')
    
    fig.suptitle(f'Top {top_n} Features by Different Interpretability Methods', fontsize=13, y=1.02)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight", facecolor='white', edgecolor='none')
    if show:
        plt.show()
    
    return fig, axes


def plot_interpretability_correlation(
    merged_df: pd.DataFrame,
    figsize: Tuple[int, int] = (14, 4),
    save_path: Optional[str] = None,
    show: bool = True
) -> Tuple[plt.Figure, np.ndarray]:
    """绘制三种方法之间的相关性散点图"""
    fig, axes = plt.subplots(1, 3, figsize=figsize)
    
    # SHAP vs LIME
    axes[0].scatter(merged_df['shap_norm'], merged_df['lime_norm'], alpha=0.5, s=20)
    corr_shap_lime = merged_df[['shap_norm', 'lime_norm']].corr().iloc[0, 1]
    axes[0].set_xlabel('SHAP (normalized)', fontsize=10)
    axes[0].set_ylabel('LIME (normalized)', fontsize=10)
    axes[0].set_title(f'SHAP vs LIME (corr={corr_shap_lime:.3f})', fontsize=11)
    axes[0].grid(True, alpha=0.3)
    
    # SHAP vs Ablation
    axes[1].scatter(merged_df['shap_norm'], merged_df['ablation_norm'], alpha=0.5, s=20)
    corr_shap_abl = merged_df[['shap_norm', 'ablation_norm']].corr().iloc[0, 1]
    axes[1].set_xlabel('SHAP (normalized)', fontsize=10)
    axes[1].set_ylabel('Feature Ablation (normalized)', fontsize=10)
    axes[1].set_title(f'SHAP vs Ablation (corr={corr_shap_abl:.3f})', fontsize=11)
    axes[1].grid(True, alpha=0.3)
    
    # LIME vs Ablation
    axes[2].scatter(merged_df['lime_norm'], merged_df['ablation_norm'], alpha=0.5, s=20)
    corr_lime_abl = merged_df[['lime_norm', 'ablation_norm']].corr().iloc[0, 1]
    axes[2].set_xlabel('LIME (normalized)', fontsize=10)
    axes[2].set_ylabel('Feature Ablation (normalized)', fontsize=10)
    axes[2].set_title(f'LIME vs Ablation (corr={corr_lime_abl:.3f})', fontsize=11)
    axes[2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight", facecolor='white', edgecolor='none')
    if show:
        plt.show()
    
    return fig, axes


def plot_interpretability_venn(
    shap_importance: pd.DataFrame,
    lime_importance: pd.DataFrame,
    ablation_importance: pd.DataFrame,
    top_n: int = 10,
    figsize: Tuple[int, int] = (6, 6),
    save_path: Optional[str] = None,
    show: bool = True
) -> Tuple[plt.Figure, plt.Axes]:
    """绘制三种方法Top N特征的Venn图"""
    shap_top = set(shap_importance.head(top_n)['feature'].values)
    lime_top = set(lime_importance.head(top_n)['feature'].values)
    abl_top = set(ablation_importance.head(top_n)['feature'].values)
    
    fig, ax = plt.subplots(figsize=figsize)
    venn3(
        [shap_top, lime_top, abl_top],
        set_labels=('SHAP', 'LIME', 'Feature\nAblation'),
        ax=ax
    )
    ax.set_title(f'Overlap of Top {top_n} Features Identified by Different Methods', fontsize=12)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight", facecolor='white', edgecolor='none')
    if show:
        plt.show()
    
    return fig, ax


In [None]:
# run baseline
def run_sklearn_baselines(features, labels, out_dir, n_folds=5, seed=2025):

    models = {
        "LogisticRegression(L2)": Pipeline([("scaler", StandardScaler(with_mean=True, with_std=True)), ("clf", LogisticRegression(penalty="l2", solver="lbfgs", max_iter=5))]), 
        "LinearSVM(Calibrated)": CalibratedClassifierCV(estimator=Pipeline([("scaler", StandardScaler(with_mean=True, with_std=True)), ("svc", LinearSVC(C=10, class_weight='balanced', max_iter=1000))]), method="sigmoid", cv=3),
        "SVM-RBF": Pipeline([("scaler", StandardScaler(with_mean=True, with_std=True)), ("clf", SVC(kernel="rbf", C=0.5, gamma="scale", probability=True))]),
        "RandomForest": RandomForestClassifier(n_estimators=50, max_depth=None, n_jobs=-1, class_weight="balanced_subsample", random_state=seed),}

    skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=seed)

    all_results = {}
    for model_name, model in models.items():
        fold_records = []
        print(f"\n=== Baseline: {model_name} ===")
        for fold, (tr_idx, te_idx) in enumerate(skf.split(features, labels), 1):
            X_tr, X_te = features[tr_idx], features[te_idx]
            y_tr, y_te = labels[tr_idx], labels[te_idx]

            t0 = time.time()
            model.fit(X_tr, y_tr)
            train_time = time.time() - t0

            # predictions and probabilities
            y_prob = None
            if hasattr(model, "predict_proba"):
                y_prob = model.predict_proba(X_te)[:, 1]
            elif hasattr(model, "decision_function"):
                scores = model.decision_function(X_te)
                y_prob = 1.0 / (1.0 + np.exp(-scores))
            else:
                y_prob = None

            y_pred = model.predict(X_te)

            acc = accuracy_score(y_te, y_pred)
            f1  = f1_score(y_te, y_pred) if len(np.unique(y_te)) > 1 else float("nan")
            auc = roc_auc_score(y_te, y_prob) if (y_prob is not None and len(np.unique(y_te)) > 1) else float("nan")

            print(f"Fold {fold}/{n_folds} | Acc={acc:.3f}, AUC={auc if not np.isnan(auc) else float('nan'):.3f}, F1={f1 if not np.isnan(f1) else float('nan'):.3f}")

            fold_records.append({
                "fold": fold,
                "test_acc": float(acc),
                "test_auc": float(auc) if not np.isnan(auc) else None,
                "test_f1":  float(f1)  if not np.isnan(f1)  else None,
                "y_true": y_te.astype(int).tolist(),
                "y_prob": y_prob.tolist() if y_prob is not None else None,
                "y_pred": y_pred.astype(int).tolist()
            })

        # summary
        accs = [r["test_acc"] for r in fold_records]
        aucs = [r["test_auc"] for r in fold_records if r["test_auc"] is not None]
        f1s  = [r["test_f1"]  for r in fold_records if r["test_f1"]  is not None]

        summary = {
            "mean_acc": float(np.mean(accs)),
            "std_acc":  float(np.std(accs)),
            "mean_auc": float(np.mean(aucs)) if len(aucs) else None,
            "std_auc":  float(np.std(aucs))  if len(aucs) else None,
            "mean_f1":  float(np.mean(f1s))  if len(f1s)  else None,
            "std_f1":   float(np.std(f1s))   if len(f1s)   else None,
        }

        all_results[model_name] = {
            "fold_results": fold_records,
            "overall_metrics": summary
        }

    # save JSON
    out_path = os.path.join(out_dir, "baselines_results.json")
    with open(out_path, "w") as f:
        json.dump(all_results, f, indent=2)
    print(f"\n[Baselines] Results saved to {out_path}")

    return all_results

In [None]:
ensemble_results_full = train_ensemble_models(n_models=10)
features, labels, _ = load_and_preprocess_raw_data(EXPR_FILE, LABEL_FILE, SAMPLE_ID_COL, LABEL_COL)
baseline_results = run_sklearn_baselines(features, labels, OUT_DIR, n_folds=CONFIG["N_FOLDS"], seed=CONFIG["RANDOM_SEED"])

In [None]:
# get summary
MAIN_JSON = os.path.join(OUT_DIR, "ensemble_10models_results.json")
BASE_JSON = os.path.join(OUT_DIR, "baselines_results.json")

def _safe_mean_std(d, mean_key, std_key):
    m = d.get(mean_key, np.nan)
    s = d.get(std_key,  np.nan)
    if m is None: m = np.nan
    if s is None: s = np.nan
    return float(m), float(s)


def load_models_summary(main_json=MAIN_JSON, base_json=BASE_JSON):
    rows = []
    if os.path.isfile(main_json):
        with open(main_json, "r") as f:
            main_sum = json.load(f)
        om = main_sum.get("overall_metrics", {})
        acc_m, acc_s = _safe_mean_std(om, "mean_acc", "std_acc")
        auc_m, auc_s = _safe_mean_std(om, "mean_auc", "std_auc")
        f1_m,  f1_s  = _safe_mean_std(om, "mean_f1",  "std_f1")
        rows.append({
            "model": "DL+SupCon",
            "mean_acc": acc_m, "std_acc": acc_s,
            "mean_auc": auc_m, "std_auc": auc_s,
            "mean_f1":  f1_m,  "std_f1":  f1_s
        })

    if os.path.isfile(base_json):
        with open(base_json, "r") as f:
            base_all = json.load(f)
        for name, pack in base_all.items():
            om = pack.get("overall_metrics", {})
            acc_m, acc_s = _safe_mean_std(om, "mean_acc", "std_acc")
            auc_m, auc_s = _safe_mean_std(om, "mean_auc", "std_auc")
            f1_m,  f1_s  = _safe_mean_std(om, "mean_f1",  "std_f1")
            rows.append({
                "model": name,
                "mean_acc": acc_m, "std_acc": acc_s,
                "mean_auc": auc_m, "std_auc": auc_s,
                "mean_f1":  f1_m,  "std_f1":  f1_s
            })

    df = pd.DataFrame(rows).set_index("model")
    df = df.loc[sorted(df.index, key=lambda x: (0 if x == "DL+SupCon" else 1, x))]
    return df


df = load_models_summary(MAIN_JSON, BASE_JSON)
table_csv = os.path.join(OUT_DIR, "metrics_summary_table.csv")
df.to_csv(table_csv)

In [None]:
# load data
fold_results_dict = {}
with open(MAIN_JSON) as f:
    main_res = json.load(f)
fold_results_dict["DL+SupCon"] = main_res

with open(BASE_JSON) as f:
    base_all = json.load(f)
for name, pack in base_all.items():
    fold_results_dict[name] = pack

plot_three_metrics_separate_subplots(df, fold_results_dict, save_path="/home/fujing/ad_ssl/3_ssl_model/metrics_comparison.pdf")
plot_mean_roc_curve(fold_results_dict, save_path="/home/fujing/ad_ssl/3_ssl_model/mean_roc_curve.pdf")
plot_mean_pr_curve(fold_results_dict, save_path="/home/fujing/ad_ssl/3_ssl_model/mean_pr_curve.pdf")

In [None]:
models_to_plot = [
    "DL+SupCon",
    "RandomForest",
    "LogisticRegression(L2)",
    "LinearSVM(Calibrated)",
    "SVM-RBF"
]

plot_confusion_matrices_grid(
    fold_results_dict,
    target_models=models_to_plot,
    class_names=("Control", "AD"),
    normalize=True,
    n_cols=3,
    save_path="/home/fujing/ad_ssl/3_ssl_model/confusion_matrix_all_models.pdf",
    save_individual_path="/home/fujing/ad_ssl/3_ssl_model/individual_confusion_matrices"
)


In [None]:
# 加载最后一折模型进行 t-SNE 可视化
device = get_device()
input_dim = features.shape[1]
features_raw = features 

# 重新创建模型结构并加载权重
model_path = os.path.join(OUT_DIR, "ensemble_fold3_model7.pth")  # 也可以换成表现最好的一折
model = GeneClassifier(input_dim=input_dim,
                       hidden_dims=CONFIG["HIDDEN_DIMS"],
                       num_classes=2,
                       dropout_rate=CONFIG["DROPOUT"],
                       use_batch_norm=CONFIG["USE_BATCH_NORM"],
                       proj_dim=CONFIG["PROJ_DIM"]).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))

# 创建完整数据的 DataLoader 用于提取特征
dataset_all = GeneDataset(features, labels)
loader_all = DataLoader(dataset_all, batch_size=128, shuffle=False)

In [None]:
# plot_tsne_comparison
def extract_features(model, dataloader, device):
    model.eval()
    all_feats = []
    all_labels = []
    with torch.no_grad():
        for x, y in dataloader:
            x = x.to(device)
            # 使用模型的feature_extractor提取中间层特征
            feats = model.feature_extractor(x)
            all_feats.append(feats.cpu().numpy())
            all_labels.append(y.numpy())
    feats = np.concatenate(all_feats, axis=0)
    labels = np.concatenate(all_labels, axis=0)
    return feats, labels

# 提取学习到的特征
features_learned, labels_extracted = extract_features(model, loader_all, device)

# 可视化
fig, (ax1, ax2) = plot_tsne_comparison(
    features_raw=features_raw,           # 原始特征矩阵 (n_samples, n_features)
    features_learned=features_learned,   # 模型学习到的特征 (n_samples, hidden_dim)
    labels=labels_extracted,            # 标签数组 (n_samples,)
    random_state=2025,                   # 随机种子，保证结果可复现
    figsize=(8, 4),                     # 图形大小 (宽度, 高度)
    save_path=os.path.join(OUT_DIR, "tsne_comparison.pdf"),  # 保存路径（可选）
    show=True                            # 是否在notebook中显示
)

ax1.set_title("t-SNE of Raw Gene Expression Features", fontsize=14)
ax2.set_title("t-SNE of Learned Representation (DL+SupCon)", fontsize=14)
ax1.tick_params(labelsize=11)
ax2.tick_params(labelsize=11)
plt.tight_layout()
plt.show()


In [None]:
# 额外的可解释性联合分析可视化函数

def plot_interpretability_heatmap(
    merged_df: pd.DataFrame,
    top_n: int = 30,
    figsize: Tuple[int, int] = (10, 12),
    save_path: Optional[str] = None,
    show: bool = True
) -> Tuple[plt.Figure, plt.Axes]:
    """
    绘制Top N基因在三种方法中的重要性热图
    
    显示每个基因在SHAP、LIME、Ablation中的标准化重要性分数
    """
    top_features = merged_df.head(top_n).copy()
    
    # 准备热图数据
    heatmap_data = top_features[['shap_norm', 'lime_norm', 'ablation_norm']].T
    heatmap_data.columns = top_features['feature'].values
    
    fig, ax = plt.subplots(figsize=figsize)
    sns.heatmap(
        heatmap_data,
        annot=True,
        fmt='.2f',
        cmap='YlOrRd',
        cbar_kws={'label': 'Normalized Importance'},
        yticklabels=['SHAP', 'LIME', 'Ablation'],
        xticklabels=heatmap_data.columns,
        ax=ax,
        linewidths=0.5
    )
    ax.set_title(f'Feature Importance Heatmap (Top {top_n} Genes)', fontsize=13, fontweight='bold')
    ax.set_xlabel('Genes', fontsize=11)
    ax.set_ylabel('Methods', fontsize=11)
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    if show:
        plt.show()
    
    return fig, ax


def plot_interpretability_ranking_comparison(
    merged_df: pd.DataFrame,
    top_n: int = 30,
    figsize: Tuple[int, int] = (12, 8),
    save_path: Optional[str] = None,
    show: bool = True
) -> Tuple[plt.Figure, plt.Axes]:
    """
    绘制特征重要性排名对比图
    
    显示每个基因在不同方法中的排名位置
    """
    top_features = merged_df.head(top_n).copy()
    
    # 计算排名
    top_features['shap_rank'] = range(1, len(top_features) + 1)
    top_features['lime_rank'] = top_features['lime_norm'].rank(ascending=False, method='min').astype(int)
    top_features['ablation_rank'] = top_features['ablation_norm'].rank(ascending=False, method='min').astype(int)
    
    # 准备数据
    genes = top_features['feature'].values
    x = np.arange(len(genes))
    width = 0.25
    
    fig, ax = plt.subplots(figsize=figsize)
    
    ax.bar(x - width, top_features['shap_rank'], width, label='SHAP', color='#4C72B0', alpha=0.8)
    ax.bar(x, top_features['lime_rank'], width, label='LIME', color='#55A868', alpha=0.8)
    ax.bar(x + width, top_features['ablation_rank'], width, label='Ablation', color='#C44E52', alpha=0.8)
    
    ax.set_xlabel('Genes', fontsize=11)
    ax.set_ylabel('Rank', fontsize=11)
    ax.set_title(f'Feature Importance Ranking Comparison (Top {top_n} Genes)', fontsize=13, fontweight='bold')
    ax.set_xticks(x)
    ax.set_xticklabels(genes, rotation=45, ha='right', fontsize=8)
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')
    ax.invert_yaxis()  # 排名越小越好，所以反转y轴
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    if show:
        plt.show()
    
    return fig, ax


def plot_consensus_score_heatmap(
    merged_df: pd.DataFrame,
    top_n: int = 50,
    figsize: Tuple[int, int] = (12, 8),
    save_path: Optional[str] = None,
    show: bool = True
) -> Tuple[plt.Figure, plt.Axes]:
    """
    绘制一致性得分热图
    
    计算每个基因在三种方法中的一致性得分（有多少种方法认为它重要）
    """
    top_features = merged_df.head(top_n).copy()
    
    # 计算一致性得分：如果标准化重要性 > 阈值，则认为该方法认为该基因重要
    threshold = 0.3  # 可以调整
    top_features['shap_important'] = (top_features['shap_norm'] > threshold).astype(int)
    top_features['lime_important'] = (top_features['lime_norm'] > threshold).astype(int)
    top_features['ablation_important'] = (top_features['ablation_norm'] > threshold).astype(int)
    
    # 一致性得分 = 三种方法中认为重要的数量
    top_features['consensus_score'] = (
        top_features['shap_important'] + 
        top_features['lime_important'] + 
        top_features['ablation_important']
    )
    
    # 准备热图数据
    heatmap_data = top_features[['shap_important', 'lime_important', 'ablation_important', 'consensus_score']].T
    heatmap_data.columns = top_features['feature'].values
    
    # 按一致性得分排序（使用列名而不是位置索引）
    sorted_features = top_features.sort_values('consensus_score', ascending=False)['feature'].values
    heatmap_data = heatmap_data[sorted_features]
    
    fig, ax = plt.subplots(figsize=figsize)
    sns.heatmap(
        heatmap_data,
        annot=True,
        fmt='d',
        cmap='RdYlGn',
        cbar_kws={'label': 'Important (1) or Not (0)'},
        yticklabels=['SHAP', 'LIME', 'Ablation', 'Consensus'],
        xticklabels=heatmap_data.columns,
        ax=ax,
        linewidths=0.5,
        vmin=0,
        vmax=3
    )
    ax.set_title(f'Consensus Score Heatmap (Top {top_n} Genes, Threshold={threshold})', 
                fontsize=13, fontweight='bold')
    ax.set_xlabel('Genes', fontsize=11)
    ax.set_ylabel('Methods', fontsize=11)
    plt.xticks(rotation=45, ha='right', fontsize=8)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    if show:
        plt.show()
    
    return fig, ax


def plot_cumulative_importance(
    merged_df: pd.DataFrame,
    top_n: int = 50,
    figsize: Tuple[int, int] = (10, 6),
    save_path: Optional[str] = None,
    show: bool = True
) -> Tuple[plt.Figure, plt.Axes]:
    """
    绘制累积重要性图
    
    显示Top N基因累积解释了多少预测重要性
    """
    top_features = merged_df.head(top_n).copy()
    
    # 计算累积重要性（归一化到0-1）
    cumulative_shap = top_features['shap_norm'].cumsum() / top_features['shap_norm'].sum()
    cumulative_lime = top_features['lime_norm'].cumsum() / top_features['lime_norm'].sum()
    cumulative_ablation = top_features['ablation_norm'].cumsum() / top_features['ablation_norm'].sum()
    cumulative_avg = top_features['avg_importance'].cumsum() / top_features['avg_importance'].sum()
    
    x = np.arange(1, len(top_features) + 1)
    
    fig, ax = plt.subplots(figsize=figsize)
    ax.plot(x, cumulative_shap, label='SHAP', linewidth=2, marker='o', markersize=4)
    ax.plot(x, cumulative_lime, label='LIME', linewidth=2, marker='s', markersize=4)
    ax.plot(x, cumulative_ablation, label='Ablation', linewidth=2, marker='^', markersize=4)
    ax.plot(x, cumulative_avg, label='Average', linewidth=2, linestyle='--', color='black')
    
    # 添加80%和90%的参考线
    ax.axhline(y=0.8, color='gray', linestyle=':', alpha=0.7, label='80%')
    ax.axhline(y=0.9, color='gray', linestyle=':', alpha=0.7, label='90%')
    
    ax.set_xlabel('Number of Top Features', fontsize=11)
    ax.set_ylabel('Cumulative Normalized Importance', fontsize=11)
    ax.set_title(f'Cumulative Importance of Top {top_n} Features', fontsize=13, fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_ylim(0, 1.05)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    if show:
        plt.show()
    
    return fig, ax


def plot_3d_importance_scatter(
    merged_df: pd.DataFrame,
    top_n: int = 50,
    figsize: Tuple[int, int] = (12, 10),
    save_path: Optional[str] = None,
    show: bool = True
) -> Tuple[plt.Figure, plt.Axes]:
    """
    绘制3D散点图
    
    在三维空间中展示三种方法的重要性分数
    """
    from mpl_toolkits.mplot3d import Axes3D
    
    top_features = merged_df.head(top_n).copy()
    
    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111, projection='3d')
    
    # 散点图
    scatter = ax.scatter(
        top_features['shap_norm'],
        top_features['lime_norm'],
        top_features['ablation_norm'],
        c=top_features['avg_importance'],
        s=50,
        alpha=0.6,
        cmap='viridis',
        edgecolors='black',
        linewidths=0.5
    )
    
    # 标注Top 10基因
    top10 = top_features.head(10)
    for idx, row in top10.iterrows():
        ax.text(
            row['shap_norm'],
            row['lime_norm'],
            row['ablation_norm'],
            row['feature'],
            fontsize=7,
            alpha=0.8
        )
    
    ax.set_xlabel('SHAP (normalized)', fontsize=10)
    ax.set_ylabel('LIME (normalized)', fontsize=10)
    ax.set_zlabel('Ablation (normalized)', fontsize=10)
    ax.set_title(f'3D Feature Importance Space (Top {top_n} Genes)', fontsize=13, fontweight='bold')
    
    # 颜色条
    cbar = plt.colorbar(scatter, ax=ax, shrink=0.8)
    cbar.set_label('Average Importance', fontsize=10)
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    if show:
        plt.show()
    
    return fig, ax


def plot_hierarchical_clustering_heatmap(
    merged_df: pd.DataFrame,
    top_n: int = 50,
    figsize: Tuple[int, int] = (14, 10),
    save_path: Optional[str] = None,
    show: bool = True
) -> Tuple[plt.Figure, plt.Axes]:
    """
    绘制层次聚类热图
    
    对特征进行聚类，显示相似的重要性模式
    """
    top_features = merged_df.head(top_n).copy()
    
    # 准备数据
    data_for_clustering = top_features[['shap_norm', 'lime_norm', 'ablation_norm']].T
    data_for_clustering.columns = top_features['feature'].values
    
    # 绘制聚类热图
    fig, ax = plt.subplots(figsize=figsize)
    sns.clustermap(
        data_for_clustering,
        method='ward',
        metric='euclidean',
        cmap='YlOrRd',
        figsize=figsize,
        cbar_kws={'label': 'Normalized Importance'},
        row_cluster=True,
        col_cluster=True,
        annot=False,
        fmt='.2f',
        linewidths=0.5
    )
    
    plt.suptitle(f'Hierarchical Clustering Heatmap (Top {top_n} Genes)', 
                fontsize=13, fontweight='bold', y=0.98)
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    if show:
        plt.show()
    
    return fig, ax


def plot_consensus_distribution(
    merged_df: pd.DataFrame,
    top_n: int = 100,
    figsize: Tuple[int, int] = (10, 6),
    save_path: Optional[str] = None,
    show: bool = True
) -> Tuple[plt.Figure, plt.Axes]:
    """
    绘制一致性得分分布图
    
    显示有多少基因被1种、2种、3种方法同时识别为重要
    """
    top_features = merged_df.head(top_n).copy()
    
    # 计算一致性得分
    threshold = 0.3
    top_features['consensus_count'] = (
        (top_features['shap_norm'] > threshold).astype(int) +
        (top_features['lime_norm'] > threshold).astype(int) +
        (top_features['ablation_norm'] > threshold).astype(int)
    )
    
    # 统计分布
    consensus_dist = top_features['consensus_count'].value_counts().sort_index()
    
    fig, axes = plt.subplots(1, 2, figsize=figsize)
    
    # 柱状图
    colors = ['#C44E52', '#F39C12', '#55A868']
    bars = axes[0].bar(consensus_dist.index, consensus_dist.values, 
                      color=[colors[i-1] if i <= 3 else 'gray' for i in consensus_dist.index],
                      alpha=0.7, edgecolor='black')
    axes[0].set_xlabel('Number of Methods Agreeing', fontsize=11)
    axes[0].set_ylabel('Number of Genes', fontsize=11)
    axes[0].set_title('Consensus Distribution', fontsize=12)
    axes[0].grid(True, alpha=0.3, axis='y')
    
    # 添加数值标签
    for bar in bars:
        height = bar.get_height()
        axes[0].text(bar.get_x() + bar.get_width()/2., height,
                    f'{int(height)}',
                    ha='center', va='bottom', fontsize=10)
    
    # 饼图
    labels = [f'{i} method(s)' for i in consensus_dist.index]
    axes[1].pie(consensus_dist.values, labels=labels, autopct='%1.1f%%',
               colors=[colors[i-1] if i <= 3 else 'gray' for i in consensus_dist.index],
               startangle=90)
    axes[1].set_title('Consensus Distribution (Pie Chart)', fontsize=12)
    
    plt.suptitle(f'Consensus Score Distribution (Top {top_n} Genes)', 
                fontsize=13, fontweight='bold')
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    if show:
        plt.show()
    
    return fig, axes



In [None]:
# 可解释性分析
# 1. SHAP Analysis using DeepExplainer
def shap_analysis_deep(model, dataloader, device, 
                       num_background=50, num_test=100,
                       save_dir="/home/fujing/ad_ssl/3_ssl_model/shap",
                       gene_names=None,
                       figsize=(6, 5),
                       save_plots=True,
                       show_plots=True):
    """
    SHAP分析函数
    
    Returns:
        shap_pos: SHAP值数组
        importance: 特征重要性DataFrame
        figs: 可视化图形字典 {'beeswarm': fig1, 'bar': fig2, 'top_features': fig3}
    """
    os.makedirs(save_dir, exist_ok=True)
    model.eval()
    
    all_data = []
    for x, y in dataloader:
        all_data.append(x.numpy())
        if len(all_data) * x.shape[0] >= num_background + num_test:
            break
    
    all_data = np.concatenate(all_data, axis=0)
    bg = torch.tensor(all_data[:num_background], dtype=torch.float32).to(device)
    test = torch.tensor(all_data[num_background:num_background+num_test], dtype=torch.float32).to(device)
    
    print(f"Background: {bg.shape}, Test: {test.shape}")
    
    # SHAP分析
    explainer = shap.DeepExplainer(model, bg)
    print("Computing SHAP values...")
    shap_values = explainer.shap_values(test)
    
    # 处理维度
    shap_pos = shap_values[1] if isinstance(shap_values, list) else shap_values
    print(f"Original SHAP shape: {shap_pos.shape}")
    
    # ★ 关键修复：如果是3维，只取AD类（类别1）
    if shap_pos.ndim == 3:
        shap_pos = shap_pos[:, :, 1]  # 取第二个类别（AD）
        print(f"Fixed SHAP shape: {shap_pos.shape}")  # 现在是 (100, 421)
    
    test_np = test.cpu().numpy()
    
    # 计算特征重要性 - 现在是一维了
    mean_shap = np.abs(shap_pos).mean(axis=0)  # (421,)
    
    # 创建DataFrame - 不会报错了
    importance = pd.DataFrame({
        'feature': gene_names if gene_names else [f"Gene_{i}" for i in range(len(mean_shap))],
        'mean_abs_shap': mean_shap
    }).sort_values('mean_abs_shap', ascending=False)
    
    importance.to_csv(os.path.join(save_dir, "shap_importance_fixed.csv"), index=False)
    
    # 可视化 - 使用可视化函数
    figs = {}
    
    # 1. SHAP summary plot (beeswarm)
    fig1, ax1 = plot_shap_summary(
        shap_pos, test_np, gene_names=gene_names,
        max_display=20, plot_type="summary",
        figsize=figsize,
        save_path=os.path.join(save_dir, "shap_beeswarm_fixed.pdf") if save_plots else None,
        show=show_plots
    )
    figs['beeswarm'] = (fig1, ax1)
    
    # 2. SHAP bar plot
    fig2, ax2 = plot_shap_summary(
        shap_pos, test_np, gene_names=gene_names,
        max_display=20, plot_type="bar",
        figsize=figsize,
        save_path=os.path.join(save_dir, "shap_bar_fixed.pdf") if save_plots else None,
        show=show_plots
    )
    figs['bar'] = (fig2, ax2)
    
    # 3. Top features bar plot
    fig3, ax3 = plot_shap_top_features(
        importance, top_n=20,
        figsize=figsize,
        save_path=os.path.join(save_dir, "shap_top20_features.pdf") if save_plots else None,
        show=show_plots
    )
    figs['top_features'] = (fig3, ax3)
    
    return shap_pos, importance, figs

# 2. LIME Analysis
def lime_analysis(model, dataloader, device,
                  num_samples=200,
                  num_features=30,
                  num_perturbations=3000,
                  save_dir="/home/fujing/ad_ssl/3_ssl_model/lime",
                  gene_names=None,
                  figsize=(6, 5),
                  save_plots=True,
                  show_plots=True):
    """
    LIME分析函数
    
    Returns:
        feature_importance_matrix: 特征重要性矩阵
        feature_importance: 特征重要性DataFrame
        figs: 可视化图形字典 {'top_features': fig1, 'heatmap': fig2}
    """
    os.makedirs(save_dir, exist_ok=True)
    model.eval()

    
    # 收集数据
    print("\n[1/4] 收集数据...")
    all_data = []
    all_labels = []
    for x, y in dataloader:
        all_data.append(x.numpy())
        all_labels.append(y.numpy())
        if len(all_data) * x.shape[0] >= num_samples * 2:
            break
    
    all_data = np.concatenate(all_data, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)
    
    train_data = all_data[:num_samples]
    test_data = all_data[num_samples:num_samples*2][:num_samples]
    test_labels = all_labels[num_samples:num_samples*2][:num_samples]
    
    print(f"  ✓ 训练数据: {train_data.shape}")
    print(f"  ✓ 测试数据: {test_data.shape}")
    
    # 定义预测函数
    def predict_fn(X):
        """LIME需要的预测函数"""
        X_tensor = torch.tensor(X, dtype=torch.float32).to(device)
        with torch.no_grad():
            outputs = model(X_tensor)
            probs = torch.softmax(outputs, dim=1).cpu().numpy()
        return probs
    
    # 创建LIME explainer
    print("\n[2/4] 创建 LIME Explainer...")
    explainer = lime_tabular.LimeTabularExplainer(
        training_data=train_data,
        feature_names=gene_names if gene_names else [f"Gene_{i}" for i in range(train_data.shape[1])],
        class_names=['Control', 'AD'],
        mode='classification',
        random_state=2025
    )
    
    # 测试单个样本看看LIME返回什么
    print("\n[DEBUG] 测试LIME输出格式...")
    test_exp = explainer.explain_instance(
        test_data[0],
        predict_fn,
        num_features=10,
        num_samples=500
    )
    
    print(f"  as_list() 示例: {test_exp.as_list()[:3]}")
    print(f"  as_map() 示例: {list(test_exp.as_map()[1][:3])}")  # 类别1的特征
    
    # 对每个样本进行解释
    print(f"\n[3/4] 分析 {num_samples} 个样本...")
    all_explanations = []
    feature_importance_matrix = np.zeros((num_samples, train_data.shape[1]))
    
    for idx in tqdm(range(num_samples), desc="LIME分析进度"):
        explanation = explainer.explain_instance(
            test_data[idx],
            predict_fn,
            num_features=num_features,
            num_samples=num_perturbations
        )
        all_explanations.append(explanation)
        
        # ★★★ 修复：使用as_map()而不是as_list() ★★★
        # as_map()返回 {类别: [(特征索引, 重要性), ...]}
        exp_map = explanation.as_map()
        
        # 获取AD类（类别1）的解释
        if 1 in exp_map:
            for feat_idx, importance in exp_map[1]:
                feature_importance_matrix[idx, feat_idx] = abs(importance)
        
        # 如果类别1没有，尝试获取所有类别
        elif 0 in exp_map:
            for feat_idx, importance in exp_map[0]:
                feature_importance_matrix[idx, feat_idx] = abs(importance)
    
    print("LIME分析完成")
    
    # 计算平均特征重要性
    print("\n[4/4] 计算特征重要性...")
    mean_importance = feature_importance_matrix.mean(axis=0)
    
    # 调试信息
    print(f"  非零特征数: {np.sum(mean_importance > 0)}")
    print(f"  最大重要性: {mean_importance.max():.6f}")
    print(f"  平均重要性: {mean_importance.mean():.6f}")
    
    feature_importance = pd.DataFrame({
        'feature': gene_names if gene_names else [f"Gene_{i}" for i in range(len(mean_importance))],
        'mean_abs_importance': mean_importance
    }).sort_values('mean_abs_importance', ascending=False)
    
    # 保存结果
    feature_importance.to_csv(os.path.join(save_dir, "lime_feature_importance_fixed.csv"), index=False)
    
    # 可视化 - 使用可视化函数
    figs = {}
    
    # 1. Top features bar plot
    fig1, ax1 = plot_lime_top_features(
        feature_importance, top_n=20,
        figsize=figsize,
        save_path=os.path.join(save_dir, "lime_top20_features_fixed.pdf") if save_plots else None,
        show=show_plots
    )
    figs['top_features'] = (fig1, ax1)
    
    # 2. Heatmap
    if mean_importance.max() > 0:  # 只在有非零值时生成热图
        fig2, ax2 = plot_lime_heatmap(
            feature_importance_matrix,
            gene_names=gene_names,
            top_n_genes=30,
            n_samples=50,
            figsize=figsize,
            save_path=os.path.join(save_dir, "lime_importance_heatmap_fixed.pdf") if save_plots else None,
            show=show_plots
        )
        if fig2 is not None:
            figs['heatmap'] = (fig2, ax2)
            print(f"LIME importance heatmap saved")
    
    print("LIME分析完成!")
    return feature_importance_matrix, feature_importance, figs

# 3. Feature Ablation Analysis
def feature_ablation_analysis(model, dataloader, device,
                              num_samples=200,
                              ablation_baseline='mean',
                              save_dir="/home/fujing/ad_ssl/3_ssl_model/ablation",
                              gene_names=None,
                              figsize=(6, 5),
                              save_plots=True,
                              show_plots=True):
    """
    Feature Ablation分析函数
    
    Returns:
        feature_importance_scores: 特征重要性分数数组
        feature_importance: 特征重要性DataFrame
        figs: 可视化图形字典 {'top_features': fig1, 'distribution': fig2}
    """
    os.makedirs(save_dir, exist_ok=True)
    
    model.eval()
    
    # 收集数据
    print("\n[1/4] 收集数据...")
    all_data = []
    all_labels = []
    for x, y in dataloader:
        all_data.append(x.numpy())
        all_labels.append(y.numpy())
        if len(all_data) * x.shape[0] >= num_samples:
            break
    
    all_data = np.concatenate(all_data, axis=0)[:num_samples]
    all_labels = np.concatenate(all_labels, axis=0)[:num_samples]
    
    n_features = all_data.shape[1]
    print(f"  ✓ 数据形状: {all_data.shape}")
    print(f"  ✓ 特征数量: {n_features}")
    
    # 计算baseline值
    print(f"\n[2/4] 计算baseline值 (方法: {ablation_baseline})...")
    if ablation_baseline == 'zero':
        baseline_values = np.zeros(n_features)
    elif ablation_baseline == 'mean':
        baseline_values = all_data.mean(axis=0)
    elif ablation_baseline == 'median':
        baseline_values = np.median(all_data, axis=0)
    else:
        baseline_values = np.zeros(n_features)
    
    # 获取原始预测
    print("\n[3/4] 计算原始预测...")
    original_data_tensor = torch.tensor(all_data, dtype=torch.float32).to(device)
    with torch.no_grad():
        original_outputs = model(original_data_tensor)
        original_probs = torch.softmax(original_outputs, dim=1)[:, 1].cpu().numpy()  # AD类的概率
    
    print(f"  ✓ 原始预测完成，平均AD概率: {original_probs.mean():.4f}")
    
    # 对每个特征进行消除分析
    print(f"\n[4/4] 逐个消除特征并观察影响 ({n_features} 个特征)...")
    feature_importance_scores = np.zeros(n_features)
    
    for feat_idx in tqdm(range(n_features), desc="Feature Ablation进度"):
        # 创建消除该特征的数据
        ablated_data = all_data.copy()
        ablated_data[:, feat_idx] = baseline_values[feat_idx]
        
        # 预测
        ablated_tensor = torch.tensor(ablated_data, dtype=torch.float32).to(device)
        with torch.no_grad():
            ablated_outputs = model(ablated_tensor)
            ablated_probs = torch.softmax(ablated_outputs, dim=1)[:, 1].cpu().numpy()
        
        # 计算预测差异（原始 - 消除后）
        prob_diff = np.abs(original_probs - ablated_probs)
        feature_importance_scores[feat_idx] = prob_diff.mean()
    
    print("  ✓ Feature Ablation分析完成")
    
    # 创建特征重要性DataFrame
    feature_importance = pd.DataFrame({
        'feature': gene_names if gene_names else [f"Gene_{i}" for i in range(n_features)],
        'mean_abs_prob_change': feature_importance_scores
    }).sort_values('mean_abs_prob_change', ascending=False)
    
    # 保存结果
    feature_importance.to_csv(os.path.join(save_dir, "ablation_feature_importance.csv"), index=False)
    
    # 可视化 - 使用可视化函数
    figs = {}
    
    # 1. Top features bar plot
    fig1, ax1 = plot_ablation_top_features(
        feature_importance, top_n=20,
        figsize=figsize,
        save_path=os.path.join(save_dir, "ablation_top20_features.pdf") if save_plots else None,
        show=show_plots
    )
    figs['top_features'] = (fig1, ax1)
    
    # 2. Distribution plot
    fig2, ax2 = plot_ablation_distribution(
        feature_importance_scores,
        figsize=figsize,
        save_path=os.path.join(save_dir, "ablation_importance_distribution.pdf") if save_plots else None,
        show=show_plots
    )
    figs['distribution'] = (fig2, ax2)
    
    return feature_importance_scores, feature_importance, figs

# 4. 对比三种方法的结果
def compare_interpretability_methods(shap_importance, lime_importance, ablation_importance,
                                     save_dir="/home/fujing/ad_ssl/3_ssl_model/comparison",
                                     top_n=10,
                                     figsize=(12, 4),
                                     save_plots=True,
                                     show_plots=True):
    """
    对比三种可解释性方法的结果
    
    Returns:
        merged: 合并后的特征重要性DataFrame
        figs: 可视化图形字典 {'top_features': fig1, 'correlation': fig2, 'venn': fig3}
    """
    os.makedirs(save_dir, exist_ok=True)

    
    # 标准化重要性分数
    shap_df = shap_importance.copy()
    shap_df['shap_norm'] = (shap_df['mean_abs_shap'] - shap_df['mean_abs_shap'].min()) / \
                           (shap_df['mean_abs_shap'].max() - shap_df['mean_abs_shap'].min())
    
    lime_df = lime_importance.copy()
    lime_df['lime_norm'] = (lime_df['mean_abs_importance'] - lime_df['mean_abs_importance'].min()) / \
                           (lime_df['mean_abs_importance'].max() - lime_df['mean_abs_importance'].min())
    
    abl_df = ablation_importance.copy()
    abl_df['ablation_norm'] = (abl_df['mean_abs_prob_change'] - abl_df['mean_abs_prob_change'].min()) / \
                              (abl_df['mean_abs_prob_change'].max() - abl_df['mean_abs_prob_change'].min())
    
    # 合并三种方法的结果
    merged = shap_df[['feature', 'shap_norm']].merge(
        lime_df[['feature', 'lime_norm']], on='feature', how='outer'
    ).merge(
        abl_df[['feature', 'ablation_norm']], on='feature', how='outer'
    ).fillna(0)
    
    # 计算平均重要性
    merged['avg_importance'] = (merged['shap_norm'] + merged['lime_norm'] + merged['ablation_norm']) / 3
    merged = merged.sort_values('avg_importance', ascending=False)
    
    # 保存结果
    merged.to_csv(os.path.join(save_dir, "combined_feature_importance.csv"), index=False)
    
    # 可视化 - 使用可视化函数
    figs = {}
    
    # 1. Top N特征的比较
    fig1, axes1 = plot_interpretability_comparison_top_features(
        merged, top_n=top_n,
        figsize=figsize,
        save_path=os.path.join(save_dir, f"comparison_top{top_n}_features.pdf") if save_plots else None,
        show=show_plots
    )
    figs['top_features'] = (fig1, axes1)
    
    # 2. 相关性分析
    fig2, axes2 = plot_interpretability_correlation(
        merged,
        figsize=figsize,
        save_path=os.path.join(save_dir, "method_correlation.pdf") if save_plots else None,
        show=show_plots
    )
    figs['correlation'] = (fig2, axes2)
    
    # 3. Venn diagram
    fig3, ax3 = plot_interpretability_venn(
        shap_df, lime_df, abl_df,
        top_n=20,
        figsize=(6, 6),
        save_path=os.path.join(save_dir, "venn_diagram_top20.pdf") if save_plots else None,
        show=show_plots
    )
    figs['venn'] = (fig3, ax3)
    
    return merged, figs

In [None]:
# 1. 加载数据
features, labels, gene_names = load_and_preprocess_raw_data(EXPR_FILE, LABEL_FILE, "id", "group", CONFIG)

# 2. 加载模型
MODEL_PATH = "/home/fujing/ad_ssl/3_ssl_model/ensemble_fold3_model7.pth"
device = get_device()
model = GeneClassifier(
    input_dim=features.shape[1],
    hidden_dims=CONFIG["HIDDEN_DIMS"],
    num_classes=2,
    dropout_rate=CONFIG["DROPOUT"],
    use_batch_norm=CONFIG["USE_BATCH_NORM"],
    proj_dim=CONFIG["PROJ_DIM"]
).to(device)

model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval()

# 3. 创建DataLoader
dataset_all = GeneDataset(features, labels)
loader_all = DataLoader(dataset_all, batch_size=128, shuffle=False)


In [None]:
# 4. SHAP分析
shap_values, shap_importance, shap_figs = shap_analysis_deep(
    model=model,
    dataloader=loader_all,
    device=device,
    num_background=250,
    num_test=100,
    save_dir=os.path.join(OUT_DIR, "shap"),
    gene_names=gene_names,  # 传递基因名称
    figsize=(6, 5),
    save_plots=True,
    show_plots=True
)

In [None]:
# 5. LIME分析
lime_matrix, lime_importance, lime_figs = lime_analysis(
    model=model,
    dataloader=loader_all,
    device=device,
    num_samples=200,  # LIME比较慢，用100个样本
    num_features=30,
    num_perturbations=3000,
    save_dir=os.path.join(OUT_DIR, "lime"),
    gene_names=gene_names,  # 传递基因名称
    figsize=(6, 5),
    save_plots=True,
    show_plots=True
)


In [None]:
# 6. Feature Ablation分析
ablation_scores, ablation_importance, ablation_figs = feature_ablation_analysis(
    model=model,
    dataloader=loader_all,
    device=device,
    num_samples=500,
    ablation_baseline='mean',
    save_dir=os.path.join(OUT_DIR, "ablation"),
    gene_names=gene_names,  # 传递基因名称
    figsize=(6, 5),
    save_plots=True,
    show_plots=True
)


In [None]:
# 7. 对比三种方法
combined_importance, comparison_figs = compare_interpretability_methods(
    shap_importance=shap_importance,
    lime_importance=lime_importance,
    ablation_importance=ablation_importance,
    save_dir=os.path.join(OUT_DIR, "comparison_v2"),
    top_n=10,
    figsize=(12, 4),
    save_plots=True,
    show_plots=True
)


In [None]:
# 生成额外的联合分析可视化图

# 确保combined_importance已经生成
if 'combined_importance' in locals():
    print("生成额外的联合分析可视化图...\n")
    
    save_dir = os.path.join(OUT_DIR, "comparison_v2")
    
    # 1. 重要性热图
    print("1. 生成重要性热图...")
    fig1, ax1 = plot_interpretability_heatmap(
        combined_importance,
        top_n=30,
        figsize=(12, 6),
        save_path=os.path.join(save_dir, "importance_heatmap_top30.pdf"),
        show=True
    )
    
    # 2. 排名对比图
    print("2. 生成排名对比图...")
    fig2, ax2 = plot_interpretability_ranking_comparison(
        combined_importance,
        top_n=30,
        figsize=(14, 8),
        save_path=os.path.join(save_dir, "ranking_comparison_top30.pdf"),
        show=True
    )
    
    # 3. 一致性得分热图
    print("3. 生成一致性得分热图...")
    fig3, ax3 = plot_consensus_score_heatmap(
        combined_importance,
        top_n=50,
        figsize=(14, 6),
        save_path=os.path.join(save_dir, "consensus_score_heatmap_top50.pdf"),
        show=True
    )
    
    # 4. 累积重要性图
    print("4. 生成累积重要性图...")
    fig4, ax4 = plot_cumulative_importance(
        combined_importance,
        top_n=50,
        figsize=(10, 6),
        save_path=os.path.join(save_dir, "cumulative_importance_top50.pdf"),
        show=True
    )
    
    # 5. 3D散点图
    print("5. 生成3D散点图...")
    try:
        fig5, ax5 = plot_3d_importance_scatter(
            combined_importance,
            top_n=50,
            figsize=(12, 10),
            save_path=os.path.join(save_dir, "3d_scatter_top50.pdf"),
            show=True
        )
    except Exception as e:
        print(f"   3D图生成失败: {e}")
    
    # 6. 层次聚类热图
    print("6. 生成层次聚类热图...")
    try:
        fig6, ax6 = plot_hierarchical_clustering_heatmap(
            combined_importance,
            top_n=50,
            figsize=(14, 10),
            save_path=os.path.join(save_dir, "hierarchical_clustering_top50.pdf"),
            show=True
        )
    except Exception as e:
        print(f"   层次聚类图生成失败: {e}")
    
    # 7. 一致性分布图
    print("7. 生成一致性分布图...")
    fig7, axes7 = plot_consensus_distribution(
        combined_importance,
        top_n=100,
        figsize=(12, 6),
        save_path=os.path.join(save_dir, "consensus_distribution_top100.pdf"),
        show=True
    )
    
    print("\n✅ 所有额外的联合分析图已生成并保存！")
    print(f"保存位置: {save_dir}")
else:
    print("⚠️ 请先运行compare_interpretability_methods函数生成combined_importance")


In [None]:
# 8. 生成总结报告
print("\n📊 三种方法一致识别的Top基因:")
shap_top20 = set(shap_importance.head(20)['feature'])
lime_top20 = set(lime_importance.head(20)['feature'])
abl_top20 = set(ablation_importance.head(20)['feature'])
consensus_genes = shap_top20 & lime_top20 & abl_top20

if consensus_genes:
    print(f"  三方一致 (Top 20): {sorted(consensus_genes)}")

In [None]:
# ============================================================================
# 绘制consensus_genes在HC和AD中的表达差异
# ============================================================================

def plot_consensus_genes_expression(
    expr_file: str,
    label_file: str,
    consensus_genes: List[str],
    sample_id_col: str = "id",
    label_col: str = "group",
    figsize: Tuple[int, int] = (14, 6),
    save_path: Optional[str] = None,
    show: bool = True
) -> Tuple[plt.Figure, np.ndarray]:
    """
    绘制consensus_genes在HC和AD中的表达差异
    
    Args:
        expr_file: 表达数据文件路径
        label_file: 标签文件路径
        consensus_genes: 一致基因列表
        sample_id_col: 样本ID列名
        label_col: 标签列名
        figsize: 图形大小
        save_path: 保存路径
        show: 是否显示
    
    Returns:
        fig, axes: matplotlib图形对象
    """
    # 1. 读取原始表达数据
    expr_df = read_expression(expr_file)
    labels_series = read_labels(label_file, sample_id_col, label_col)
    
    # 2. 对齐数据
    expr_aligned, labels_aligned = align_expr_labels(expr_df, labels_series)
    
    # 3. 提取consensus_genes的表达数据
    available_genes = [g for g in consensus_genes if g in expr_aligned.index]
    if len(available_genes) == 0:
        print(f"警告: consensus_genes中没有在表达数据中找到的基因")
        return None, None
    
    if len(available_genes) < len(consensus_genes):
        missing = set(consensus_genes) - set(available_genes)
        print(f"警告: 以下基因在表达数据中未找到: {missing}")
    
    expr_consensus = expr_aligned.loc[available_genes].T  # 转置为 (n_samples, n_genes)
    
    # 4. 准备绘图数据
    plot_data = []
    for gene in available_genes:
        for sample_id in expr_consensus.index:
            label = labels_aligned.loc[sample_id]
            label_name = "AD" if label == 1 else "HC"
            expr_value = expr_consensus.loc[sample_id, gene]
            plot_data.append({
                'Gene': gene,
                'Group': label_name,
                'Expression': expr_value
            })
    
    plot_df = pd.DataFrame(plot_data)
    
    # 5. 绘制箱线图
    n_genes = len(available_genes)
    n_cols = min(4, n_genes)
    n_rows = (n_genes + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
    if n_genes == 1:
        axes = np.array([axes])
    axes = axes.flatten()
    
    for idx, gene in enumerate(available_genes):
        ax = axes[idx]
        gene_data = plot_df[plot_df['Gene'] == gene]
        
        # 箱线图
        groups = ['HC', 'AD']
        data_to_plot = [gene_data[gene_data['Group'] == g]['Expression'].values for g in groups]
        
        bp = ax.boxplot(data_to_plot, labels=groups, patch_artist=True, 
                       widths=0.6, showmeans=True, meanline=True)
        
        # 设置颜色
        colors = ['#4C72B0', '#C44E52']  # HC用蓝色，AD用红色
        for patch, color in zip(bp['boxes'], colors):
            patch.set_facecolor(color)
            patch.set_alpha(0.7)
        
        # 添加散点
        for i, group in enumerate(groups):
            group_data = gene_data[gene_data['Group'] == group]['Expression'].values
            x_pos = np.random.normal(i+1, 0.04, size=len(group_data))
            ax.scatter(x_pos, group_data, alpha=0.4, s=20, color=colors[i], zorder=3)
        
        ax.set_title(gene, fontsize=11, fontweight='bold')
        ax.set_ylabel('Expression Level', fontsize=10)
        ax.grid(True, alpha=0.3, axis='y')
        
        # 添加统计检验（可选）
        from scipy import stats
        hc_data = gene_data[gene_data['Group'] == 'HC']['Expression'].values
        ad_data = gene_data[gene_data['Group'] == 'AD']['Expression'].values
        if len(hc_data) > 0 and len(ad_data) > 0:
            try:
                stat, p_value = stats.mannwhitneyu(hc_data, ad_data, alternative='two-sided')
                significance = '***' if p_value < 0.001 else '**' if p_value < 0.01 else '*' if p_value < 0.05 else 'ns'
                ax.text(0.5, 0.95, f'p={p_value:.3f} ({significance})', 
                       transform=ax.transAxes, ha='center', va='top',
                       fontsize=9, bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
            except:
                pass
    
    # 隐藏多余的子图
    for idx in range(len(available_genes), len(axes)):
        axes[idx].axis('off')
    
    plt.suptitle('Expression Differences of Consensus Genes (HC vs AD)', 
                fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight", facecolor='white', edgecolor='none')
    if show:
        plt.show()
    
    return fig, axes


# 绘制表达差异
if 'consensus_genes' in locals() and len(consensus_genes) > 0:
    fig, axes = plot_consensus_genes_expression(
        expr_file=EXPR_FILE,
        label_file=LABEL_FILE,
        consensus_genes=list(consensus_genes),
        sample_id_col="id",
        label_col="group",
        figsize=(16, 8),
        save_path=os.path.join(OUT_DIR, "consensus_genes_expression_difference.pdf"),
        show=True
    )
    
    print(f"\n✅ 已绘制 {len(consensus_genes)} 个consensus genes的表达差异图")
else:
    print("⚠️ 未找到consensus_genes，请先运行可解释性分析")
