In [2]:
import scvi
import scanpy as sc
import anndata
import pandas as pd
import time
from sklearn.metrics import accuracy_score, f1_score, classification_report

# 加载
adata_ref = sc.read_h5ad("pbmc_1.h5ad")
adata_query = sc.read_h5ad("pbmc_2.h5ad")

In [3]:
# ✅ 储存 pbmc2 的真实标签，用于评估后预测效果
adata_query.obs["true_label"] = adata_query.obs["cell_type"].astype(str)

# 接下来覆盖为 Unknown
adata_query.obs["cell_type"] = "Unknown"

# 标记数据来源
adata_ref.obs["dataset"] = "reference"
adata_query.obs["dataset"] = "query"

# 合并
adata_all = adata_ref.concatenate(adata_query, batch_key="dataset", uns_merge="unique")

  adata_all = adata_ref.concatenate(adata_query, batch_key="dataset", uns_merge="unique")


In [4]:
# 注册给 scvi 使用
scvi.model.SCVI.setup_anndata(adata_all, labels_key="cell_type", batch_key="dataset")

# Pre-train SCVI
start_scvi = time.time()
model = scvi.model.SCVI(adata_all, n_latent=30)
model.train(max_epochs=100)
scvi_time = time.time() - start_scvi

# 转换为 SCANVI 模型
start_scanvi = time.time()
scanvi_model = scvi.model.SCANVI.from_scvi_model(model, unlabeled_category="Unknown")
scanvi_model.train(max_epochs=25)
scanvi_time = time.time() - start_scanvi

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-PCIE-40GB MIG 1g.10gb') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/root/miniconda3/envs/scvi-env/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=2` in the `DataLoader` to improve performance.


Epoch 100/100: 100%|██████████| 100/100 [02:01<00:00,  1.20s/it, v_num=1, train_loss_step=2.93e+3, train_loss_epoch=2.99e+3]

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 100/100: 100%|██████████| 100/100 [02:01<00:00,  1.21s/it, v_num=1, train_loss_step=2.93e+3, train_loss_epoch=2.99e+3]
[34mINFO    [0m Training for [1;36m25[0m epochs.                                                                                   


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/root/miniconda3/envs/scvi-env/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=2` in the `DataLoader` to improve performance.


Epoch 25/25: 100%|██████████| 25/25 [00:55<00:00,  2.23s/it, v_num=1, train_loss_step=2.93e+3, train_loss_epoch=2.96e+3]

`Trainer.fit` stopped: `max_epochs=25` reached.


Epoch 25/25: 100%|██████████| 25/25 [00:55<00:00,  2.23s/it, v_num=1, train_loss_step=2.93e+3, train_loss_epoch=2.96e+3]


In [36]:
start_predict = time.time()
pred_labels_all = scanvi_model.predict(adata_all)
predict_time = time.time() - start_predict

true_labels = adata_query.obs["true_label"]

In [38]:
query_idx = adata_all.obs["dataset"] == "1"
pred_labels = pred_labels_all[query_idx]

In [44]:
pd.DataFrame(true_labels).to_csv("scANVI_True_Labels_cross.csv", index=False)
pd.DataFrame(pred_labels).to_csv("scANVI_Pred_Labels_cross.csv", index=False)
pd.DataFrame({"Training_Time": [scvi_time + scanvi_time]}).to_csv("scANVI_Training_Times_cross.csv", index=False)
pd.DataFrame({"Testing_Time": [predict_time]}).to_csv("scANVI_Testing_Times_cross.csv", index=False)

In [None]:
acc = accuracy_score(true_labels, pred_labels)
f1 = f1_score(true_labels, pred_labels, average="weighted")

print(f"Accuracy: {acc:.4f}")
print(f"F1 Score: {f1:.4f}")
print(classification_report(true_labels, pred_labels))


=== scANVI 1/5训练结果 ===
Accuracy: 0.9387
F1 Score: 0.9382
              precision    recall  f1-score   support

      B cell       0.99      1.00      0.99       250
  CD4 T cell       0.96      0.96      0.96      1238
  CD8 T cell       0.85      0.91      0.88       676
     NK cell       0.98      0.77      0.86       270
 Plasma cell       1.00      0.50      0.67         6
         cDC       0.94      0.85      0.89        20
       cMono       0.99      0.98      0.99       409
      ncMono       0.94      0.99      0.96       119
         pDC       1.00      0.83      0.91        12

    accuracy                           0.94      3000
   macro avg       0.96      0.87      0.90      3000
weighted avg       0.94      0.94      0.94      3000

