In [2]:
import scanpy as sc
import pandas as pd
from sklearn.model_selection import train_test_split
import numpy as np

# 读取h5ad文件
adata = sc.read("/home/lxz/scmamba/gene_pretubation/norman.h5ad")
a=adata.obs
b=adata.var
print("adata.obs 的列名:", adata.obs.columns.tolist())
print("adata.var 的列名:", adata.var.columns.tolist())
print("adata.var 的索引:", adata.var.index.name)

# 检查基因名称的存储位置
# 通常基因名称可能在：
# 1. adata.var.index (索引)
# 2. adata.var['gene_name'] 或类似的列
# 3. adata.var['gene_names']

# 方法1：检查索引
if "HMGA2" in adata.var.index:
    print("##### HMGA2 found in var index")

adata.obs 的列名: ['condition', 'pert_type', 'cell_type', 'guide_identity', 'read_count', 'UMI_count', 'coverage', 'gemgroup', 'good_coverage', 'number_of_cells', 'n_genes', 'cell_type_condition', 'lane', 'dose_val', 'control', 'condition_name', 'total_count']
adata.var 的列名: ['gene_name', 'highly_variable', 'means', 'dispersions', 'dispersions_norm']
adata.var 的索引: None


In [14]:
import os
import pickle
import torch
import numpy as np
import pandas as pd
import anndata as ad
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import sys
from pathlib import Path

# 获取当前 Notebook 所在目录的父目录
current_dir = Path.cwd()
parent_dir = current_dir.parent
sys.path.append(str(parent_dir))
from models.gene_tokenizer import GeneVocab
from models.model import MambaModel

max_seq_len = 1070
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class H5ADDataset(Dataset):
    """读取h5ad文件的数据集，只选择ctrl细胞"""
    def __init__(self, file_path):
        self.adata = ad.read_h5ad(file_path)
        
        # 筛选ctrl细胞
        if 'condition' in self.adata.obs.columns:
            ctrl_mask = self.adata.obs['condition'] == 'ctrl'
            self.adata = self.adata[ctrl_mask]
            print(f"筛选后ctrl细胞数量: {len(self.adata.obs)}")
        else:
            print("警告: 数据集中没有condition列，使用所有细胞")
        
        # 确保数据是稠密矩阵
        if hasattr(self.adata.X, 'toarray'):
            self.X = self.adata.X.toarray()
        else:
            self.X = self.adata.X
        
        # 获取细胞类型信息
        if 'condition' in self.adata.obs.columns:
            self.cell_types = self.adata.obs['condition'].values
        else:
            self.cell_types = ['unknown'] * len(self.adata.obs)
    
    def __len__(self):
        return len(self.adata.obs)
    
    def __getitem__(self, idx):
        return {
            'X': torch.FloatTensor(self.X[idx]),
            'cell_type': self.cell_types[idx]
        }

def preprocess_batch_fixed_genes(x_batch, feature_names, vocab, device, max_seq_len=2049):
    """预处理批次数据，使用固定的1049个基因顺序"""
    batch_size = x_batch.size(0)
    n_genes = len(feature_names)
    
    # 初始化张量
    values_tensor = torch.full((batch_size, max_seq_len), -2.0, device=device)
    src_tensor = torch.full((batch_size, max_seq_len), vocab["<pad>"], dtype=torch.long, device=device)
    padding_mask = torch.zeros((batch_size, max_seq_len), dtype=torch.bool, device=device)
    
    for i in range(batch_size):
        # 使用固定的基因顺序：CLS + 所有1049个基因
        genes = ["<cls>"] + feature_names
        
        # 截断或填充到max_seq_len
        if len(genes) > max_seq_len:
            genes = genes[:max_seq_len]
        else:
            genes = genes + ["<pad>"] * (max_seq_len - len(genes))
        
        # 映射到token ID
        src_tensor[i] = torch.tensor([
            vocab[gene] if gene in vocab else vocab["<unk>"] for gene in genes
        ], device=device)
        
        # 填充values_tensor：CLS位置为0，基因位置为实际表达值
        values_tensor[i, 0] = 0.0  # <cls> 位置
        
        # 填充基因表达值
        n_genes_to_fill = min(n_genes, max_seq_len - 1)
        values_tensor[i, 1:1+n_genes_to_fill] = x_batch[i, :n_genes_to_fill]
        
        # 生成填充掩码
        padding_mask[i] = torch.tensor([gene == "<pad>" for gene in genes], device=device)
        
    return values_tensor, src_tensor, padding_mask

def get_gene_embeddings(model, file_path, feature_names, vocab):
    """
    从h5ad文件提取所有基因的嵌入
    """
    all_gene_embeddings = []
    all_cell_types = []
    
    dataset = H5ADDataset(file_path)
    loader = DataLoader(
        dataset,
        batch_size=256,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    
    with torch.no_grad():
        for batch in tqdm(loader, desc=f"Extracting gene embeddings"):
            x = batch['X'].to(device, non_blocking=True)
            values, src, mask = preprocess_batch_fixed_genes(x, feature_names, vocab, device, max_seq_len)
            print(src)
            with torch.cuda.amp.autocast(enabled=True):
                outputs = model(src=src, values=values, src_key_padding_mask=mask)
                
                # 当使用 "all-genes" 模式时，outputs["cell_emb"] 已经是基因嵌入
                # 形状: [batch_size, n_genes, d_model]
                gene_embeddings = outputs["cell_emb"].cpu().numpy()
            
            all_gene_embeddings.append(gene_embeddings)
            all_cell_types.append(batch['cell_type'])
    
    # 合并所有结果
    all_gene_embeddings = np.concatenate(all_gene_embeddings, axis=0)
    all_cell_types = np.concatenate(all_cell_types, axis=0)
    
    return all_gene_embeddings, all_cell_types

def load_model(model_path, vocab_path):
    """加载预训练模型和词汇表"""
    vocab = GeneVocab.from_file(vocab_path)
    model = MambaModel(
        ntoken=len(vocab), d_model=512, nhead=8, d_hid=512, nlayers=6,
        dropout=0.2, pad_token="<pad>", pad_value=-2,
        input_emb_style="continuous", 
        cell_emb_style="all-genes"  # 修改为返回所有基因嵌入的模式
    )
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
    return model, vocab

if __name__ == "__main__":
    # 1. 加载模型和词汇表
    model, vocab = load_model(
        "/home/lxz/scmamba/model_state/cell_cls_3loss_6layer_final.pth",
        "/home/lxz/scmamba/vocab.json"
    )
    
    # 2. 加载h5ad文件并获取基因名称
    h5ad_file_path = "/home/lxz/scmamba/gene_pretubation/GEARS/data/adamson.h5ad"
    adata = ad.read_h5ad(h5ad_file_path)
    
    # 从var中获取gene_name作为特征名称
    if 'gene_name' in adata.var.columns:
        feature_names = adata.var['gene_name'].tolist()
    else:
        # 如果没有gene_name列，使用var的index
        feature_names = adata.var.index.tolist()
        print("警告: 使用var的index作为基因名称")
    
    print(f"数据集包含 {adata.n_obs} 个细胞和 {len(feature_names)} 个基因")
    print(f"前10个基因名称: {feature_names[:10]}")
    
    # 3. 提取基因嵌入（只处理ctrl细胞）
    gene_embeddings, cell_types = get_gene_embeddings(model, h5ad_file_path, feature_names, vocab)
    print(f"基因嵌入矩阵形状: {gene_embeddings.shape}")
    # 4. 保存结果
    output_dir = "/home/lxz/scmamba/gene_pretubation/embeddings"
    os.makedirs(output_dir, exist_ok=True)
    
    # 保存基因嵌入（每个细胞对应1049个基因的嵌入）
    np.save(os.path.join(output_dir, "adamson_gene_embeddings.npy"), gene_embeddings)
    np.save(os.path.join(output_dir, "adamson_cell_types.npy"), cell_types)
    
    print(f"\n结果已保存至 {output_dir}:")
    print(f"- 基因嵌入矩阵形状: {gene_embeddings.shape} (cells × genes × features)")
    print(f"- 细胞类型数量: {len(cell_types)}")
    print(f"- 每个细胞的基因数量: {gene_embeddings.shape[1]}")
    print(f"- 嵌入维度: {gene_embeddings.shape[2]}")

数据集包含 47795 个细胞和 1069 个基因
前10个基因名称: ['IL7R', 'GYPB', 'SEZ6', 'ELMO1', 'RPS10', 'FCER1G', 'CTRB2', 'HBG2', 'XAF1', 'CTTNBP2']
筛选后ctrl细胞数量: 3952


Extracting gene embeddings:   0%|          | 0/16 [00:00<?, ?it/s]

tensor([[60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        ...,
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535]], device='cuda:0')


Extracting gene embeddings:   6%|▋         | 1/16 [00:01<00:20,  1.36s/it]

tensor([[60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        ...,
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535]], device='cuda:0')


Extracting gene embeddings:  12%|█▎        | 2/16 [00:02<00:17,  1.21s/it]

tensor([[60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        ...,
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535]], device='cuda:0')


Extracting gene embeddings:  19%|█▉        | 3/16 [00:03<00:15,  1.17s/it]

tensor([[60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        ...,
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535]], device='cuda:0')


Extracting gene embeddings:  25%|██▌       | 4/16 [00:04<00:13,  1.14s/it]

tensor([[60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        ...,
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535]], device='cuda:0')


Extracting gene embeddings:  31%|███▏      | 5/16 [00:05<00:12,  1.13s/it]

tensor([[60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        ...,
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535]], device='cuda:0')


Extracting gene embeddings:  38%|███▊      | 6/16 [00:06<00:11,  1.12s/it]

tensor([[60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        ...,
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535]], device='cuda:0')


Extracting gene embeddings:  44%|████▍     | 7/16 [00:08<00:10,  1.12s/it]

tensor([[60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        ...,
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535]], device='cuda:0')


Extracting gene embeddings:  50%|█████     | 8/16 [00:09<00:08,  1.12s/it]

tensor([[60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        ...,
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535]], device='cuda:0')


Extracting gene embeddings:  56%|█████▋    | 9/16 [00:10<00:07,  1.12s/it]

tensor([[60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        ...,
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535]], device='cuda:0')


Extracting gene embeddings:  62%|██████▎   | 10/16 [00:11<00:06,  1.11s/it]

tensor([[60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        ...,
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535]], device='cuda:0')


Extracting gene embeddings:  69%|██████▉   | 11/16 [00:12<00:05,  1.11s/it]

tensor([[60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        ...,
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535]], device='cuda:0')


Extracting gene embeddings:  75%|███████▌  | 12/16 [00:13<00:04,  1.11s/it]

tensor([[60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        ...,
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535]], device='cuda:0')


Extracting gene embeddings:  81%|████████▏ | 13/16 [00:14<00:03,  1.11s/it]

tensor([[60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        ...,
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535]], device='cuda:0')


Extracting gene embeddings:  88%|████████▊ | 14/16 [00:15<00:02,  1.10s/it]

tensor([[60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        ...,
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535]], device='cuda:0')


Extracting gene embeddings:  94%|█████████▍| 15/16 [00:16<00:01,  1.10s/it]

tensor([[60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        ...,
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535],
        [60695, 12051, 10670,  ..., 19574,  3836,  5535]], device='cuda:0')


Extracting gene embeddings: 100%|██████████| 16/16 [00:17<00:00,  1.09s/it]


基因嵌入矩阵形状: (3952, 1069, 512)

结果已保存至 /home/lxz/scmamba/gene_pretubation/embeddings:
- 基因嵌入矩阵形状: (3952, 1069, 512) (cells × genes × features)
- 细胞类型数量: 3952
- 每个细胞的基因数量: 1069
- 嵌入维度: 512


In [13]:
import pickle
import scanpy as sc

# 读取pkl文件
with open('/home/lxz/scmamba/gene_pretubation/GEARS/data/gene2go.pkl', 'rb') as f:
    data = pickle.load(f)

# 获取gene2go中的基因列表
gene_names = list(data.keys())
print(f"gene2go.pkl 中的基因数量: {len(gene_names)}")

# 读取h5ad文件
adata = sc.read("/home/lxz/scmamba/gene_pretubation/GEARS/data/adamson.h5ad")
print(f"h5ad文件中的细胞数: {adata.n_obs}, 基因数: {adata.n_vars}")

# 获取h5ad文件中的基因名称
h5ad_genes = adata.var.gene_name.tolist()
print(f"h5ad文件中的基因数量: {len(h5ad_genes)}")

# 检查HMGA2是否在两个文件中
if "HMGA2" in gene_names:
    print("HMGA2 存在于 gene2go.pkl 中")
if "HMGA2" in h5ad_genes:
    print("HMGA2 存在于 h5ad 文件中")

# 找出两个文件的基因交集
common_genes = set(gene_names) & set(h5ad_genes)
print(f"两个文件共有的基因数量: {len(common_genes)}")

# 计算重合比例
overlap_ratio_gene2go = len(common_genes) / len(gene_names) * 100
overlap_ratio_h5ad = len(common_genes) / len(h5ad_genes) * 100

print(f"gene2go.pkl 中与 h5ad 重合的基因比例: {overlap_ratio_gene2go:.2f}%")
print(f"h5ad 中与 gene2go.pkl 重合的基因比例: {overlap_ratio_h5ad:.2f}%")

# 如果需要查看具体的重合基因列表（前10个作为示例）
print("\n重合基因示例（前10个）:")
for gene in list(common_genes)[:10]:
    print(gene)
# print(f"字典键的数量: {len(data)}")
# print("前几个键值对:")
# for i, (key, value) in enumerate(data.items()):
#     if i < 5:  # 只显示前5个
#         print(f"{key}: {value}")
#     else:
#         break

gene2go.pkl 中的基因数量: 67832
h5ad文件中的细胞数: 47795, 基因数: 1069
h5ad文件中的基因数量: 1069
HMGA2 存在于 gene2go.pkl 中
HMGA2 存在于 h5ad 文件中
两个文件共有的基因数量: 1049
gene2go.pkl 中与 h5ad 重合的基因比例: 1.55%
h5ad 中与 gene2go.pkl 重合的基因比例: 98.13%

重合基因示例（前10个）:
PLD3
NTRK1
AVP
OLAH
GOLIM4
IFITM1
PDIA6
PLXNC1
BIRC3
MMP10


In [7]:
with open('/home/lxz/scmamba/gene_pretubation/GEARS/data/gene2go.pkl', 'rb') as f:
    gene2go = pickle.load(f)

pert_names = np.unique(list(gene2go.keys()))
print(pert_names[:20])
print(len(pert_names))

["'C-K-RAS" '(FM-3)' '(ppGpp)ase' '0610011B16Rik' '0610037N12Rik'
 '0710008D09Rik' '1' '1-8D' '1-8U' '1-AGPAT 3' '1-AGPAT 6' '1-AGPAT1'
 '1-AGPAT2' '1-AGPAT4' '1-Cys' '10-FTHFDH' '10-fTHF' '101F10.1' '101F6'
 '104p']
67832


In [27]:
import os
import numpy as np
import matplotlib.pyplot as plt

# ======= 路径与参数 =======
H5AD = "/home/lxz/scmamba/gene_pretubation/GEARS/data/norman.h5ad"  # 仅用于取基因名，可不读
NPZ  = "/home/lxz/scmamba/gene_pretubation/GEARS/results/eval_outputs/mamba_norman_eval_results.npz"
DATA = "norman"                         # "norman" 或 "adamson"
TOPK = 20
OUT  = "/home/lxz/scmamba/gene_pretubation/GEARS/results/eval_outputs/"
os.makedirs(OUT, exist_ok=True)

# ========== 如果你本机有 scanpy，解开下两行即可从 h5ad 取基因名 ==========
import scanpy as sc
genes = sc.read_h5ad(H5AD).var_names.to_numpy()

# ======= 读取评估结果 =======
npz = np.load(NPZ, allow_pickle=True)
pred, truth, ctrl, pert_cat = npz["pred"], npz["truth"], npz["ctrl"], npz["pert_cat"]

# 这里尝试从 npz 取基因名；若没有，则回退为索引字符串（Gene_0,...）
genes = npz.get("genes", None)
if genes is None:
    try:
        import scanpy as sc
        genes = sc.read_h5ad(H5AD).var_names.to_numpy()
    except Exception:
        genes = np.array([f"Gene_{i}" for i in range(pred.shape[1])])

# ======= 选择一个扰动（示例：EIF2B2+ctrl）=======
target = "AHR+ctrl" if DATA.lower() == "norman" else "EIF2B2+ctrl"
mask = (pert_cat == target)
if mask.sum() == 0:
    raise ValueError(f"在结果里找不到扰动：{target}")

# ======= 计算 Δ 分布 =======
# 每个细胞的 Δ = 表达(perturbed) - 控制均值；truth/pred 都得到 (n_cells, n_genes)
delta_t_cells = truth[mask] - ctrl
delta_p_cells = pred[mask]  - ctrl

# 用真实 Δ 的"细胞均值"的绝对值排序，取 Top-K 基因
truth_mean_all = delta_t_cells.mean(axis=0)        # (n_genes,)
order = np.argsort(np.abs(truth_mean_all))[::-1][:TOPK]
genes_sel   = genes[order]
pred_mean   = delta_p_cells.mean(axis=0)[order]    # 预测均值（用于三角标记）
truth_mean  = truth_mean_all[order]                # 真实均值（用于方块标记）
# ======= 画图 =======
fig_h = max(6, TOPK * 0.42)  # 自适应高度
plt.figure(figsize=(8.5, fig_h))

# 箱线图数据（list of 1D arrays），每个基因为一行（Truth 的细胞分布）
data_for_box = [delta_t_cells[:, i] for i in order]
box = plt.boxplot(
    data_for_box,
    vert=False,
    labels=genes_sel,
    whis=1.5,
    showfliers=False,
    patch_artist=True,
    widths=0.6  # 加宽箱体
)

# ——箱线图颜色设置（修改中位线颜色和粗细）——
box_color = '#87CEEB'  # 天蓝色箱体
whisker_color = '#4682B4'  # 钢蓝色须线
median_color = '#2E8B57'  # 海绿色中位线

for patch in box['boxes']:
    patch.set_facecolor(box_color)  # 箱体填充色
    patch.set_edgecolor(whisker_color)  # 箱体边框色
    patch.set_alpha(0.7)  # 透明度
    patch.set_linewidth(1.5)
for whisker in box['whiskers']:
    whisker.set_color(whisker_color)
    whisker.set_linewidth(1.5)
for cap in box['caps']:
    cap.set_color(whisker_color)
    cap.set_linewidth(1.5)
for median in box['medians']:
    median.set_color(median_color)
    median.set_linewidth(3)  # 加粗中位线

# y 轴位置（箱线图内部是 1..K）
ypos = np.arange(1, len(order) + 1)

# ——只保留Pred Mean的标记点——
pred_marker = plt.plot(pred_mean, ypos, '^', markersize=10, 
                       label="Pred Mean", color='#FF8C00',  # 深橙色
                       markerfacecolor='#FFA07A',  # 淡橘色填充
                       markeredgecolor='#FF8C00',  # 深橙色边框
                       linestyle='None',
                       markeredgewidth=1.2)

# x=0 竖直虚线
plt.axvline(0.0, linestyle='--', linewidth=1.5, color="gray", alpha=0.7)

# 坐标轴和标题设置
plt.xlabel("Change in Gene Expression over Control", fontsize=12, labelpad=10)
plt.ylabel("Genes", fontsize=12, labelpad=10)
# plt.title(f"{target}", 
#           fontsize=14, pad=20, loc='center')  # 确保标题居中

# 调整坐标轴刻度
plt.tick_params(axis='both', which='major', labelsize=10)

# 创建自定义图例（长方形表示Truth）
from matplotlib.patches import Rectangle

# 创建自定义图例句柄
truth_legend = Rectangle((0,0), 1, 0.5, facecolor='#87CEEB', edgecolor='#4682B4',
                        linewidth=1.2, label='Truth')
pred_legend = plt.Line2D([], [], marker='^', markersize=10, 
                        color='#FF8C00', markerfacecolor='#FFA07A',
                        linestyle='None', label='Pred Mean')

# 添加图例（右上方）
plt.legend(handles=[truth_legend, pred_legend],
           loc='upper right', framealpha=1, edgecolor='black', fontsize=10)

# 设置固定横轴范围[-2,4]
plt.xlim(-2, 4)

# 设置网格线
plt.grid(axis='x', linestyle=':', alpha=0.5)

# 调整布局和保存
plt.tight_layout()
save_path = os.path.join(OUT, f"{target}_top{TOPK}_enhanced_plot.svg")
plt.savefig(save_path, dpi=300, bbox_inches="tight")
plt.close()  # 关闭图形以释放内存
print("[Saved]", save_path)

  box = plt.boxplot(


[Saved] /home/lxz/scmamba/gene_pretubation/GEARS/results/eval_outputs/AHR+ctrl_top20_enhanced_plot.svg


In [30]:
import os
import numpy as np
import matplotlib.pyplot as plt

# ======= 路径与参数 =======
H5AD = "/home/lxz/scmamba/gene_pretubation/GEARS/data/norman.h5ad"  # 仅用于取基因名，可不读
NPZ  = "/home/lxz/scmamba/gene_pretubation/GEARS/results/eval_outputs/mamba_norman_eval_results.npz"
DATA = "norman"                         # "norman" 或 "adamson"
TOPK = 20
OUT  = "/home/lxz/scmamba/gene_pretubation/GEARS/results/eval_outputs/"
os.makedirs(OUT, exist_ok=True)

# ========== 如果你本机有 scanpy，解开下两行即可从 h5ad 取基因名 ==========
import scanpy as sc
genes = sc.read_h5ad(H5AD).var_names.to_numpy()

# ======= 读取评估结果 =======
npz = np.load(NPZ, allow_pickle=True)
pred, truth, ctrl, pert_cat = npz["pred"], npz["truth"], npz["ctrl"], npz["pert_cat"]

# 这里尝试从 npz 取基因名；若没有，则回退为索引字符串（Gene_0,...）
genes = npz.get("genes", None)
if genes is None:
    try:
        import scanpy as sc
        genes = sc.read_h5ad(H5AD).var_names.to_numpy()
    except Exception:
        genes = np.array([f"Gene_{i}" for i in range(pred.shape[1])])

# ======= 选择一个扰动（示例：EIF2B2+ctrl）=======
target = "BPGM+ZBTB1" if DATA.lower() == "norman" else "EIF2B2+ctrl"
mask = (pert_cat == target)
if mask.sum() == 0:
    raise ValueError(f"在结果里找不到扰动：{target}")

# ======= 计算 Δ 分布 =======
# 每个细胞的 Δ = 表达(perturbed) - 控制均值；truth/pred 都得到 (n_cells, n_genes)
delta_t_cells = truth[mask] - ctrl
delta_p_cells = pred[mask]  - ctrl

# 用真实 Δ 的"细胞均值"的绝对值排序，取 Top-K 基因
truth_mean_all = delta_t_cells.mean(axis=0)        # (n_genes,)
order = np.argsort(np.abs(truth_mean_all))[::-1][:TOPK]
genes_sel   = genes[order]
pred_mean   = delta_p_cells.mean(axis=0)[order]    # 预测均值（用于三角标记）
truth_mean  = truth_mean_all[order]                # 真实均值（用于方块标记）
# ======= 画图 =======
fig_h = max(6, TOPK * 0.42)  # 自适应高度
plt.figure(figsize=(8.5, fig_h))

# 箱线图数据（list of 1D arrays），每个基因为一行（Truth 的细胞分布）
data_for_box = [delta_t_cells[:, i] for i in order]
box = plt.boxplot(
    data_for_box,
    vert=False,
    labels=genes_sel,
    whis=1.5,
    showfliers=False,
    patch_artist=True,
    widths=0.6  # 加宽箱体
)

# ——箱线图颜色设置（修改中位线颜色和粗细）——
box_color = '#B0E0E6'  # 淡蓝色箱体
whisker_color = '#4169E1'  # 藏青色须线
median_color = '#228B22'  # 森林绿中位线

for patch in box['boxes']:
    patch.set_facecolor(box_color)  # 箱体填充色
    patch.set_edgecolor(whisker_color)  # 箱体边框色
    patch.set_alpha(0.7)  # 透明度
    patch.set_linewidth(1.5)
for whisker in box['whiskers']:
    whisker.set_color(whisker_color)
    whisker.set_linewidth(1.5)
for cap in box['caps']:
    cap.set_color(whisker_color)
    cap.set_linewidth(1.5)
for median in box['medians']:
    median.set_color(median_color)
    median.set_linewidth(3)  # 加粗中位线

# y 轴位置（箱线图内部是 1..K）
ypos = np.arange(1, len(order) + 1)

# ——只保留Pred Mean的标记点——
pred_marker = plt.plot(pred_mean, ypos, '^', markersize=10, 
                       label="Pred Mean", color='#FF4500',  # 橙红色
                       markerfacecolor='#FF7F50',  # 珊瑚色填充
                       markeredgecolor='#FF4500',  # 橙红色边框
                       linestyle='None',
                       markeredgewidth=1.2)

# x=0 竖直虚线
plt.axvline(0.0, linestyle='--', linewidth=1.5, color="gray", alpha=0.7)

# 坐标轴和标题设置
plt.xlabel("Change in Gene Expression over Control", fontsize=12, labelpad=10)
plt.ylabel("Genes", fontsize=12, labelpad=10)

# 调整坐标轴刻度
plt.tick_params(axis='both', which='major', labelsize=10)

# 创建自定义图例（长方形表示Truth）
from matplotlib.patches import Rectangle

# 创建自定义图例句柄
truth_legend = Rectangle((0,0), 1, 0.5, facecolor='#B0E0E6', edgecolor='#4169E1',
                        linewidth=1.2, label='Truth')
pred_legend = plt.Line2D([], [], marker='^', markersize=10, 
                        color='#FF4500', markerfacecolor='#FF7F50',
                        linestyle='None', label='Pred Mean')

# 添加图例（右上方）
plt.legend(handles=[truth_legend, pred_legend],
           loc='upper right', framealpha=1, edgecolor='black', fontsize=10)

# 设置固定横轴范围[-2,4]
plt.xlim(-2, 5)

# 设置网格线
plt.grid(axis='x', linestyle=':', alpha=0.5, color='#DDDDDD')

# 调整布局和保存
plt.tight_layout()
save_path = os.path.join(OUT, f"{target}_top{TOPK}_enhanced_plot.svg")
plt.savefig(save_path, dpi=300, bbox_inches="tight")
plt.close()  # 关闭图形以释放内存
print("[Saved]", save_path)

  box = plt.boxplot(


[Saved] /home/lxz/scmamba/gene_pretubation/GEARS/results/eval_outputs/BPGM+ZBTB1_top20_enhanced_plot.svg


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# 数据准备
tools = ['KcellFM', 'scfoundation', 'scgpt', 'geneformer']
colors = ['#4e79a7', '#59a14f', '#e15759', '#79706e']  # 协调的颜色组合

# 原始数据
norman_mse = [0.1851, 0.1927, 0.1845, 0.1915]
norman_pcc = [0.6840, 0.6709, 0.6850, 0.6555]
adamson_mse = [0.1641, 0.1749, 0.1669, 0.1692]
adamson_pcc = [0.7886, 0.7858, 0.7842, 0.7882]

# 确定统一的Y轴范围
mse_min = min(min(norman_mse), min(adamson_mse)) - 0.01
mse_max = max(max(norman_mse), max(adamson_mse)) + 0.01

pcc_min = min(min(norman_pcc), min(adamson_pcc)) - 0.01
pcc_max = max(max(norman_pcc), max(adamson_pcc)) + 0.01

# 创建2x2的画布
fig, axs = plt.subplots(2, 2, figsize=(14, 10), dpi=100)
plt.subplots_adjust(hspace=0.3, wspace=0.3)  # 调整子图间距

# --------------------------
# 1. Norman MSE
# --------------------------
axs[0,0].bar(tools, norman_mse, color=colors, edgecolor='black', linewidth=0.5)
axs[0,0].set_title('Norman Dataset - MSE (Lower is Better)', fontweight='bold')
axs[0,0].set_ylim(mse_min, mse_max)  # 使用统一的MSE范围
axs[0,0].grid(axis='y', linestyle='--', alpha=0.6)

# --------------------------
# 2. Norman PCC
# --------------------------
axs[0,1].bar(tools, norman_pcc, color=colors, edgecolor='black', linewidth=0.5)
axs[0,1].set_title('Norman Dataset - PCC', fontweight='bold')
axs[0,1].set_ylim(pcc_min, pcc_max)  # 使用统一的PCC范围
axs[0,1].grid(axis='y', linestyle='--', alpha=0.6)

# --------------------------
# 3. Adamson MSE
# --------------------------
axs[1,0].bar(tools, adamson_mse, color=colors, edgecolor='black', linewidth=0.5)
axs[1,0].set_title('Adamson Dataset - MSE', fontweight='bold')
axs[1,0].set_ylim(mse_min, mse_max)  # 使用统一的MSE范围
axs[1,0].grid(axis='y', linestyle='--', alpha=0.6)

# --------------------------
# 4. Adamson PCC
# --------------------------
axs[1,1].bar(tools, adamson_pcc, color=colors, edgecolor='black', linewidth=0.5)
axs[1,1].set_title('Adamson Dataset - PCC (Higher is Better)', fontweight='bold')
axs[1,1].set_ylim(pcc_min, pcc_max)  # 使用统一的PCC范围
axs[1,1].grid(axis='y', linestyle='--', alpha=0.6)

# 保存为SVG
plt.savefig('performance_breakdown.svg', format='svg', bbox_inches='tight')
plt.show()