In [1]:
import os
import torch
import anndata as ad
import scanpy as sc
from torch.utils.tensorboard import SummaryWriter

# 导入我们的模块
from models.cl_scetm import CL_scETM
from trainers.cl_scETM_trainer import CL_scETM_Trainer
from data.preprocess import preprocess_data, setup_anndata,read_data
import logging

# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


--------------------------------------------------------------------------------

  CuPy may not function correctly because multiple CuPy packages are installed
  in your environment:

    cupy-cuda101, cupy-cuda117

  Follow these steps to resolve this issue:

    1. For all packages listed above, run the following command to remove all
       existing CuPy installations:

         $ pip uninstall <package_name>

      If you previously installed CuPy via conda, also run the following:

         $ conda uninstall cupy

    2. Install the appropriate CuPy package.
       Refer to the Installation Guide for detailed instructions.

         https://docs.cupy.dev/en/stable/install.html

--------------------------------------------------------------------------------

2025-05-23 03:58:41.343353: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To tu

In [2]:
adata = read_data(file_path='/volume1/home/pxie/data/PBMC.h5ad')

# 预处理数据
adata = preprocess_data(
    adata,
    normalize=True,
    log_transform=True,
    scale=False,
    min_cells=3,
    min_genes=200,
            hvg_selection=True,
        n_top_genes=2000
)

INFO:data.preprocess:读取了66985个细胞，36263个基因
INFO:data.preprocess:预处理后数据：66944个细胞，2000个基因


In [3]:
adata = setup_anndata(adata, batch_col='batch', cell_type_col='cell_type')



In [4]:
model = CL_scETM(
    n_genes=adata.n_vars,
    n_topics=50,
    hidden_sizes=[128],
    gene_emb_dim=400,
    bn=True,
    dropout_prob=0.1,
    n_batches=adata.obs['batch'].nunique() if 'batch' in adata.obs else 1,
    normalize_beta=False,
    input_batch_id=True,
    enable_batch_bias=True,
    enable_global_bias=False,
    prior_type='standard',  # 使用标准先验
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
)

In [None]:
    trainer = CL_scETM_Trainer(
        model=model,
        adata=adata,
        ckpt_dir='./saved_models/standard_prior',
        test_ratio=0,
        learning_rate=5e-3,
        batch_size=1024,
        train_instance_name='pbmc_standard',
        seed=42
    )

INFO:trainers.trainer_utils:设置种子为 42。
INFO:trainers.cl_scETM_trainer:检查点目录: ./saved_models/standard_prior/pbmc_standard_05_23-03_59_03


In [6]:
    history = trainer.train(
        n_epochs=100,
        eval_every=10,
        save_every=20, 
        batch_col='batch',
        num_workers=0
    )

INFO:trainers.cl_scETM_trainer:开始训练任务 0


loss:      562.2	nll:      552.7	kl:      31.64	Epoch     9/  100	Next ckpt:      10

INFO:trainers.cl_scETM_trainer:
INFO:trainers.cl_scETM_trainer:学习率: 0.005000, KL权重: 0.300000
INFO:trainers.trainer_utils:loss        :      562.9
INFO:trainers.trainer_utils:nll         :      554.1
INFO:trainers.trainer_utils:kl          :  6.858e+05


loss:      566.4	nll:        559	kl:      11.64	Epoch    19/  100	Next ckpt:      20

INFO:trainers.cl_scETM_trainer:
INFO:trainers.cl_scETM_trainer:学习率: 0.005000, KL权重: 0.633333
INFO:trainers.trainer_utils:loss        :      564.9
INFO:trainers.trainer_utils:nll         :      556.9
INFO:trainers.trainer_utils:kl          :      17.52
INFO:trainers.trainer_utils:检查点已保存到 ./saved_models/standard_prior/pbmc_standard_05_23-03_59_03


loss:      568.6	nll:      560.9	kl:       8.02	Epoch    29/  100	Next ckpt:      30

INFO:trainers.cl_scETM_trainer:
INFO:trainers.cl_scETM_trainer:学习率: 0.005000, KL权重: 0.966667
INFO:trainers.trainer_utils:loss        :      567.6
INFO:trainers.trainer_utils:nll         :      560.2
INFO:trainers.trainer_utils:kl          :      9.273


loss:      568.4	nll:      560.7	kl:      7.705	Epoch    39/  100	Next ckpt:      40

INFO:trainers.cl_scETM_trainer:
INFO:trainers.cl_scETM_trainer:学习率: 0.005000, KL权重: 1.000000
INFO:trainers.trainer_utils:loss        :      568.6
INFO:trainers.trainer_utils:nll         :      560.9
INFO:trainers.trainer_utils:kl          :      7.727
INFO:trainers.trainer_utils:检查点已保存到 ./saved_models/standard_prior/pbmc_standard_05_23-03_59_03


loss:      568.1	nll:      560.5	kl:      7.645	Epoch    49/  100	Next ckpt:      50

INFO:trainers.cl_scETM_trainer:
INFO:trainers.cl_scETM_trainer:学习率: 0.005000, KL权重: 1.000000
INFO:trainers.trainer_utils:loss        :      568.3
INFO:trainers.trainer_utils:nll         :      560.6
INFO:trainers.trainer_utils:kl          :      7.679


loss:        568	nll:      560.3	kl:      7.633	Epoch    59/  100	Next ckpt:      60

INFO:trainers.cl_scETM_trainer:
INFO:trainers.cl_scETM_trainer:学习率: 0.005000, KL权重: 1.000000
INFO:trainers.trainer_utils:loss        :      568.1
INFO:trainers.trainer_utils:nll         :      560.4
INFO:trainers.trainer_utils:kl          :      7.643
INFO:trainers.trainer_utils:检查点已保存到 ./saved_models/standard_prior/pbmc_standard_05_23-03_59_03


loss:      567.8	nll:      560.2	kl:      7.583	Epoch    69/  100	Next ckpt:      70

INFO:trainers.cl_scETM_trainer:
INFO:trainers.cl_scETM_trainer:学习率: 0.005000, KL权重: 1.000000
INFO:trainers.trainer_utils:loss        :      567.9
INFO:trainers.trainer_utils:nll         :      560.3
INFO:trainers.trainer_utils:kl          :      7.603


loss:      567.6	nll:      560.1	kl:      7.542	Epoch    79/  100	Next ckpt:      80

INFO:trainers.cl_scETM_trainer:
INFO:trainers.cl_scETM_trainer:学习率: 0.005000, KL权重: 1.000000
INFO:trainers.trainer_utils:loss        :      567.7
INFO:trainers.trainer_utils:nll         :      560.1
INFO:trainers.trainer_utils:kl          :      7.567
INFO:trainers.trainer_utils:检查点已保存到 ./saved_models/standard_prior/pbmc_standard_05_23-03_59_03


loss:      567.5	nll:      559.9	kl:      7.567	Epoch    89/  100	Next ckpt:      90

INFO:trainers.cl_scETM_trainer:
INFO:trainers.cl_scETM_trainer:学习率: 0.005000, KL权重: 1.000000
INFO:trainers.trainer_utils:loss        :      567.6
INFO:trainers.trainer_utils:nll         :        560
INFO:trainers.trainer_utils:kl          :      7.557


loss:      567.5	nll:      559.9	kl:      7.553	Epoch    99/  100	Next ckpt:     100

INFO:trainers.cl_scETM_trainer:
INFO:trainers.cl_scETM_trainer:学习率: 0.005000, KL权重: 1.000000
INFO:trainers.trainer_utils:loss        :      567.5
INFO:trainers.trainer_utils:nll         :      559.9
INFO:trainers.trainer_utils:kl          :      7.566
INFO:trainers.trainer_utils:检查点已保存到 ./saved_models/standard_prior/pbmc_standard_05_23-03_59_03
INFO:trainers.cl_scETM_trainer:训练完成！最佳测试NLL: inf
