In [1]:

import os
import argparse
import numpy as np
import torch
import pandas as pd
import random
import time
from tqdm import tqdm
from pathlib import Path

from model import CTM
from dataset import CTMDataset
from utils.cell_embedding import generate_cell_embedding as generate_cell_embedding_new


--------------------------------------------------------------------------------

  CuPy may not function correctly because multiple CuPy packages are installed
  in your environment:

    cupy-cuda101, cupy-cuda117

  Follow these steps to resolve this issue:

    1. For all packages listed above, run the following command to remove all
       existing CuPy installations:

         $ pip uninstall <package_name>

      If you previously installed CuPy via conda, also run the following:

         $ conda uninstall cupy

    2. Install the appropriate CuPy package.
       Refer to the Installation Guide for detailed instructions.

         https://docs.cupy.dev/en/stable/install.html

--------------------------------------------------------------------------------



In [2]:
def set_seed(seed):
    """
    设置随机种子，确保实验可重复性

    参数:
        seed: 随机种子
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # 如果使用多GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

In [3]:
def prepare_dataset(adata_path, gene_embedding_path, device, cell_embed_epochs=300, cell_embed_lr=1e-4):
    """
    准备数据集

    参数:
        adata_path: h5ad文件路径
        gene_embedding_path: 基因嵌入文件路径
        device: 使用的设备
        cell_embed_epochs: 细胞嵌入训练轮数
        cell_embed_lr: 细胞嵌入学习率

    返回:
        train_dataset: 训练数据集
        gene_counts_common: 基因计数
        cell_embeddings: 细胞嵌入
    """
    print(f"正在处理数据集: {adata_path}")

    # 生成细胞嵌入
    print("生成细胞嵌入...")
    cell_embeddings, gene_counts_common, _ = generate_cell_embedding_new(
        adata_path=adata_path,
        gene_embedding_path=gene_embedding_path
    )

    # 准备训练数据集
    num_genes = gene_counts_common.n_vars
    id2token = {i: list(gene_counts_common.var_names)[i] for i in range(num_genes)}
    cell_embeddings_numpy = cell_embeddings.detach().numpy()

    train_dataset = CTMDataset(
        X_contextual=cell_embeddings_numpy,
        X_bow=gene_counts_common.X.toarray(),
        idx2token=id2token
    )

    print(f"数据集准备完成，基因数: {num_genes}, 细胞数: {len(cell_embeddings)}")

    return train_dataset, gene_counts_common, cell_embeddings

def train_model(train_dataset, cell_embeddings, num_topics, batch_size, lr, num_epochs,
                patience, device, prior_mean=None, prior_variance=None, seed=42):
    """
    训练主题模型

    参数:
        train_dataset: 训练数据集
        cell_embeddings: 细胞嵌入
        num_topics: 主题数量
        batch_size: 批次大小
        lr: 学习率
        num_epochs: 训练轮数
        patience: 早停耐心值
        device: 使用的设备
        prior_mean: 先验均值
        prior_variance: 先验方差
        seed: 随机种子

    返回:
        ctm: 训练好的模型
        prior_mean: 更新后的先验均值
        prior_variance: 更新后的先验方差
    """
    print("开始训练主题模型...")

    # 创建模型
    ctm = CTM(
        contextual_size=cell_embeddings.shape[1],
        bow_size=train_dataset.X_bow.shape[1],
        n_components=num_topics,
        batch_size=batch_size,
        device=device,
        lr=lr,
        num_epochs=num_epochs,
        prior_mean=prior_mean,
        prior_variance=prior_variance,
        seed=seed
    )

    # 训练模型
    prior_mean, prior_variance = ctm.fit(
        train_dataset,
        return_mean=False,
        patience=patience,
        seed=seed
    )

    print("模型训练完成")

    return ctm, prior_mean, prior_variance

def save_results(ctm, train_dataset, gene_counts_common, model_name, train_order, data_name,
                num_topics, output_dir, checkpoint_dir):
    """
    保存训练结果

    参数:
        ctm: 训练好的模型
        train_dataset: 训练数据集
        gene_counts_common: 基因计数
        model_name: 模型名称
        train_order: 训练顺序
        data_name: 数据集名称
        num_topics: 主题数量
        output_dir: 输出目录
        checkpoint_dir: 检查点目录
    """
    print(f"保存模型和结果: {model_name}_{train_order}_{data_name}")

    # 保存模型检查点
    checkpoint_name = f"{model_name}_{train_order}_{data_name}"
    ctm.save(model_dir=checkpoint_dir, part_name=checkpoint_name)

    # 获取主题分布
    topics_per_cell = ctm.get_thetas(train_dataset)
    df_topics_per_cell = pd.DataFrame(topics_per_cell)

    # 保存细胞-主题矩阵
    cell_output_file = os.path.join(output_dir, f"{model_name}_{train_order}_{data_name}_t{num_topics}_c.csv")
    df_topics_per_cell.to_csv(cell_output_file, index=False)
    print(f"细胞-主题矩阵已保存至: {cell_output_file}")

    # 获取主题-基因矩阵
    topics_per_gene = ctm.get_topic_word_matrix()
    df_topics_per_gene = pd.DataFrame(topics_per_gene)
    df_gene_topic = df_topics_per_gene.T
    df_gene_topic.index = gene_counts_common.var_names

    # 保存基因-主题矩阵
    gene_output_file = os.path.join(output_dir, f"{model_name}_{train_order}_{data_name}_t{num_topics}_g.csv")
    df_gene_topic.to_csv(gene_output_file)
    print(f"基因-主题矩阵已保存至: {gene_output_file}")

In [4]:

prior_mean = 0.0
prior_variance = None
train_dataset, gene_counts_common, cell_embeddings = prepare_dataset(
    adata_path="/volume1/home/pxie/data/PBMC.h5ad",
    gene_embedding_path="/volume1/home/pxie/data/embeddings/fused_geneformerv2_genePT.pkl",
    device="cuda"
)

正在处理数据集: /volume1/home/pxie/data/PBMC.h5ad
生成细胞嵌入...
AnnData中的基因数: 36263
基因嵌入中的基因数: 20271
共有基因数: 19132
数据集准备完成，基因数: 19132, 细胞数: 66985


In [None]:
ctm, prior_mean, prior_variance = train_model(
    train_dataset=train_dataset,
    cell_embeddings=cell_embeddings,
    num_topics=50,
    batch_size=1024,
    lr=2e-3,
    num_epochs=300,
    patience=10,
    device="cuda:0",
    prior_mean=prior_mean,
    prior_variance=prior_variance,
    seed=3407
)

# 保存结果
save_results(
    ctm=ctm,
    train_dataset=train_dataset,
    gene_counts_common=gene_counts_common,
    model_name='fusion_gpt_gfv2',
    train_order=1,
    data_name='PBMC',
    num_topics=50,
    output_dir="/volume1/home/pxie/topic_model/solution4/results",
    checkpoint_dir="/volume1/home/pxie/topic_model/solution4/results/checkpoint"
)

开始训练主题模型...
Settings: 
                N Components: 50
                Topic Prior Mean: 0.0
                Topic Prior Variance: None
                Model Type: prodLDA
                Hidden Sizes: (100, 100)
                Activation: softplus
                Dropout: 0.2
                Learn Priors: True
                Learning Rate: 0.002
                Momentum: 0.99
                Reduce On Plateau: False
                Save Dir: None


0it [00:00, ?it/s]

epoch: 0 loss: 4390.235614483173 best_loss: 4390.235614483173
epoch: 1 loss: 4337.835201322116 best_loss: 4337.835201322116
epoch: 2 loss: 4331.097896634616 best_loss: 4331.097896634616
epoch: 3 loss: 4326.9749399038465 best_loss: 4326.9749399038465
epoch: 4 loss: 4323.48251953125 best_loss: 4323.48251953125
epoch: 5 loss: 4321.752411358173 best_loss: 4321.752411358173
epoch: 6 loss: 4319.4267578125 best_loss: 4319.4267578125
epoch: 7 loss: 4319.096221454327 best_loss: 4319.096221454327
epoch: 8 loss: 4318.141180889423 best_loss: 4318.141180889423
epoch: 9 loss: 4315.889926382211 best_loss: 4315.889926382211
epoch: 10 loss: 4314.270395132212 best_loss: 4314.270395132212
epoch: 11 loss: 4296.521146334135 best_loss: 4296.521146334135
epoch: 12 loss: 4278.837154447116 best_loss: 4278.837154447116
epoch: 13 loss: 4261.7213792067305 best_loss: 4261.7213792067305
epoch: 14 loss: 4246.530506310096 best_loss: 4246.530506310096
epoch: 15 loss: 4242.146213942307 best_loss: 4242.146213942307
epoc