In [1]:
import os
import scanpy as sc
import scvi
import json
from sklearn.model_selection import train_test_split
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

# Define data paths
data_dir = "data_input"
os.makedirs(data_dir, exist_ok=True)

pancreas_adata_path = os.path.join(data_dir, "pancreas_full.h5ad")
train_path = os.path.join(data_dir, "pancreas_train.h5ad")
valid_path = os.path.join(data_dir, "pancreas_valid.h5ad")
test_path  = os.path.join(data_dir, "pancreas_test.h5ad")

# Download if missing, otherwise load from local file
pancreas_adata = sc.read(
    pancreas_adata_path,
    backup_url="https://figshare.com/ndownloader/files/24539828",
)

# Split dataset by technology: keep smartseq2/celseq2 as held-out test
query_mask = pancreas_adata.obs["tech"].isin(["smartseq2", "celseq2"]).to_numpy()
pancreas_no_test = pancreas_adata[~query_mask].copy()
pancreas_test    = pancreas_adata[ query_mask].copy()

# 80/20 train/valid split on the remaining data, stratified by technology
y = pancreas_no_test.obs["tech"].astype("category")
indices = np.arange(pancreas_no_test.n_obs)

idx_train, idx_valid = train_test_split(
    indices,
    test_size=0.20,
    train_size=0.80,
    random_state=42,
    shuffle=True,
    stratify=y  # stratify by technology
)

pancreas_train = pancreas_no_test[idx_train].copy()
pancreas_valid = pancreas_no_test[idx_valid].copy()

# Save splits
pancreas_train.write(train_path)
pancreas_valid.write(valid_path)
pancreas_test.write(test_path)

print(
    f"Train: {pancreas_train.n_obs} cells | "
    f"Valid: {pancreas_valid.n_obs} cells | "
    f"Test: {pancreas_test.n_obs} cells"
)

# Print counts per technology
print("\nCells per technology:")
for name, ad in [("Train", pancreas_train),
                 ("Valid", pancreas_valid),
                 ("Test", pancreas_test)]:
    counts = ad.obs["tech"].value_counts().sort_index()
    print(f"\n{name} split:")
    for tech, n in counts.items():
        print(f"  {tech}: {n}")

# --- Cleanup: delete the original full dataset file ---
del pancreas_adata  # drop reference to ensure no open handle
try:
    if os.path.exists(pancreas_adata_path):
        os.remove(pancreas_adata_path)
        print(f"Deleted '{pancreas_adata_path}'")
except Exception as e:
    print(f"[WARN] Could not delete '{pancreas_adata_path}': {e}")

100%|██████████| 301M/301M [00:21<00:00, 14.5MB/s] 


Train: 9362 cells | Valid: 2341 cells | Test: 4679 cells

Cells per technology:

Train split:
  celseq: 803
  fluidigmc1: 510
  inDrop1: 1550
  inDrop2: 1379
  inDrop3: 2884
  inDrop4: 1042
  smarter: 1194

Valid split:
  celseq: 201
  fluidigmc1: 128
  inDrop1: 387
  inDrop2: 345
  inDrop3: 721
  inDrop4: 261
  smarter: 298

Test split:
  celseq2: 2285
  smartseq2: 2394
Deleted 'data_input/pancreas_full.h5ad'


In [None]:
# Utility to load HVG list
def load_hvg_list(hvg_list_path):
    with open(hvg_list_path) as f:
        return json.load(f)

hvg_list = load_hvg_list("data_input/hvg_list.json")

# Restrict to HVG genes
pancreas_train = pancreas_train[:, hvg_list].copy()




pancreas_train
AnnData object with n_obs × n_vars = 9362 × 2000
    obs: 'tech', 'celltype', 'size_factors'
    layers: 'counts'


In [17]:
print("pancreas_train")
print(pancreas_train.X)
print(pancreas_train.X.shape)

print("\n pancreas_train obs")
print(pancreas_train.obs.head())
print(pancreas_train.obs.shape)


print("\n pancreas_train.obs.columns")
print(pancreas_train.obs.columns)

print("\npancreas_train.var.head()")
print(pancreas_train.var.head())

pancreas_train
[[0.         0.         0.         ... 0.         0.86844474 0.        ]
 [0.         0.         0.         ... 0.         0.         0.        ]
 [0.         0.         0.         ... 0.         0.         0.        ]
 ...
 [0.         0.         0.         ... 0.         0.         0.        ]
 [0.         0.         0.         ... 0.         0.         0.        ]
 [0.         0.         0.         ... 0.         0.         0.        ]]
(9362, 2000)

 pancreas_train obs
                                   tech celltype  size_factors  _scvi_batch  \
3rd-C86_S85                  fluidigmc1    delta      5.060723            1   
human3_lib4.final_cell_0804     inDrop3    alpha      0.010361            4   
human3_lib4.final_cell_0815     inDrop3     beta      0.011553            4   
Sample_163                      smarter     beta      1.000000            6   
human3_lib1.final_cell_0737     inDrop3    alpha      0.014493            4   

                             _sc

## Train of scVI model


In [7]:
scvi.model.SCVI.setup_anndata(pancreas_train, batch_key="tech", layer="counts")

scvi_ref = scvi.model.SCVI(
    pancreas_train,
    use_layer_norm="both",
    use_batch_norm="none",
    encode_covariates=True,
    dropout_rate=0.2,
    n_layers=2,
)
scvi_ref.train(max_epochs=50)

  self.validate_field(adata)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/katwre/miniconda3/envs/fl-course-env/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Epoch 1/50:   0%|          | 0/50 [00:00<?, ?it/s]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 2/50:   2%|▏         | 1/50 [00:02<01:57,  2.40s/it, v_num=1, train_loss_step=962, train_loss_epoch=1.31e+3]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 3/50:   4%|▍         | 2/50 [00:04<01:47,  2.24s/it, v_num=1, train_loss_step=1.08e+3, train_loss_epoch=1.05e+3]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 4/50:   6%|▌         | 3/50 [00:06<01:43,  2.20s/it, v_num=1, train_loss_step=977, train_loss_epoch=971]        

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 5/50:   8%|▊         | 4/50 [00:08<01:41,  2.20s/it, v_num=1, train_loss_step=853, train_loss_epoch=932]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 6/50:  10%|█         | 5/50 [00:11<01:42,  2.28s/it, v_num=1, train_loss_step=973, train_loss_epoch=911]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 7/50:  12%|█▏        | 6/50 [00:13<01:41,  2.30s/it, v_num=1, train_loss_step=856, train_loss_epoch=896]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 8/50:  14%|█▍        | 7/50 [00:16<01:42,  2.38s/it, v_num=1, train_loss_step=944, train_loss_epoch=886]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 9/50:  16%|█▌        | 8/50 [00:18<01:45,  2.52s/it, v_num=1, train_loss_step=857, train_loss_epoch=877]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 10/50:  18%|█▊        | 9/50 [00:21<01:45,  2.57s/it, v_num=1, train_loss_step=945, train_loss_epoch=871]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 11/50:  20%|██        | 10/50 [00:24<01:40,  2.50s/it, v_num=1, train_loss_step=993, train_loss_epoch=866]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 12/50:  22%|██▏       | 11/50 [00:26<01:38,  2.53s/it, v_num=1, train_loss_step=886, train_loss_epoch=862]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 13/50:  24%|██▍       | 12/50 [00:29<01:37,  2.56s/it, v_num=1, train_loss_step=926, train_loss_epoch=857]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 14/50:  26%|██▌       | 13/50 [00:31<01:34,  2.54s/it, v_num=1, train_loss_step=804, train_loss_epoch=854]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 15/50:  28%|██▊       | 14/50 [00:34<01:28,  2.47s/it, v_num=1, train_loss_step=893, train_loss_epoch=851]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 16/50:  30%|███       | 15/50 [00:36<01:25,  2.44s/it, v_num=1, train_loss_step=845, train_loss_epoch=848]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 17/50:  32%|███▏      | 16/50 [00:38<01:20,  2.38s/it, v_num=1, train_loss_step=792, train_loss_epoch=845]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 18/50:  34%|███▍      | 17/50 [00:41<01:18,  2.39s/it, v_num=1, train_loss_step=757, train_loss_epoch=842]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 19/50:  36%|███▌      | 18/50 [00:43<01:16,  2.40s/it, v_num=1, train_loss_step=818, train_loss_epoch=840]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 20/50:  38%|███▊      | 19/50 [00:45<01:14,  2.39s/it, v_num=1, train_loss_step=857, train_loss_epoch=837]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 21/50:  40%|████      | 20/50 [00:48<01:11,  2.39s/it, v_num=1, train_loss_step=777, train_loss_epoch=835]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 22/50:  42%|████▏     | 21/50 [00:50<01:09,  2.40s/it, v_num=1, train_loss_step=826, train_loss_epoch=833]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 23/50:  44%|████▍     | 22/50 [00:53<01:07,  2.42s/it, v_num=1, train_loss_step=928, train_loss_epoch=831]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 24/50:  46%|████▌     | 23/50 [00:55<01:06,  2.48s/it, v_num=1, train_loss_step=875, train_loss_epoch=829]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 25/50:  48%|████▊     | 24/50 [00:58<01:08,  2.63s/it, v_num=1, train_loss_step=928, train_loss_epoch=827]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 26/50:  50%|█████     | 25/50 [01:01<01:05,  2.60s/it, v_num=1, train_loss_step=792, train_loss_epoch=826]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 27/50:  52%|█████▏    | 26/50 [01:03<01:02,  2.61s/it, v_num=1, train_loss_step=781, train_loss_epoch=824]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 28/50:  54%|█████▍    | 27/50 [01:06<00:59,  2.58s/it, v_num=1, train_loss_step=759, train_loss_epoch=823]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 29/50:  56%|█████▌    | 28/50 [01:08<00:55,  2.52s/it, v_num=1, train_loss_step=923, train_loss_epoch=821]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 30/50:  58%|█████▊    | 29/50 [01:11<00:52,  2.49s/it, v_num=1, train_loss_step=749, train_loss_epoch=820]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 31/50:  60%|██████    | 30/50 [01:13<00:49,  2.48s/it, v_num=1, train_loss_step=789, train_loss_epoch=818]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 32/50:  62%|██████▏   | 31/50 [01:16<00:47,  2.48s/it, v_num=1, train_loss_step=870, train_loss_epoch=817]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 33/50:  64%|██████▍   | 32/50 [01:18<00:44,  2.45s/it, v_num=1, train_loss_step=969, train_loss_epoch=816]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 34/50:  66%|██████▌   | 33/50 [01:20<00:41,  2.42s/it, v_num=1, train_loss_step=810, train_loss_epoch=815]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 35/50:  68%|██████▊   | 34/50 [01:23<00:39,  2.46s/it, v_num=1, train_loss_step=732, train_loss_epoch=814]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 36/50:  70%|███████   | 35/50 [01:25<00:36,  2.46s/it, v_num=1, train_loss_step=742, train_loss_epoch=813]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 37/50:  72%|███████▏  | 36/50 [01:28<00:34,  2.49s/it, v_num=1, train_loss_step=756, train_loss_epoch=812]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 38/50:  74%|███████▍  | 37/50 [01:31<00:32,  2.51s/it, v_num=1, train_loss_step=902, train_loss_epoch=811]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 39/50:  76%|███████▌  | 38/50 [01:33<00:30,  2.56s/it, v_num=1, train_loss_step=748, train_loss_epoch=810]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 40/50:  78%|███████▊  | 39/50 [01:36<00:29,  2.64s/it, v_num=1, train_loss_step=827, train_loss_epoch=809]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 41/50:  80%|████████  | 40/50 [01:39<00:26,  2.66s/it, v_num=1, train_loss_step=781, train_loss_epoch=809]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 42/50:  82%|████████▏ | 41/50 [01:41<00:23,  2.64s/it, v_num=1, train_loss_step=751, train_loss_epoch=808]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 43/50:  84%|████████▍ | 42/50 [01:44<00:21,  2.66s/it, v_num=1, train_loss_step=785, train_loss_epoch=807]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 44/50:  86%|████████▌ | 43/50 [01:47<00:18,  2.65s/it, v_num=1, train_loss_step=735, train_loss_epoch=806]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 45/50:  88%|████████▊ | 44/50 [01:49<00:15,  2.66s/it, v_num=1, train_loss_step=871, train_loss_epoch=805]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 46/50:  90%|█████████ | 45/50 [01:52<00:13,  2.70s/it, v_num=1, train_loss_step=764, train_loss_epoch=805]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 47/50:  92%|█████████▏| 46/50 [01:54<00:10,  2.57s/it, v_num=1, train_loss_step=941, train_loss_epoch=804]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 48/50:  94%|█████████▍| 47/50 [01:57<00:07,  2.52s/it, v_num=1, train_loss_step=732, train_loss_epoch=804]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 49/50:  96%|█████████▌| 48/50 [01:59<00:04,  2.46s/it, v_num=1, train_loss_step=840, train_loss_epoch=803]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 50/50:  98%|█████████▊| 49/50 [02:01<00:02,  2.41s/it, v_num=1, train_loss_step=883, train_loss_epoch=802]

  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)


Epoch 50/50: 100%|██████████| 50/50 [02:04<00:00,  2.43s/it, v_num=1, train_loss_step=788, train_loss_epoch=802]

`Trainer.fit` stopped: `max_epochs=50` reached.


Epoch 50/50: 100%|██████████| 50/50 [02:04<00:00,  2.49s/it, v_num=1, train_loss_step=788, train_loss_epoch=802]


In [8]:
scvi_ref.save("model_centralized", overwrite=True)