In [1]:
%matplotlib inline

import os
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from scMVP.dataset import LoadData,GeneExpressionDataset, CellMeasurement
from scMVP.models import VAE_Attention, Multi_VAE_Attention, VAE_Peak_SelfAttention
from scMVP.inference import UnsupervisedTrainer
from scMVP.inference import MultiPosterior, MultiTrainer
import torch

import scanpy as sc
import anndata

import scipy.io as sp_io
from scipy.sparse import csr_matrix, issparse

[2021-11-23 10:13:25,778] INFO - scMVP._settings | Added StreamHandler with custom formatter to 'scMVP' logger.


In [12]:
torch.cuda.is_available()
torch.set_num_threads(20)


In [3]:
input_path = "./10x_lymph_node/"
output_path = "./10x_lymph_node/scMVP_output"

lymph_node_dataset = {
                "gene_names": '10x_lymph_node_scale_gene.txt',
                "gene_expression": '10x_lymph_node_rna_normalize_count.mtx',
                "gene_barcodes": '10x_lymph_node_cell_barcode.txt',
                "atac_names": '10x_lymph_node_peak.txt',
                "atac_expression": '10x_lymph_node_atac_normalize_count.mtx',
                "atac_barcodes": '10x_lymph_node_cell_barcode.txt'
                }

dataset = LoadData(dataset=lymph_node_dataset,data_path=input_path,
                       dense=False,gzipped=False, atac_threshold=0.001,
                       cell_threshold=1)

[2021-11-23 10:20:39,072] INFO - scMVP.dataset.scMVP_dataloader | Preprocessing joint profiling dataset.
[2021-11-23 10:21:20,840] INFO - scMVP.dataset.scMVP_dataloader | Finished preprocessing dataset
[2021-11-23 10:21:20,927] INFO - scMVP.dataset.dataset | Remapping batch_indices to [0,N]
[2021-11-23 10:21:20,928] INFO - scMVP.dataset.dataset | Remapping labels to [0,N]
[2021-11-23 10:22:48,105] INFO - scMVP.dataset.dataset | Computing the library size for the new data
[2021-11-23 10:22:48,150] INFO - scMVP.dataset.dataset | Downsampled from 7039 to 7039 cells


In [10]:
# atac dataloader
atac_dataset = GeneExpressionDataset()
cell_attributes_dict = {
    "barcodes": dataset.barcodes
    }
atac_dataset.populate_from_data(
    X=dataset.atac_expression, # notice the normalization
    batch_indices=None,
    gene_names=dataset.atac_names,
    cell_attributes_dict=cell_attributes_dict,
    Ys=[],
)
rna_dataset = GeneExpressionDataset()
Ys = []
measurement = CellMeasurement(
        name="atac_expression",
        data=atac_dataset.X,
        columns_attr_name="atac_names",
        columns=atac_dataset.gene_names,
    )
Ys.append(measurement)
cell_attributes_dict = {
    "barcodes": dataset.barcodes
    }
rna_dataset.populate_from_data(
    X=dataset.X,
    batch_indices=None,
    gene_names=dataset.gene_names,
    cell_attributes_dict=cell_attributes_dict,
    Ys=Ys,
)

[2021-11-23 10:26:31,524] INFO - scMVP.dataset.dataset | Remapping batch_indices to [0,N]
[2021-11-23 10:26:31,526] INFO - scMVP.dataset.dataset | Remapping labels to [0,N]
[2021-11-23 10:26:32,349] INFO - scMVP.dataset.dataset | Remapping batch_indices to [0,N]
[2021-11-23 10:26:32,350] INFO - scMVP.dataset.dataset | Remapping labels to [0,N]


In [11]:
# 进行预训练

# model para
n_epochs = 15
lr = 5e-3
use_batches = False
use_cuda = True
n_centroids = 15
n_alfa = 1.0

# ATAC peak embedding
pre_atac_vae = VAE_Peak_SelfAttention(atac_dataset.nb_genes, n_latent=20,n_batch=0, n_layers=1, log_variational=True, reconstruction_loss="nb")
pre_atac_trainer = UnsupervisedTrainer(
    pre_atac_vae,
    atac_dataset,
    train_size=0.9,
    use_cuda=use_cuda,
    frequency=5,
)
is_test_pragram = False

if os.path.isfile('%s/pre_atac_trainerk.pkl' % output_path):
    pre_atac_trainer.model.load_state_dict(torch.load('%s/pre_atac_trainer.pkl' % output_path))
    pre_atac_trainer.model.eval()
else:
    pre_atac_trainer.train(n_epochs=n_epochs, lr=lr)
    torch.save(pre_atac_trainer.model.state_dict(), '%s/pre_atac_trainer.pkl' % output_path)
    pre_atac_trainer.model.eval()

training:   0%|          | 0/15 [00:00<?, ?it/s]reconst_loss=9125.333008, kl_divergence=0.088799
tensor(9126.3330, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=9254.258789, kl_divergence=0.420087
tensor(9255.2588, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=8314.435547, kl_divergence=1.114696
tensor(8315.4355, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=9850.153320, kl_divergence=2.569014
tensor(9851.1533, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=9102.714844, kl_divergence=6.883841
tensor(9103.7148, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=8570.944336, kl_divergence=16.938183
tensor(8571.9443, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=9065.055664, kl_divergence=37.709003
tensor(9066.0557, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=9515.861328, kl_divergence=63.096493
tensor(9516.8613, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=8638.613281, kl_divergence=65.691856
tensor(8639.6133, device='cuda:0', 

reconst_loss=7503.891113, kl_divergence=62.380787
tensor(7504.8911, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7432.861328, kl_divergence=63.501621
tensor(7433.8613, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7322.632324, kl_divergence=63.107399
tensor(7323.6323, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=8581.403320, kl_divergence=64.828659
tensor(8582.4033, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7171.101562, kl_divergence=61.345520
tensor(7172.1016, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=8265.664062, kl_divergence=60.925644
tensor(8266.6641, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6788.920410, kl_divergence=61.713516
tensor(6789.9199, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6804.696777, kl_divergence=63.772766
tensor(6805.6968, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7492.392090, kl_divergence=63.288494
tensor(7493.3921, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7032.8

reconst_loss=7099.422363, kl_divergence=56.246025
tensor(7100.5625, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6946.735352, kl_divergence=56.739662
tensor(6947.8770, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7762.263184, kl_divergence=57.017845
tensor(7763.4062, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7834.073242, kl_divergence=54.276646
tensor(7835.2090, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=8470.613281, kl_divergence=53.376984
tensor(8471.7461, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7830.607910, kl_divergence=51.728401
tensor(7831.7373, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=8073.495605, kl_divergence=51.186691
tensor(8074.6230, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7806.469727, kl_divergence=52.307518
tensor(7807.5996, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=8787.669922, kl_divergence=53.337708
tensor(8788.8047, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=8252.5

reconst_loss=7700.290039, kl_divergence=56.522793
tensor(7701.5732, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7488.291016, kl_divergence=56.591419
tensor(7489.5742, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6995.848633, kl_divergence=56.945526
tensor(6997.1338, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6205.057617, kl_divergence=55.650185
tensor(6206.3354, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7618.054199, kl_divergence=54.632168
tensor(7619.3267, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6987.916504, kl_divergence=54.523010
tensor(6989.1890, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7872.727539, kl_divergence=54.084591
tensor(7873.9985, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7522.890625, kl_divergence=53.704330
tensor(7524.1582, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7570.966797, kl_divergence=53.063278
tensor(7572.2324, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6851.3

reconst_loss=7468.663086, kl_divergence=57.513840
tensor(7470.0942, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7772.580078, kl_divergence=57.963665
tensor(7774.0146, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6607.936035, kl_divergence=57.007160
tensor(6609.3633, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6699.752441, kl_divergence=57.919670
tensor(6701.1865, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7174.629883, kl_divergence=58.674229
tensor(7176.0698, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=8105.976074, kl_divergence=61.306488
tensor(8107.4365, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6640.676758, kl_divergence=60.300327
tensor(6642.1289, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=8279.327148, kl_divergence=59.776009
tensor(8280.7754, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=8124.650391, kl_divergence=60.153618
tensor(8126.1011, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7226.9

reconst_loss=7151.356445, kl_divergence=61.536282
tensor(7152.8184, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6829.212402, kl_divergence=61.712273
tensor(6830.6753, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=8397.724609, kl_divergence=60.837944
tensor(8399.1807, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7012.485352, kl_divergence=60.614082
tensor(7013.9404, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6873.400879, kl_divergence=61.376766
tensor(6874.8613, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7229.198242, kl_divergence=63.189636
tensor(7230.6719, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6373.608398, kl_divergence=63.007992
tensor(6375.0806, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7337.974121, kl_divergence=64.441193
tensor(7339.4580, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=8432.971680, kl_divergence=61.876900
tensor(8434.4355, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7413.4

reconst_loss=6798.239746, kl_divergence=66.558762
tensor(6799.9053, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6938.499512, kl_divergence=68.333908
tensor(6940.1826, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7552.681641, kl_divergence=68.305038
tensor(7554.3652, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=8005.400879, kl_divergence=67.971832
tensor(8007.0811, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6809.913086, kl_divergence=69.419441
tensor(6811.6064, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7880.894531, kl_divergence=67.925186
tensor(7882.5737, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6357.948242, kl_divergence=67.476883
tensor(6359.6230, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7581.928711, kl_divergence=66.371078
tensor(7583.5923, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7685.395020, kl_divergence=68.583549
tensor(7687.0811, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7331.5

reconst_loss=6967.847656, kl_divergence=66.570343
reconst_loss=8064.604980, kl_divergence=63.631210
reconst_loss=7976.362305, kl_divergence=63.863068
reconst_loss=7435.815430, kl_divergence=64.767120
reconst_loss=6872.533691, kl_divergence=64.552582
reconst_loss=7806.462891, kl_divergence=63.914764
reconst_loss=7216.583008, kl_divergence=65.755005
reconst_loss=8338.306641, kl_divergence=67.091896
reconst_loss=7065.178711, kl_divergence=65.788307
reconst_loss=7290.307617, kl_divergence=65.592850
reconst_loss=7562.718750, kl_divergence=64.631500
reconst_loss=7156.977539, kl_divergence=68.457008
reconst_loss=6964.546387, kl_divergence=64.252609
reconst_loss=6704.056641, kl_divergence=66.268188
reconst_loss=7917.177734, kl_divergence=65.331169
reconst_loss=7744.875977, kl_divergence=64.066269
reconst_loss=9146.562500, kl_divergence=64.021957
reconst_loss=6925.099121, kl_divergence=66.840706
reconst_loss=8581.476562, kl_divergence=60.665192
reconst_loss=7169.574219, kl_divergence=60.957199


reconst_loss=8098.712891, kl_divergence=60.711941
tensor(8100.4717, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=8612.494141, kl_divergence=60.856247
tensor(8614.2539, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7750.712891, kl_divergence=60.824921
tensor(7752.4736, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7466.687500, kl_divergence=61.610870
tensor(7468.4575, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6953.307617, kl_divergence=62.033188
tensor(6955.0830, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7040.375000, kl_divergence=63.003162
tensor(7042.1631, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7027.770996, kl_divergence=62.648750
tensor(7029.5537, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6931.020020, kl_divergence=62.989456
tensor(6932.8071, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=8577.841797, kl_divergence=62.741074
tensor(8579.6250, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=5544.7

reconst_loss=8044.187500, kl_divergence=63.451416
tensor(8046.1392, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6914.953125, kl_divergence=64.356483
tensor(6916.9185, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=8501.041992, kl_divergence=62.114643
tensor(8502.9746, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=8277.297852, kl_divergence=60.990158
tensor(8279.2129, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7971.204102, kl_divergence=60.244141
tensor(7973.1074, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7809.367676, kl_divergence=60.302925
tensor(7811.2725, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7334.660156, kl_divergence=59.233452
tensor(7336.5479, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7483.663574, kl_divergence=58.592644
tensor(7485.5420, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6956.183594, kl_divergence=57.870247
tensor(6958.0513, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7828.3

reconst_loss=8802.341797, kl_divergence=65.225296
tensor(8804.4824, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7942.704590, kl_divergence=64.437881
tensor(7944.8320, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6776.020020, kl_divergence=63.089905
tensor(6778.1240, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6868.167969, kl_divergence=64.354218
tensor(6870.2944, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7945.777344, kl_divergence=62.442787
tensor(7947.8701, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7890.671387, kl_divergence=62.887543
tensor(7892.7715, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7909.082031, kl_divergence=65.513756
tensor(7911.2275, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6647.298828, kl_divergence=64.490845
tensor(6649.4268, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6242.041016, kl_divergence=65.838737
tensor(6244.1924, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7593.4

reconst_loss=7930.202637, kl_divergence=56.916389
tensor(7932.1997, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7087.449707, kl_divergence=57.476372
tensor(7089.4546, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7156.979004, kl_divergence=57.493355
tensor(7158.9854, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6589.139648, kl_divergence=57.367058
tensor(6591.1436, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=8364.535156, kl_divergence=59.335197
tensor(8366.5732, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7940.079102, kl_divergence=58.099617
tensor(7942.0957, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7516.470703, kl_divergence=58.391163
tensor(7518.4927, device='cuda:0', grad_fn=<DivBackward0>)
training:  80%|████████  | 12/15 [01:19<00:15,  5.29s/it]reconst_loss=7402.672852, kl_divergence=59.323116
tensor(7404.8594, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7870.414062, kl_divergence=57.655369
tensor(7872.5674, dev

reconst_loss=6560.112793, kl_divergence=65.205780
tensor(6562.4165, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7524.262207, kl_divergence=64.764397
tensor(7526.5571, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6761.152344, kl_divergence=64.595306
tensor(6763.4438, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6619.282227, kl_divergence=65.341629
tensor(6621.5884, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7338.866211, kl_divergence=63.386799
tensor(7341.1338, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6939.128906, kl_divergence=63.615997
tensor(6941.4009, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7561.195312, kl_divergence=64.667084
tensor(7563.4883, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7007.985840, kl_divergence=63.677078
tensor(7010.2593, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7552.124512, kl_divergence=62.586838
tensor(7554.3765, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6684.7

reconst_loss=7763.691406, kl_divergence=59.419273
tensor(7766.0278, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6867.221191, kl_divergence=59.353104
tensor(6869.5562, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=8829.545898, kl_divergence=60.373680
tensor(8831.9043, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7662.160645, kl_divergence=60.392036
tensor(7664.5200, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7420.485352, kl_divergence=61.794186
tensor(7422.8755, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7155.470703, kl_divergence=62.522823
tensor(7157.8770, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7906.008789, kl_divergence=62.967926
tensor(7908.4253, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7957.361816, kl_divergence=64.036606
tensor(7959.8027, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7124.333496, kl_divergence=64.137199
tensor(7126.7764, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7492.1

reconst_loss=7592.467773, kl_divergence=65.712341
reconst_loss=6582.620117, kl_divergence=64.932266
reconst_loss=6270.321777, kl_divergence=63.757896
reconst_loss=7019.556152, kl_divergence=62.698845
reconst_loss=8720.160156, kl_divergence=62.724091
reconst_loss=7988.744141, kl_divergence=63.857063
reconst_loss=7611.309570, kl_divergence=64.029541
reconst_loss=7795.665039, kl_divergence=64.285782
reconst_loss=8339.416992, kl_divergence=65.088226
reconst_loss=7766.720703, kl_divergence=64.192421
reconst_loss=9024.149414, kl_divergence=64.039864
reconst_loss=8202.297852, kl_divergence=62.909859
reconst_loss=7963.249512, kl_divergence=63.819252
reconst_loss=7491.624023, kl_divergence=63.360199
reconst_loss=7177.293945, kl_divergence=63.474998
reconst_loss=6915.468262, kl_divergence=63.392860
reconst_loss=7802.977539, kl_divergence=66.138199
reconst_loss=7604.133789, kl_divergence=63.819717
reconst_loss=8649.816406, kl_divergence=64.225914
reconst_loss=7507.886230, kl_divergence=62.407387


reconst_loss=7114.973633, kl_divergence=53.524864
tensor(7117.3120, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6654.471191, kl_divergence=55.353588
tensor(6656.8550, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7985.455078, kl_divergence=54.591732
tensor(7987.8198, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6161.244141, kl_divergence=56.110970
tensor(6163.6470, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7295.147461, kl_divergence=57.194633
tensor(7297.5776, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7213.246094, kl_divergence=55.907288
tensor(7215.6440, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7828.100098, kl_divergence=56.586311
tensor(7830.5146, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7867.722656, kl_divergence=58.062653
tensor(7870.1743, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=8756.939453, kl_divergence=58.840057
tensor(8759.4111, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6882.0

reconst_loss=7828.501953, kl_divergence=61.856403
tensor(7831.2036, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7670.869141, kl_divergence=60.703327
tensor(7673.5381, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7420.067871, kl_divergence=61.363628
tensor(7422.7554, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6622.731445, kl_divergence=60.637264
tensor(6625.3994, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7343.918945, kl_divergence=62.460060
tensor(7346.6367, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7675.187500, kl_divergence=61.997025
tensor(7677.8921, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7664.990234, kl_divergence=61.431114
tensor(7667.6797, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=8114.624023, kl_divergence=61.967228
tensor(8117.3276, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6783.685059, kl_divergence=63.651489
tensor(6786.4355, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7365.6

training: 17it [01:43,  4.85s/it]reconst_loss=7610.520996, kl_divergence=60.596703
tensor(7613.3384, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7681.781738, kl_divergence=60.784561
tensor(7684.6050, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6521.374512, kl_divergence=62.670807
tensor(6524.2539, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6659.321777, kl_divergence=61.364220
tensor(6662.1626, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6518.274414, kl_divergence=64.589340
tensor(6521.2114, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7764.369629, kl_divergence=65.763039
tensor(7767.3428, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=8767.560547, kl_divergence=67.643143
tensor(8770.5889, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6891.672852, kl_divergence=67.928513
tensor(6894.7104, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6816.081055, kl_divergence=68.481949
tensor(6819.1357, device='cuda:0', grad_fn=<D

reconst_loss=6954.800293, kl_divergence=67.027039
tensor(6957.8110, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7693.083984, kl_divergence=66.547638
tensor(7696.0806, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7496.486816, kl_divergence=65.711578
tensor(7499.4580, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7602.765625, kl_divergence=64.774803
tensor(7605.7090, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7588.247070, kl_divergence=64.057800
tensor(7591.1680, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6776.112305, kl_divergence=64.342339
tensor(6779.0430, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6889.112305, kl_divergence=62.289360
tensor(6891.9805, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7967.641602, kl_divergence=61.976830
tensor(7970.5010, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7133.604492, kl_divergence=60.517677
tensor(7136.4204, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7159.4

reconst_loss=7359.721191, kl_divergence=63.896454
tensor(7362.7979, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=8170.255371, kl_divergence=63.788914
tensor(8173.3291, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6592.618164, kl_divergence=62.507854
tensor(6595.6504, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6760.808594, kl_divergence=63.012611
tensor(6763.8564, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=8320.501953, kl_divergence=62.357727
tensor(8323.5273, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=8135.930664, kl_divergence=63.111504
tensor(8138.9819, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7814.694336, kl_divergence=62.670647
tensor(7817.7305, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6043.122070, kl_divergence=62.060371
tensor(6046.1387, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7377.748047, kl_divergence=64.649200
tensor(7380.8491, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6649.4

reconst_loss=8114.720215, kl_divergence=65.332687
tensor(8118.0063, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7050.853027, kl_divergence=65.142593
tensor(7054.1328, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6871.723633, kl_divergence=67.401260
tensor(6875.0830, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6543.266113, kl_divergence=67.379013
tensor(6546.6240, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=6844.592773, kl_divergence=66.743065
tensor(6847.9287, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7366.905762, kl_divergence=67.190140
tensor(7370.2568, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7649.646484, kl_divergence=66.991989
tensor(7652.9912, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7068.676758, kl_divergence=66.879311
tensor(7072.0176, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=7503.826172, kl_divergence=66.883377
tensor(7507.1670, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=8334.4

reconst_loss=8666.126953, kl_divergence=60.213333
reconst_loss=7788.877441, kl_divergence=59.433041
reconst_loss=7367.659180, kl_divergence=60.113060
reconst_loss=7381.947266, kl_divergence=59.518669
reconst_loss=7884.723633, kl_divergence=58.959480
reconst_loss=6783.906250, kl_divergence=59.966465
reconst_loss=6728.690430, kl_divergence=58.821411
reconst_loss=7360.670898, kl_divergence=59.344955
reconst_loss=7049.729004, kl_divergence=58.971390
reconst_loss=7120.522949, kl_divergence=60.104393
reconst_loss=7059.621582, kl_divergence=60.960823
reconst_loss=6900.122559, kl_divergence=59.350082
reconst_loss=7694.797363, kl_divergence=59.727440
reconst_loss=7664.945312, kl_divergence=59.333626
reconst_loss=7927.193359, kl_divergence=60.638489
reconst_loss=7352.181641, kl_divergence=59.952328
reconst_loss=7376.677734, kl_divergence=60.264072
reconst_loss=7250.594727, kl_divergence=61.569405
reconst_loss=7037.785156, kl_divergence=60.288170
reconst_loss=7238.524902, kl_divergence=59.409405


reconst_loss=6663.971680, kl_divergence=58.759277
reconst_loss=8322.383789, kl_divergence=60.952389
reconst_loss=6816.692383, kl_divergence=58.955139
reconst_loss=7132.244141, kl_divergence=59.328072
reconst_loss=7848.956055, kl_divergence=59.670311
reconst_loss=7804.445312, kl_divergence=59.842842
reconst_loss=6927.117676, kl_divergence=59.617317
reconst_loss=6685.281250, kl_divergence=58.133545
reconst_loss=7110.363281, kl_divergence=60.212921
reconst_loss=8880.710938, kl_divergence=59.623146
reconst_loss=7321.100586, kl_divergence=59.124451
reconst_loss=7755.806641, kl_divergence=59.271194
reconst_loss=7150.927734, kl_divergence=60.584282
reconst_loss=6960.537109, kl_divergence=59.817207
reconst_loss=7465.416016, kl_divergence=60.564529
reconst_loss=8460.965820, kl_divergence=59.960686
reconst_loss=6072.417969, kl_divergence=58.583290
reconst_loss=7182.586914, kl_divergence=58.631878
reconst_loss=7882.950195, kl_divergence=59.910378
reconst_loss=6863.925293, kl_divergence=60.236115


In [13]:
# ATAC pretrainer_posterior:
full = pre_atac_trainer.create_posterior(pre_atac_trainer.model, atac_dataset, indices=np.arange(len(atac_dataset)))
latent, batch_indices, labels = full.sequential().get_latent()
batch_indices = batch_indices.ravel()
prior_adata = anndata.AnnData(X=atac_dataset.X)
prior_adata.obsm["X_multi_vi"] = latent
prior_adata.obs['cell_type'] = torch.tensor(labels.reshape(-1,1))
sc.pp.neighbors(prior_adata, use_rep="X_multi_vi", n_neighbors=30)
sc.tl.umap(prior_adata, min_dist=0.3)
#matplotlib.use('TkAgg')
#fig, ax = plt.subplots(figsize=(7, 6))
#sc.pl.umap(prior_adata, color=["cell_type"])
#plt.show()
sc.tl.louvain(prior_adata)
sc.pl.umap(prior_adata, color=['louvain'])
plt.show()
# save data as csv file
df = pd.DataFrame(data=prior_adata.obsm["X_umap"],  columns=["umap_dim1","umap_dim2"] , index=atac_dataset.barcodes )
df.insert(0,"labels",prior_adata.obs['louvain'].values)
df.to_csv(os.path.join(output_path,"scvi_atac_umap.csv"))

df = pd.DataFrame(data=prior_adata.obsm["X_multi_vi"],  index=atac_dataset.barcodes)
df.to_csv(os.path.join(output_path,"scvi_latent_atac_imputation.csv"))
imputed_values = full.sequential().imputation()
df = pd.DataFrame(data=imputed_values.T, columns=atac_dataset.barcodes, index=atac_dataset.gene_names)

In [14]:
# RNA embedding
pre_vae = VAE_Attention(rna_dataset.nb_genes, n_latent=20,n_batch=0, n_layers=1, log_variational=True, reconstruction_loss="nb")
pre_trainer = UnsupervisedTrainer(
    pre_vae,
    rna_dataset,
    train_size=0.9,
    use_cuda=use_cuda,
    frequency=5,
)
is_test_pragram = False
if is_test_pragram:
    pre_trainer.train(n_epochs=n_epochs, lr=lr)
    torch.save(pre_trainer.model.state_dict(), '%s/pre_trainer.pkl' % output_path)

if os.path.isfile('%s/pre_trainer.pkl' % output_path):
    pre_trainer.model.load_state_dict(torch.load('%s/pre_trainer.pkl' % output_path))
else:
    pre_trainer.train(n_epochs=n_epochs, lr=lr)
    torch.save(pre_trainer.model.state_dict(), '%s/pre_trainer.pkl' % output_path)
    pre_trainer.model.eval()

# RNA pretrainer_posterior:
full = pre_trainer.create_posterior(pre_trainer.model, rna_dataset, indices=np.arange(len(rna_dataset)))
latent, batch_indices, labels = full.sequential().get_latent()
batch_indices = batch_indices.ravel()
imputed_values = full.sequential().imputation()

df = pd.DataFrame(data=imputed_values.T, columns=rna_dataset.barcodes, index=rna_dataset.gene_names)
#df.to_csv(os.path.join(save_path,"gene_scvi_imputation_210324_2.csv"))
# visulization
prior_adata = anndata.AnnData(X=rna_dataset.X)
prior_adata.obsm["X_multi_vi"] = latent
prior_adata.obs['cell_type'] = torch.tensor(labels.reshape(-1,1))
sc.pp.neighbors(prior_adata, use_rep="X_multi_vi", n_neighbors=30)
sc.tl.umap(prior_adata, min_dist=0.3)
#matplotlib.use('TkAgg')
#fig, ax = plt.subplots(figsize=(7, 6))
#sc.pl.umap(prior_adata, color=["cell_type"], ax=ax, show=show_plot)
#plt.show()
sc.tl.louvain(prior_adata)
sc.pl.umap(prior_adata, color=['louvain'])
plt.show()

# save data as csv file
df = pd.DataFrame(data=prior_adata.obsm["X_umap"],  columns=["umap_dim1","umap_dim2"] , index=rna_dataset.barcodes )
df.insert(0,"labels",prior_adata.obs['louvain'].values)
df.to_csv(os.path.join(output_path,"scvi_umap.csv"))

df = pd.DataFrame(data=prior_adata.obsm["X_multi_vi"],  index=rna_dataset.barcodes)
df.to_csv(os.path.join(output_path,"scvi_latent_imputation.csv"))

training:   0%|          | 0/15 [00:00<?, ?it/s]reconst_loss=1199.452026, kl_divergence=0.672067
tensor(1200.4520, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=1148.259644, kl_divergence=0.903855
tensor(1149.2596, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=1070.145752, kl_divergence=1.356137
tensor(1071.1458, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=951.737305, kl_divergence=2.047019
tensor(952.7373, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=926.873535, kl_divergence=3.537813
tensor(927.8735, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=831.660217, kl_divergence=5.588002
tensor(832.6602, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=786.705566, kl_divergence=8.687210
tensor(787.7056, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=779.406311, kl_divergence=14.239708
tensor(780.4063, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=749.899292, kl_divergence=21.569992
tensor(750.8993, device='cuda:0', grad_fn=<DivBa

reconst_loss=634.147766, kl_divergence=76.453796
tensor(635.1478, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=611.898682, kl_divergence=77.181694
tensor(612.8987, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=630.787109, kl_divergence=77.247772
tensor(631.7871, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=635.204590, kl_divergence=75.637360
tensor(636.2046, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=630.183838, kl_divergence=77.242615
tensor(631.1838, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=646.910461, kl_divergence=78.201492
tensor(647.9105, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=616.549927, kl_divergence=77.008278
tensor(617.5499, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=653.387878, kl_divergence=80.244499
tensor(654.3879, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=619.033447, kl_divergence=78.063835
tensor(620.0334, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=632.760010, kl_divergenc

reconst_loss=614.494995, kl_divergence=89.900284
tensor(615.7197, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=585.941895, kl_divergence=91.582039
tensor(587.1709, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=645.009888, kl_divergence=92.352753
tensor(646.2408, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=616.139221, kl_divergence=92.155319
tensor(617.3696, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=617.932861, kl_divergence=91.269798
tensor(619.1610, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=605.965210, kl_divergence=92.021103
tensor(607.1953, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=628.863647, kl_divergence=90.571808
tensor(630.0901, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=626.250366, kl_divergence=92.466263
tensor(627.4816, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=594.143921, kl_divergence=92.974945
tensor(595.3763, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=630.743225, kl_divergenc

reconst_loss=587.886963, kl_divergence=93.481651
tensor(589.3544, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=611.476868, kl_divergence=94.320496
tensor(612.9485, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=608.935913, kl_divergence=96.641953
tensor(610.4191, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=612.629700, kl_divergence=94.716125
tensor(614.1033, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=620.420105, kl_divergence=93.199844
tensor(621.8861, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=624.156799, kl_divergence=94.526382
tensor(625.6295, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=607.006348, kl_divergence=94.011414
tensor(608.4764, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=631.153687, kl_divergence=94.537033
tensor(632.6264, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=625.370483, kl_divergence=95.223854
tensor(626.8466, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=611.212158, kl_divergenc

reconst_loss=616.108154, kl_divergence=92.583191
tensor(617.8026, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=606.214844, kl_divergence=94.723633
tensor(607.9253, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=589.672363, kl_divergence=94.448616
tensor(591.3807, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=624.279602, kl_divergence=94.058838
tensor(625.9850, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=601.731934, kl_divergence=93.757889
tensor(603.4351, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=616.782471, kl_divergence=93.429306
tensor(618.4832, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=569.736511, kl_divergence=92.982529
tensor(571.4338, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=594.019531, kl_divergence=93.302620
tensor(595.7192, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=615.009460, kl_divergence=92.403770
tensor(616.7025, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=587.820068, kl_divergenc

reconst_loss=575.860474, kl_divergence=89.130997
tensor(577.7517, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=600.689209, kl_divergence=91.378418
tensor(602.6030, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=580.778809, kl_divergence=88.000854
tensor(582.6589, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=599.874512, kl_divergence=90.153259
tensor(601.7761, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=571.169495, kl_divergence=88.555824
tensor(573.0551, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=577.805542, kl_divergence=89.514313
tensor(579.7007, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=592.026672, kl_divergence=89.206467
tensor(593.9188, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=593.477783, kl_divergence=86.290367
tensor(595.3407, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=590.888916, kl_divergence=88.479004
tensor(592.7737, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=594.353638, kl_divergenc

reconst_loss=609.198608, kl_divergence=87.311935
tensor(611.0717, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=597.873230, kl_divergence=87.425919
tensor(599.7475, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=598.632935, kl_divergence=85.103485
tensor(600.4840, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=597.935974, kl_divergence=87.865891
tensor(599.8146, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=612.465088, kl_divergence=87.919991
tensor(614.3443, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=597.035034, kl_divergence=86.071136
tensor(598.8958, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=586.355591, kl_divergence=85.865150
tensor(588.2142, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=606.641113, kl_divergence=86.013855
tensor(608.5013, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=602.745117, kl_divergence=87.015724
tensor(604.6153, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=586.145691, kl_divergenc

reconst_loss=601.645386, kl_divergence=85.034622
tensor(603.7083, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=595.926392, kl_divergence=84.503799
tensor(597.9828, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=573.845215, kl_divergence=82.582336
tensor(575.8775, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=606.826111, kl_divergence=82.669968
tensor(608.8594, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=596.669006, kl_divergence=84.369736
tensor(598.7236, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=599.096619, kl_divergence=82.954750
tensor(601.1335, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=594.977051, kl_divergence=84.984856
tensor(597.0393, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=583.385681, kl_divergence=81.786598
tensor(585.4080, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=634.646118, kl_divergence=85.260124
tensor(636.7119, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=588.337891, kl_divergenc

training:  60%|██████    | 9/15 [00:08<00:06,  1.14s/it]reconst_loss=588.115845, kl_divergence=83.148376
tensor(590.3630, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=576.447693, kl_divergence=83.826767
tensor(578.7051, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=574.333435, kl_divergence=80.722092
tensor(576.5442, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=579.784912, kl_divergence=80.536255
tensor(581.9929, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=589.743713, kl_divergence=80.737610
tensor(591.9548, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=585.950623, kl_divergence=80.787033
tensor(588.1625, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=603.648682, kl_divergence=83.081963
tensor(605.8949, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=589.541687, kl_divergence=81.791161
tensor(591.7686, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=589.843079, kl_divergence=81.147659
tensor(592.0603, device='cuda:0', grad_

reconst_loss=620.888428, kl_divergence=80.553963
tensor(623.0967, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=586.160950, kl_divergence=80.580536
tensor(588.3698, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=581.348511, kl_divergence=79.846832
tensor(583.5462, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=582.369873, kl_divergence=81.948410
tensor(584.5991, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=605.212585, kl_divergence=79.499771
tensor(607.4051, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=604.042786, kl_divergence=79.692551
tensor(606.2382, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=595.452759, kl_divergence=81.455185
tensor(597.6746, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=597.111938, kl_divergence=80.316490
tensor(599.3167, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=623.103333, kl_divergence=80.939888
tensor(625.3174, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=599.842896, kl_divergenc

reconst_loss=581.630981, kl_divergence=78.284576
tensor(584.0010, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=597.851440, kl_divergence=77.921661
tensor(600.2151, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=579.187378, kl_divergence=78.671951
tensor(581.5641, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=578.873291, kl_divergence=77.369202
tensor(581.2273, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=602.228638, kl_divergence=79.856850
tensor(604.6261, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=594.306030, kl_divergence=78.283508
tensor(596.6760, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=575.489929, kl_divergence=79.253830
tensor(577.8769, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=587.238770, kl_divergence=77.737579
tensor(589.5992, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=611.181763, kl_divergence=78.114120
tensor(613.5488, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=602.281616, kl_divergenc

reconst_loss=587.531860, kl_divergence=74.480896
tensor(590.0215, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=615.407959, kl_divergence=75.508141
tensor(617.9182, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=573.960571, kl_divergence=76.502007
tensor(576.4906, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=594.460449, kl_divergence=75.761169
tensor(596.9756, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=586.565308, kl_divergence=78.496170
tensor(589.1353, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=609.287476, kl_divergence=77.356354
tensor(611.8345, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=587.210876, kl_divergence=76.766289
tensor(589.7462, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=576.396790, kl_divergence=76.794182
tensor(578.9327, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=593.277405, kl_divergence=74.908173
tensor(595.7756, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=591.119507, kl_divergenc

reconst_loss=599.390625, kl_divergence=73.811073
tensor(602.0514, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=571.858704, kl_divergence=73.134697
tensor(574.5042, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=559.831787, kl_divergence=76.244583
tensor(562.5474, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=591.474365, kl_divergence=74.323486
tensor(594.1466, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=589.327026, kl_divergence=73.209366
tensor(591.9742, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=605.650513, kl_divergence=73.602524
tensor(608.3066, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=580.583008, kl_divergence=73.980194
tensor(583.2476, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=590.562561, kl_divergence=74.135330
tensor(593.2306, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=584.077698, kl_divergence=72.120277
tensor(586.7004, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=571.463074, kl_divergenc

reconst_loss=605.808716, kl_divergence=72.234222
tensor(608.4340, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=592.238892, kl_divergence=73.336319
tensor(594.8890, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=605.604126, kl_divergence=73.285690
tensor(608.2531, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=609.618042, kl_divergence=72.675903
tensor(612.2532, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=583.632507, kl_divergence=73.136711
reconst_loss=588.975098, kl_divergence=71.280342
reconst_loss=583.107239, kl_divergence=73.596176
reconst_loss=566.054688, kl_divergence=72.072601
reconst_loss=563.404907, kl_divergence=71.553787
reconst_loss=588.449036, kl_divergence=70.901993
reconst_loss=578.790527, kl_divergence=72.651764
reconst_loss=573.724854, kl_divergence=71.633636
reconst_loss=574.507324, kl_divergence=71.540390
reconst_loss=570.724243, kl_divergence=71.649902
reconst_loss=598.109558, kl_divergence=71.451622
reconst_loss=560.239014, kl_diver

reconst_loss=594.702759, kl_divergence=74.189987
tensor(597.5575, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=592.616028, kl_divergence=74.425507
tensor(595.4767, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=591.062317, kl_divergence=74.667374
tensor(593.9290, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=583.817810, kl_divergence=75.226578
tensor(586.6985, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=589.666992, kl_divergence=73.648003
tensor(592.5082, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=590.979980, kl_divergence=74.819366
tensor(593.8505, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=602.123657, kl_divergence=75.539078
tensor(605.0122, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=609.011292, kl_divergence=74.384247
tensor(611.8710, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=561.245117, kl_divergence=74.372208
tensor(564.1044, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=572.172729, kl_divergenc

tensor(592.1372, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=562.069397, kl_divergence=72.407066
tensor(565.0605, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=572.331604, kl_divergence=71.581772
tensor(575.3001, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=588.104309, kl_divergence=70.162354
tensor(591.0338, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=597.553650, kl_divergence=72.733719
tensor(600.5538, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=550.746460, kl_divergence=72.824371
tensor(553.7491, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=587.384216, kl_divergence=71.337296
tensor(590.3459, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=603.502258, kl_divergence=71.782211
tensor(606.4763, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=579.148071, kl_divergence=71.026154
tensor(582.1013, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=564.548340, kl_divergence=69.718277
tensor(567.4656, device='cuda:0', gra

reconst_loss=577.181702, kl_divergence=67.372345
tensor(580.0344, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=589.116699, kl_divergence=69.807785
tensor(592.0364, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=595.018921, kl_divergence=69.029594
tensor(597.9173, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=581.613892, kl_divergence=68.651627
tensor(584.5018, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=594.222778, kl_divergence=70.106415
tensor(597.1507, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=588.868530, kl_divergence=70.679153
tensor(591.8122, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=579.394531, kl_divergence=69.056541
tensor(582.2936, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=574.357300, kl_divergence=69.361282
tensor(577.2648, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=583.914978, kl_divergence=69.476379
tensor(586.8256, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=575.721130, kl_divergenc

reconst_loss=600.691650, kl_divergence=67.624939
tensor(603.7205, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=610.960327, kl_divergence=67.841095
tensor(613.9956, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=595.928101, kl_divergence=67.851089
tensor(598.9637, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=575.238647, kl_divergence=66.814590
tensor(578.2430, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=589.493408, kl_divergence=67.100235
tensor(592.5065, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=589.350952, kl_divergence=67.983536
tensor(592.3905, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=583.242554, kl_divergence=67.006233
tensor(586.2527, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=591.763794, kl_divergence=67.221558
tensor(594.7805, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=569.906311, kl_divergence=67.828278
tensor(572.9412, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=591.027832, kl_divergenc

reconst_loss=575.686523, kl_divergence=68.277657
tensor(578.9056, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=588.625122, kl_divergence=66.716614
tensor(591.7935, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=586.984131, kl_divergence=66.586517
tensor(590.1482, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=586.342163, kl_divergence=67.237930
tensor(589.5275, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=589.693726, kl_divergence=66.863983
tensor(592.8668, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=592.592163, kl_divergence=67.610153
tensor(595.7894, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=590.546204, kl_divergence=67.877396
tensor(593.7522, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=574.587769, kl_divergence=68.095505
tensor(577.8009, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=578.404419, kl_divergence=65.455345
tensor(581.5317, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=576.553955, kl_divergenc

reconst_loss=586.361694, kl_divergence=66.929184
tensor(589.7042, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=548.816101, kl_divergence=66.401993
tensor(552.1401, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=560.524597, kl_divergence=67.003502
tensor(563.8697, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=577.932495, kl_divergence=68.262009
tensor(581.3217, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=596.788940, kl_divergence=66.493393
tensor(600.1162, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=572.617493, kl_divergence=66.385956
tensor(575.9410, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=574.688721, kl_divergence=66.482025
tensor(578.0156, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=589.204956, kl_divergence=67.994194
tensor(592.5847, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=575.724304, kl_divergence=66.810455
tensor(579.0626, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=586.235107, kl_divergenc

reconst_loss=585.157715, kl_divergence=65.464127
tensor(588.4490, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=578.559143, kl_divergence=63.421627
tensor(581.7790, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=576.599182, kl_divergence=64.027641
tensor(579.8402, device='cuda:0', grad_fn=<DivBackward0>)
reconst_loss=576.117371, kl_divergence=63.019077
reconst_loss=583.682007, kl_divergence=61.687061
reconst_loss=556.178589, kl_divergence=61.521656
reconst_loss=582.863586, kl_divergence=62.913456
reconst_loss=574.858154, kl_divergence=61.143951
reconst_loss=547.859375, kl_divergence=61.801983
reconst_loss=566.607483, kl_divergence=62.497200
reconst_loss=573.582947, kl_divergence=62.703506
reconst_loss=573.871216, kl_divergence=63.800987
reconst_loss=571.589111, kl_divergence=60.553680
reconst_loss=584.776978, kl_divergence=64.359840
reconst_loss=542.604614, kl_divergence=63.331345
reconst_loss=544.170959, kl_divergence=61.628616
reconst_loss=579.333557, kl_divergence=60.

reconst_loss=573.510376, kl_divergence=62.868362
reconst_loss=554.262085, kl_divergence=60.992725
reconst_loss=568.926025, kl_divergence=62.847252
reconst_loss=575.838013, kl_divergence=60.571449
reconst_loss=568.200439, kl_divergence=61.766697
reconst_loss=559.683838, kl_divergence=61.270321
reconst_loss=569.552795, kl_divergence=61.715530
reconst_loss=565.359619, kl_divergence=64.090630
reconst_loss=569.586304, kl_divergence=61.552177
reconst_loss=572.751343, kl_divergence=62.817585
reconst_loss=568.283569, kl_divergence=60.586136
reconst_loss=567.166382, kl_divergence=61.812176
reconst_loss=575.046326, kl_divergence=61.689789
reconst_loss=563.541260, kl_divergence=60.253616
reconst_loss=540.712769, kl_divergence=60.610336
reconst_loss=578.133667, kl_divergence=62.462742
reconst_loss=565.893433, kl_divergence=62.812019
reconst_loss=579.180725, kl_divergence=62.269154
reconst_loss=594.066040, kl_divergence=60.692383
reconst_loss=571.924011, kl_divergence=59.930397
reconst_loss=561.478

In [15]:
n_centroids=len(np.unique(prior_adata.obs['louvain'].tolist()))

# joint RNA and ATAC embedding
multi_vae = Multi_VAE_Attention(rna_dataset.nb_genes, len(rna_dataset.atac_names), n_batch=0, n_latent=20, n_centroids=n_centroids, n_alfa = n_alfa, mode="mm-vae") # should provide ATAC num, alfa, mode and loss type
trainer = MultiTrainer(
    multi_vae,
    rna_dataset,
    train_size=0.9,
    use_cuda=use_cuda,
    frequency=5,
)


if os.path.isfile('%s/multi_vae.pkl' % output_path):
    trainer.model.load_state_dict(torch.load('%s/multi_vae.pkl' % output_path))
    trainer.model.eval()
else:
    pre_trainer = UnsupervisedTrainer(
        pre_vae,
        rna_dataset,
        train_size=0.9,
        use_cuda=use_cuda,
        frequency=5,
    )
    pre_trainer.model.load_state_dict(torch.load('%s/pre_trainer.pkl' % output_path))

    pre_atac_trainer = UnsupervisedTrainer(
        pre_atac_vae,
        atac_dataset,
        train_size=0.9,
        use_cuda=use_cuda,
        frequency=5,
    )
    pre_atac_trainer.model.load_state_dict(torch.load('%s/pre_atac_trainer.pkl' % output_path))

    n_centroids=len(np.unique(prior_adata.obs['louvain'].tolist()))

    # joint RNA and ATAC embedding
    trainer.model.init_gmm_params_with_louvain(latent,np.array(prior_adata.obs['louvain'].tolist()).astype(int))

    trainer.model.RNA_encoder.load_state_dict(pre_trainer.model.z_encoder.state_dict())
    for param in trainer.model.RNA_encoder.parameters():
        param.requires_grad = True
    trainer.model.ATAC_encoder.load_state_dict(pre_atac_trainer.model.z_encoder.state_dict())
    for param in trainer.model.ATAC_encoder.parameters():
        param.requires_grad = True
    trainer.train(n_epochs=15, lr=lr)
    torch.save(trainer.model.state_dict(), '%s/multi_vae.pkl' % output_path)
    trainer.model.eval()

training:   0%|          | 0/15 [00:00<?, ?it/s]logpzc:-287.84307861328125, logqcx:-2.079441547393799
kld_qz_pz = 246.227661,kld_qz_rna = 5933.751953,kld_qz_atac = 18137.398438,kl_divergence = 271.020050,reconst_loss_rna = 3568.731445,        reconst_loss_atac = 857.273010, mu=0.116825, sigma=1.686523
reconst_loss = 4476.372070,kl_divergence_local = 271.020050,kl_weight = 1.000000,loss = 4747.392578
tensor(4747.3926, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-46.061073303222656, logqcx:-2.079441547393799
kld_qz_pz = 9.502743,kld_qz_rna = 4684.460449,kld_qz_atac = 987.074951,kl_divergence = 15.104197,reconst_loss_rna = 3362.215332,        reconst_loss_atac = 831.204224, mu=0.116400, sigma=1.690874
reconst_loss = 4243.787598,kl_divergence_local = 15.104197,kl_weight = 1.000000,loss = 4258.892090
tensor(4258.8921, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-43.807518005371094, logqcx:-2.079441547393799
kld_qz_pz = 9.222754,kld_qz_rna = 4216.265625,kld_qz_atac = 259.255676,kl_di

kld_qz_pz = 7.434896,kld_qz_rna = 1525.923828,kld_qz_atac = 376.645874,kl_divergence = 10.437126,reconst_loss_rna = 1986.721558,        reconst_loss_atac = 1012.398254, mu=0.106982, sigma=1.741041
reconst_loss = 3049.488037,kl_divergence_local = 10.437126,kl_weight = 1.000000,loss = 3059.925049
tensor(3059.9250, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-38.041839599609375, logqcx:-2.079437732696533
kld_qz_pz = 6.691964,kld_qz_rna = 1649.321777,kld_qz_atac = 228.614502,kl_divergence = 9.539717,reconst_loss_rna = 2143.903320,        reconst_loss_atac = 805.674133, mu=0.106771, sigma=1.742744
reconst_loss = 2999.945557,kl_divergence_local = 9.539717,kl_weight = 1.000000,loss = 3009.485352
tensor(3009.4854, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-37.459205627441406, logqcx:-2.0794291496276855
kld_qz_pz = 6.176533,kld_qz_rna = 1429.081543,kld_qz_atac = 220.295380,kl_divergence = 8.953671,reconst_loss_rna = 2055.962402,        reconst_loss_atac = 790.665710, mu=0.106596, sigm

logpzc:-35.75185775756836, logqcx:-2.0794224739074707
kld_qz_pz = 5.638467,kld_qz_rna = 769.015930,kld_qz_atac = 37.200214,kl_divergence = 8.675638,reconst_loss_rna = 1928.748779,        reconst_loss_atac = 754.667297, mu=0.107400, sigma=1.765991
reconst_loss = 2733.783936,kl_divergence_local = 8.675638,kl_weight = 1.000000,loss = 2742.459717
tensor(2742.4597, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-35.90188980102539, logqcx:-2.07942533493042
kld_qz_pz = 5.722612,kld_qz_rna = 772.666626,kld_qz_atac = 38.509476,kl_divergence = 8.737254,reconst_loss_rna = 2104.901855,        reconst_loss_atac = 811.464722, mu=0.107634, sigma=1.766704
reconst_loss = 2966.734375,kl_divergence_local = 8.737254,kl_weight = 1.000000,loss = 2975.471680
tensor(2975.4717, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-36.04742431640625, logqcx:-2.0794217586517334
kld_qz_pz = 5.898539,kld_qz_rna = 579.984192,kld_qz_atac = 38.344681,kl_divergence = 8.973586,reconst_loss_rna = 2019.355469,        reconst

logpzc:-37.00790786743164, logqcx:-2.079432487487793
kld_qz_pz = 5.903799,kld_qz_rna = 649.324097,kld_qz_atac = 37.605194,kl_divergence = 8.641756,reconst_loss_rna = 1970.557861,        reconst_loss_atac = 721.948792, mu=0.109919, sigma=1.778276
reconst_loss = 2742.874756,kl_divergence_local = 8.641756,kl_weight = 1.000000,loss = 2751.516357
tensor(2751.5164, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-37.4891357421875, logqcx:-2.079430103302002
kld_qz_pz = 6.282133,kld_qz_rna = 547.873657,kld_qz_atac = 38.325630,kl_divergence = 9.114414,reconst_loss_rna = 1988.956543,        reconst_loss_atac = 772.844971, mu=0.110182, sigma=1.778881
reconst_loss = 2812.169189,kl_divergence_local = 9.114414,kl_weight = 1.000000,loss = 2821.283936
tensor(2821.2839, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-37.618350982666016, logqcx:-2.0794315338134766
kld_qz_pz = 6.477922,kld_qz_rna = 635.335876,kld_qz_atac = 40.015675,kl_divergence = 9.389725,reconst_loss_rna = 1920.216797,        reconst

logpzc:-35.51699447631836, logqcx:-2.079420566558838
kld_qz_pz = 5.193516,kld_qz_rna = 493.896179,kld_qz_atac = 44.006599,kl_divergence = 7.959767,reconst_loss_rna = 2014.837402,        reconst_loss_atac = 806.537109, mu=0.113178, sigma=1.792104
reconst_loss = 2871.742188,kl_divergence_local = 7.959767,kl_weight = 1.000000,loss = 2879.702148
tensor(2879.7021, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-35.43891525268555, logqcx:-2.0794272422790527
kld_qz_pz = 5.109072,kld_qz_rna = 464.163269,kld_qz_atac = 44.440559,kl_divergence = 7.838649,reconst_loss_rna = 2068.504883,        reconst_loss_atac = 869.861450, mu=0.113183, sigma=1.792720
reconst_loss = 2988.734375,kl_divergence_local = 7.838649,kl_weight = 1.000000,loss = 2996.572998
tensor(2996.5730, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-35.2779541015625, logqcx:-2.079423427581787
kld_qz_pz = 5.022397,kld_qz_rna = 461.880280,kld_qz_atac = 43.499184,kl_divergence = 7.779515,reconst_loss_rna = 1985.383545,        reconst_

logpzc:-34.22632598876953, logqcx:-2.0794179439544678
kld_qz_pz = 4.228137,kld_qz_rna = 397.866272,kld_qz_atac = 56.209057,kl_divergence = 6.896863,reconst_loss_rna = 1910.572754,        reconst_loss_atac = 713.167725, mu=0.115972, sigma=1.801345
reconst_loss = 2674.108398,kl_divergence_local = 6.896863,kl_weight = 1.000000,loss = 2681.005127
tensor(2681.0051, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-33.9879150390625, logqcx:-2.0794010162353516
kld_qz_pz = 3.990450,kld_qz_rna = 415.957520,kld_qz_atac = 57.848076,kl_divergence = 6.690016,reconst_loss_rna = 1977.136230,        reconst_loss_atac = 829.165771, mu=0.116070, sigma=1.801545
reconst_loss = 2856.669434,kl_divergence_local = 6.690016,kl_weight = 1.000000,loss = 2863.359375
tensor(2863.3594, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-33.917625427246094, logqcx:-2.0793988704681396
kld_qz_pz = 3.921724,kld_qz_rna = 414.813721,kld_qz_atac = 57.760880,kl_divergence = 6.626969,reconst_loss_rna = 1908.665894,        recon

logpzc:-33.36967468261719, logqcx:-2.0793652534484863
kld_qz_pz = 3.405431,kld_qz_rna = 393.384369,kld_qz_atac = 68.215591,kl_divergence = 6.324497,reconst_loss_rna = 1914.259766,        reconst_loss_atac = 714.703918, mu=0.118202, sigma=1.800179
reconst_loss = 2679.330322,kl_divergence_local = 6.324497,kl_weight = 1.000000,loss = 2685.654785
tensor(2685.6548, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-33.42973709106445, logqcx:-2.0793614387512207
kld_qz_pz = 3.480810,kld_qz_rna = 366.417755,kld_qz_atac = 69.307236,kl_divergence = 6.513424,reconst_loss_rna = 1928.497925,        reconst_loss_atac = 713.816101, mu=0.118229, sigma=1.800138
reconst_loss = 2692.680664,kl_divergence_local = 6.513424,kl_weight = 1.000000,loss = 2699.194092
tensor(2699.1941, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-33.392662048339844, logqcx:-2.0793700218200684
kld_qz_pz = 3.412508,kld_qz_rna = 404.908539,kld_qz_atac = 69.441864,kl_divergence = 6.395385,reconst_loss_rna = 2020.004639,        reco

logpzc:-32.96514129638672, logqcx:-2.079340934753418
kld_qz_pz = 2.956065,kld_qz_rna = 309.950897,kld_qz_atac = 74.377197,kl_divergence = 6.099885,reconst_loss_rna = 1920.958862,        reconst_loss_atac = 716.217529, mu=0.119920, sigma=1.797620
reconst_loss = 2687.542480,kl_divergence_local = 6.099885,kl_weight = 1.000000,loss = 2693.642334
tensor(2693.6423, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-32.84476852416992, logqcx:-2.0793490409851074
kld_qz_pz = 2.825952,kld_qz_rna = 303.001587,kld_qz_atac = 73.832184,kl_divergence = 5.804656,reconst_loss_rna = 1909.818237,        reconst_loss_atac = 818.815063, mu=0.120053, sigma=1.797285
reconst_loss = 2778.999756,kl_divergence_local = 5.804656,kl_weight = 1.000000,loss = 2784.804199
tensor(2784.8042, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-33.08845901489258, logqcx:-2.0793652534484863
kld_qz_pz = 3.056221,kld_qz_rna = 303.448792,kld_qz_atac = 73.970039,kl_divergence = 6.113116,reconst_loss_rna = 1882.736572,        recons

logpzc:-32.73193359375, logqcx:-2.0793051719665527
kld_qz_pz = 2.680962,kld_qz_rna = 301.166748,kld_qz_atac = 77.853096,kl_divergence = 5.727986,reconst_loss_rna = 1844.674561,        reconst_loss_atac = 699.036011, mu=0.122796, sigma=1.791673
reconst_loss = 2594.075684,kl_divergence_local = 5.727986,kl_weight = 1.000000,loss = 2599.803711
tensor(2599.8037, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-32.672752380371094, logqcx:-2.0792787075042725
kld_qz_pz = 2.617540,kld_qz_rna = 249.698364,kld_qz_atac = 76.879059,kl_divergence = 5.630688,reconst_loss_rna = 1892.312256,        reconst_loss_atac = 733.496704, mu=0.122945, sigma=1.791172
reconst_loss = 2676.173584,kl_divergence_local = 5.630688,kl_weight = 1.000000,loss = 2681.804199
tensor(2681.8042, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-32.58428955078125, logqcx:-2.079277992248535
kld_qz_pz = 2.535574,kld_qz_rna = 269.001221,kld_qz_atac = 78.339554,kl_divergence = 5.497868,reconst_loss_rna = 1951.512939,        reconst_

logpzc:-32.512298583984375, logqcx:-2.0792500972747803
kld_qz_pz = 2.439374,kld_qz_rna = 316.200012,kld_qz_atac = 80.990448,kl_divergence = 5.576858,reconst_loss_rna = 1879.268311,        reconst_loss_atac = 673.831665, mu=0.124920, sigma=1.784280
reconst_loss = 2603.464355,kl_divergence_local = 5.576858,kl_weight = 1.000000,loss = 2609.041260
tensor(2609.0413, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-32.55975341796875, logqcx:-2.0792551040649414
kld_qz_pz = 2.481133,kld_qz_rna = 299.665161,kld_qz_atac = 80.421616,kl_divergence = 5.599373,reconst_loss_rna = 1944.541504,        reconst_loss_atac = 809.409973, mu=0.124983, sigma=1.783778
reconst_loss = 2804.316162,kl_divergence_local = 5.599373,kl_weight = 1.000000,loss = 2809.915527
tensor(2809.9155, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-32.480560302734375, logqcx:-2.079233169555664
kld_qz_pz = 2.397196,kld_qz_rna = 301.871826,kld_qz_atac = 80.119949,kl_divergence = 5.431751,reconst_loss_rna = 1871.901611,        reco

logpzc:-32.31123352050781, logqcx:-2.079209804534912
kld_qz_pz = 2.189679,kld_qz_rna = 276.050446,kld_qz_atac = 80.512009,kl_divergence = 5.255541,reconst_loss_rna = 1875.650757,        reconst_loss_atac = 791.485291, mu=0.126185, sigma=1.772015
reconst_loss = 2717.499023,kl_divergence_local = 5.255541,kl_weight = 1.000000,loss = 2722.754883
tensor(2722.7549, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-32.24671936035156, logqcx:-2.0792036056518555
kld_qz_pz = 2.126317,kld_qz_rna = 272.285126,kld_qz_atac = 80.191246,kl_divergence = 5.209548,reconst_loss_rna = 1926.116699,        reconst_loss_atac = 864.325989, mu=0.126180, sigma=1.771773
reconst_loss = 2840.805664,kl_divergence_local = 5.209548,kl_weight = 1.000000,loss = 2846.015381
tensor(2846.0154, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-32.246551513671875, logqcx:-2.079190731048584
kld_qz_pz = 2.121011,kld_qz_rna = 266.104065,kld_qz_atac = 80.414825,kl_divergence = 5.153402,reconst_loss_rna = 1910.366577,        recons

logpzc:-32.11677932739258, logqcx:-2.0790650844573975
kld_qz_pz = 1.960444,kld_qz_rna = 254.470367,kld_qz_atac = 84.876450,kl_divergence = 4.980296,reconst_loss_rna = 1896.228271,        reconst_loss_atac = 737.676636, mu=0.126207, sigma=1.760360
reconst_loss = 2684.265137,kl_divergence_local = 4.980296,kl_weight = 1.000000,loss = 2689.245605
tensor(2689.2456, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-32.067413330078125, logqcx:-2.0791144371032715
kld_qz_pz = 1.908484,kld_qz_rna = 302.720032,kld_qz_atac = 84.663277,kl_divergence = 4.898642,reconst_loss_rna = 1872.806030,        reconst_loss_atac = 773.276978, mu=0.126260, sigma=1.759701
reconst_loss = 2696.443848,kl_divergence_local = 4.898642,kl_weight = 1.000000,loss = 2701.342529
tensor(2701.3425, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-32.02717971801758, logqcx:-2.0790486335754395
kld_qz_pz = 1.866450,kld_qz_rna = 255.307953,kld_qz_atac = 85.999191,kl_divergence = 4.841737,reconst_loss_rna = 1884.497314,        reco

logpzc:-31.83670425415039, logqcx:-2.078852653503418
kld_qz_pz = 1.641739,kld_qz_rna = 246.874802,kld_qz_atac = 84.107147,kl_divergence = 4.857654,reconst_loss_rna = 1833.832642,        reconst_loss_atac = 667.798096, mu=0.126856, sigma=1.751044
reconst_loss = 2551.985840,kl_divergence_local = 4.857654,kl_weight = 1.000000,loss = 2556.843262
tensor(2556.8433, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-31.899208068847656, logqcx:-2.078941822052002
kld_qz_pz = 1.703145,kld_qz_rna = 243.736115,kld_qz_atac = 83.180595,kl_divergence = 4.883110,reconst_loss_rna = 1906.075439,        reconst_loss_atac = 711.260010, mu=0.126806, sigma=1.751092
reconst_loss = 2667.692383,kl_divergence_local = 4.883110,kl_weight = 1.000000,loss = 2672.575684
tensor(2672.5757, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-31.91266632080078, logqcx:-2.0790023803710938
kld_qz_pz = 1.713328,kld_qz_rna = 243.854858,kld_qz_atac = 82.761948,kl_divergence = 5.020369,reconst_loss_rna = 1856.502441,        recons

logpzc:-31.660961151123047, logqcx:-2.0791492462158203
kld_qz_pz = 1.417816,kld_qz_rna = 241.018677,kld_qz_atac = 80.213654,kl_divergence = 4.588984,reconst_loss_rna = 1856.435913,        reconst_loss_atac = 867.031006, mu=0.126639, sigma=1.743946
reconst_loss = 2773.829102,kl_divergence_local = 4.588984,kl_weight = 1.000000,loss = 2778.418213
tensor(2778.4182, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-31.642425537109375, logqcx:-2.0790181159973145
kld_qz_pz = 1.396919,kld_qz_rna = 267.074463,kld_qz_atac = 81.885422,kl_divergence = 4.622563,reconst_loss_rna = 1861.545898,        reconst_loss_atac = 773.854431, mu=0.126641, sigma=1.743457
reconst_loss = 2685.760010,kl_divergence_local = 4.622563,kl_weight = 1.000000,loss = 2690.382812
tensor(2690.3828, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-31.69423484802246, logqcx:-2.0790724754333496
kld_qz_pz = 1.448504,kld_qz_rna = 259.780640,kld_qz_atac = 81.999649,kl_divergence = 4.633578,reconst_loss_rna = 1827.662598,        rec

training:  33%|███▎      | 5/15 [00:26<00:56,  5.65s/it]logpzc:-31.558340072631836, logqcx:-2.078925132751465
kld_qz_pz = 1.277503,kld_qz_rna = 211.971069,kld_qz_atac = 84.234016,kl_divergence = 4.389784,reconst_loss_rna = 1905.399658,        reconst_loss_atac = 788.100281, mu=0.126724, sigma=1.732182
reconst_loss = 2743.858398,kl_divergence_local = 4.389784,kl_weight = 1.000000,loss = 2748.248047
tensor(2748.2480, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-31.54503059387207, logqcx:-2.0789613723754883
kld_qz_pz = 1.260816,kld_qz_rna = 234.157959,kld_qz_atac = 84.676041,kl_divergence = 4.406197,reconst_loss_rna = 1778.364990,        reconst_loss_atac = 720.609131, mu=0.126742, sigma=1.731559
reconst_loss = 2549.332520,kl_divergence_local = 4.406197,kl_weight = 1.000000,loss = 2553.738770
tensor(2553.7388, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-31.644180297851562, logqcx:-2.078972816467285
kld_qz_pz = 1.357617,kld_qz_rna = 222.728043,kld_qz_atac = 84.690163,kl_divergence

logpzc:-31.478565216064453, logqcx:-2.0788726806640625
kld_qz_pz = 1.160950,kld_qz_rna = 224.457214,kld_qz_atac = 83.343811,kl_divergence = 4.535929,reconst_loss_rna = 1840.168945,        reconst_loss_atac = 698.892334, mu=0.127925, sigma=1.716081
reconst_loss = 2589.417236,kl_divergence_local = 4.535929,kl_weight = 1.000000,loss = 2593.953125
tensor(2593.9531, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-31.557533264160156, logqcx:-2.078909397125244
kld_qz_pz = 1.236624,kld_qz_rna = 232.081543,kld_qz_atac = 84.246284,kl_divergence = 4.665100,reconst_loss_rna = 1882.414307,        reconst_loss_atac = 792.996948, mu=0.127988, sigma=1.715264
reconst_loss = 2725.768311,kl_divergence_local = 4.665100,kl_weight = 1.000000,loss = 2730.433594
tensor(2730.4336, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-31.51433563232422, logqcx:-2.078861713409424
kld_qz_pz = 1.191795,kld_qz_rna = 242.346603,kld_qz_atac = 84.058350,kl_divergence = 4.595126,reconst_loss_rna = 1822.947510,        recon

logpzc:-31.26750946044922, logqcx:-2.078505516052246
kld_qz_pz = 0.907759,kld_qz_rna = 235.833237,kld_qz_atac = 84.639069,kl_divergence = 4.097042,reconst_loss_rna = 1845.582031,        reconst_loss_atac = 723.068604, mu=0.129668, sigma=1.702818
reconst_loss = 2618.999268,kl_divergence_local = 4.097042,kl_weight = 1.000000,loss = 2623.096191
tensor(2623.0962, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-31.233936309814453, logqcx:-2.0784192085266113
kld_qz_pz = 0.871945,kld_qz_rna = 244.694992,kld_qz_atac = 82.838882,kl_divergence = 4.095827,reconst_loss_rna = 1870.096436,        reconst_loss_atac = 820.789307, mu=0.129802, sigma=1.701891
reconst_loss = 2741.232422,kl_divergence_local = 4.095827,kl_weight = 1.000000,loss = 2745.328613
tensor(2745.3286, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-31.335140228271484, logqcx:-2.078549385070801
kld_qz_pz = 0.970204,kld_qz_rna = 215.384033,kld_qz_atac = 81.805405,kl_divergence = 4.252134,reconst_loss_rna = 1917.712402,        recon

logpzc:-31.144367218017578, logqcx:-2.0779426097869873
kld_qz_pz = 0.745269,kld_qz_rna = 214.886307,kld_qz_atac = 82.617279,kl_divergence = 4.056287,reconst_loss_rna = 1891.748169,        reconst_loss_atac = 783.663269, mu=0.131030, sigma=1.688347
reconst_loss = 2725.744141,kl_divergence_local = 4.056287,kl_weight = 1.000000,loss = 2729.800537
tensor(2729.8005, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-31.16468048095703, logqcx:-2.0780234336853027
kld_qz_pz = 0.762947,kld_qz_rna = 212.509979,kld_qz_atac = 82.282707,kl_divergence = 4.042509,reconst_loss_rna = 1859.414551,        reconst_loss_atac = 784.434937, mu=0.131076, sigma=1.687648
reconst_loss = 2694.184814,kl_divergence_local = 4.042509,kl_weight = 1.000000,loss = 2698.227295
tensor(2698.2273, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-31.104900360107422, logqcx:-2.078077793121338
kld_qz_pz = 0.700877,kld_qz_rna = 216.256866,kld_qz_atac = 82.027100,kl_divergence = 3.915493,reconst_loss_rna = 1893.640137,        reco

logpzc:-31.164134979248047, logqcx:-2.077380657196045
kld_qz_pz = 0.727022,kld_qz_rna = 209.728607,kld_qz_atac = 82.723877,kl_divergence = 4.157493,reconst_loss_rna = 1851.554932,        reconst_loss_atac = 749.137085, mu=0.130927, sigma=1.676149
reconst_loss = 2651.007812,kl_divergence_local = 4.157493,kl_weight = 1.000000,loss = 2655.165527
tensor(2655.1655, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-31.085927963256836, logqcx:-2.0773673057556152
kld_qz_pz = 0.648395,kld_qz_rna = 222.695648,kld_qz_atac = 83.697968,kl_divergence = 4.169912,reconst_loss_rna = 1812.187744,        reconst_loss_atac = 826.087891, mu=0.130948, sigma=1.675623
reconst_loss = 2688.590332,kl_divergence_local = 4.169912,kl_weight = 1.000000,loss = 2692.760254
tensor(2692.7603, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-31.150543212890625, logqcx:-2.07733154296875
kld_qz_pz = 0.710614,kld_qz_rna = 209.354034,kld_qz_atac = 83.830391,kl_divergence = 4.079953,reconst_loss_rna = 1867.206421,        recon

logpzc:-31.04597282409668, logqcx:-2.0740623474121094
kld_qz_pz = 0.571925,kld_qz_rna = 216.669739,kld_qz_atac = 83.377350,kl_divergence = 3.872063,reconst_loss_rna = 1899.570435,        reconst_loss_atac = 899.467773, mu=0.129907, sigma=1.657649
reconst_loss = 2849.259277,kl_divergence_local = 3.872063,kl_weight = 1.000000,loss = 2853.131348
tensor(2853.1313, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-31.01412010192871, logqcx:-2.0735249519348145
kld_qz_pz = 0.538351,kld_qz_rna = 207.828125,kld_qz_atac = 82.179024,kl_divergence = 3.812016,reconst_loss_rna = 1820.256104,        reconst_loss_atac = 759.518799, mu=0.129914, sigma=1.656713
reconst_loss = 2629.980469,kl_divergence_local = 3.812016,kl_weight = 1.000000,loss = 2633.792480
tensor(2633.7925, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.96231460571289, logqcx:-2.0730087757110596
kld_qz_pz = 0.485942,kld_qz_rna = 214.370758,kld_qz_atac = 83.710556,kl_divergence = 3.745341,reconst_loss_rna = 1863.255249,        recon

logpzc:-30.96162223815918, logqcx:-2.062192678451538
kld_qz_pz = 0.455218,kld_qz_rna = 204.510284,kld_qz_atac = 80.745522,kl_divergence = 4.074880,reconst_loss_rna = 1843.486084,        reconst_loss_atac = 724.464172, mu=0.130666, sigma=1.634542
reconst_loss = 2617.833008,kl_divergence_local = 4.074880,kl_weight = 1.000000,loss = 2621.907959
tensor(2621.9080, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.961410522460938, logqcx:-2.058901071548462
kld_qz_pz = 0.454609,kld_qz_rna = 215.670532,kld_qz_atac = 81.031372,kl_divergence = 4.081001,reconst_loss_rna = 1813.358398,        reconst_loss_atac = 685.320251, mu=0.130660, sigma=1.633168
reconst_loss = 2548.469238,kl_divergence_local = 4.081001,kl_weight = 1.000000,loss = 2552.550049
tensor(2552.5500, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.922439575195312, logqcx:-2.0557236671447754
kld_qz_pz = 0.414201,kld_qz_rna = 202.416443,kld_qz_atac = 81.022491,kl_divergence = 4.077281,reconst_loss_rna = 1863.517578,        recon

logpzc:-30.09973907470703, logqcx:-1.4828660488128662
kld_qz_pz = -0.069396,kld_qz_rna = 196.681305,kld_qz_atac = 88.074600,kl_divergence = 3.418360,reconst_loss_rna = 1785.785156,        reconst_loss_atac = 772.593994, mu=0.133019, sigma=1.610624
reconst_loss = 2588.341553,kl_divergence_local = 3.418360,kl_weight = 1.000000,loss = 2591.759766
tensor(2591.7598, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.058603286743164, logqcx:-1.3844163417816162
kld_qz_pz = -0.038876,kld_qz_rna = 200.270615,kld_qz_atac = 88.373154,kl_divergence = 3.500170,reconst_loss_rna = 1843.239136,        reconst_loss_atac = 575.283936, mu=0.133069, sigma=1.612836
reconst_loss = 2444.980957,kl_divergence_local = 3.500170,kl_weight = 1.000000,loss = 2448.480957
tensor(2448.4810, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-29.896961212158203, logqcx:-1.2386751174926758
kld_qz_pz = -0.091689,kld_qz_rna = 197.013809,kld_qz_atac = 88.930397,kl_divergence = 3.379905,reconst_loss_rna = 1778.780151,        

logpzc:-29.86079216003418, logqcx:-0.05670996010303497
kld_qz_pz = 0.726400,kld_qz_rna = 164.108856,kld_qz_atac = 99.969696,kl_divergence = 4.677781,reconst_loss_rna = 1838.584717,        reconst_loss_atac = 767.519287, mu=0.140542, sigma=1.681225
reconst_loss = 2584.349609,kl_divergence_local = 4.677781,kl_weight = 1.000000,loss = 2589.027344
tensor(2589.0273, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-29.90189552307129, logqcx:-0.04953993111848831
kld_qz_pz = 0.769601,kld_qz_rna = 158.870102,kld_qz_atac = 101.514343,kl_divergence = 4.848736,reconst_loss_rna = 1877.183594,        reconst_loss_atac = 826.705078, mu=0.140766, sigma=1.682874
reconst_loss = 2681.898438,kl_divergence_local = 4.848736,kl_weight = 1.000000,loss = 2686.747070
tensor(2686.7471, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.01618766784668, logqcx:-0.04771779477596283
kld_qz_pz = 0.881455,kld_qz_rna = 155.930878,kld_qz_atac = 101.513069,kl_divergence = 5.162529,reconst_loss_rna = 1814.266113,        

logpzc:-30.911357879638672, logqcx:-0.027254540473222733
kld_qz_pz = 1.749325,kld_qz_rna = 157.733109,kld_qz_atac = 96.018875,kl_divergence = 5.881831,reconst_loss_rna = 1908.202515,        reconst_loss_atac = 769.769653, mu=0.142864, sigma=1.698006
reconst_loss = 2655.313477,kl_divergence_local = 5.881831,kl_weight = 1.000000,loss = 2661.195557
tensor(2661.1956, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.84070587158203, logqcx:-0.02289329469203949
kld_qz_pz = 1.681078,kld_qz_rna = 153.038116,kld_qz_atac = 97.082123,kl_divergence = 5.805972,reconst_loss_rna = 1883.755737,        reconst_loss_atac = 722.816284, mu=0.142914, sigma=1.698342
reconst_loss = 2583.736816,kl_divergence_local = 5.805972,kl_weight = 1.000000,loss = 2589.542969
tensor(2589.5430, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.837406158447266, logqcx:-0.021845970302820206
kld_qz_pz = 1.677268,kld_qz_rna = 151.923706,kld_qz_atac = 98.239525,kl_divergence = 5.892383,reconst_loss_rna = 1801.715576,      

logpzc:-30.962703704833984, logqcx:-0.028340011835098267
kld_qz_pz = 1.787635,kld_qz_rna = 194.106140,kld_qz_atac = 91.067299,kl_divergence = 6.208712,reconst_loss_rna = 1904.845703,        reconst_loss_atac = 704.016479, mu=0.143397, sigma=1.700600
logpzc:-30.898658752441406, logqcx:-0.02427832782268524
kld_qz_pz = 1.727399,kld_qz_rna = 178.300751,kld_qz_atac = 91.316345,kl_divergence = 6.202508,reconst_loss_rna = 1868.510010,        reconst_loss_atac = 785.391235, mu=0.143397, sigma=1.700600
logpzc:-30.91919708251953, logqcx:-0.02642107382416725
kld_qz_pz = 1.745832,kld_qz_rna = 195.938385,kld_qz_atac = 90.793854,kl_divergence = 6.150460,reconst_loss_rna = 1895.955078,        reconst_loss_atac = 810.948120, mu=0.143397, sigma=1.700600
logpzc:-30.844459533691406, logqcx:-0.022400572896003723
kld_qz_pz = 1.674817,kld_qz_rna = 185.029114,kld_qz_atac = 91.376297,kl_divergence = 6.022407,reconst_loss_rna = 1890.495605,        reconst_loss_atac = 768.994141, mu=0.143397, sigma=1.700600
log

logpzc:-30.903533935546875, logqcx:-0.02428382821381092
kld_qz_pz = 1.732069,kld_qz_rna = 188.468231,kld_qz_atac = 91.254448,kl_divergence = 6.204212,reconst_loss_rna = 1875.258911,        reconst_loss_atac = 769.686646, mu=0.143397, sigma=1.700600
logpzc:-30.960094451904297, logqcx:-0.02571263536810875
kld_qz_pz = 1.787506,kld_qz_rna = 207.736252,kld_qz_atac = 91.222717,kl_divergence = 6.172026,reconst_loss_rna = 1864.231934,        reconst_loss_atac = 770.866455, mu=0.143397, sigma=1.700600
logpzc:-30.841171264648438, logqcx:-0.022325385361909866
kld_qz_pz = 1.671399,kld_qz_rna = 187.752411,kld_qz_atac = 91.174797,kl_divergence = 6.076326,reconst_loss_rna = 1870.275879,        reconst_loss_atac = 795.665466, mu=0.143397, sigma=1.700600
logpzc:-31.011478424072266, logqcx:-0.02693036012351513
kld_qz_pz = 1.837691,kld_qz_rna = 198.706863,kld_qz_atac = 91.489967,kl_divergence = 6.270150,reconst_loss_rna = 1915.549927,        reconst_loss_atac = 901.215881, mu=0.143397, sigma=1.700600
log

logpzc:-30.80036735534668, logqcx:-0.02161325141787529
kld_qz_pz = 1.631099,kld_qz_rna = 185.791107,kld_qz_atac = 90.906265,kl_divergence = 5.950751,reconst_loss_rna = 1905.378662,        reconst_loss_atac = 634.464966, mu=0.143397, sigma=1.700600
logpzc:-30.845661163330078, logqcx:-0.022482994943857193
kld_qz_pz = 1.675800,kld_qz_rna = 178.838089,kld_qz_atac = 91.144272,kl_divergence = 6.074995,reconst_loss_rna = 1889.959839,        reconst_loss_atac = 756.471191, mu=0.143397, sigma=1.700600
logpzc:-30.808130264282227, logqcx:-0.022582776844501495
kld_qz_pz = 1.638062,kld_qz_rna = 188.426056,kld_qz_atac = 91.213509,kl_divergence = 6.038120,reconst_loss_rna = 1891.665161,        reconst_loss_atac = 822.752563, mu=0.143397, sigma=1.700600
logpzc:-30.825645446777344, logqcx:-0.022283395752310753
kld_qz_pz = 1.656071,kld_qz_rna = 178.830261,kld_qz_atac = 91.404388,kl_divergence = 6.097339,reconst_loss_rna = 1901.102173,        reconst_loss_atac = 777.466431, mu=0.143397, sigma=1.700600
lo

logpzc:-30.916316986083984, logqcx:-0.02562151476740837
kld_qz_pz = 1.734228,kld_qz_rna = 140.596436,kld_qz_atac = 87.660812,kl_divergence = 5.854813,reconst_loss_rna = 1842.684326,        reconst_loss_atac = 797.234192, mu=0.143849, sigma=1.702350
reconst_loss = 2617.177246,kl_divergence_local = 5.854813,kl_weight = 1.000000,loss = 2623.032227
tensor(2623.0322, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.930145263671875, logqcx:-0.02590443380177021
kld_qz_pz = 1.746771,kld_qz_rna = 143.596436,kld_qz_atac = 88.570374,kl_divergence = 5.854850,reconst_loss_rna = 1772.556641,        reconst_loss_atac = 697.764038, mu=0.143899, sigma=1.702550
reconst_loss = 2447.589355,kl_divergence_local = 5.854850,kl_weight = 1.000000,loss = 2453.444336
tensor(2453.4443, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.862220764160156, logqcx:-0.024497922509908676
kld_qz_pz = 1.679003,kld_qz_rna = 138.807617,kld_qz_atac = 90.358879,kl_divergence = 5.867836,reconst_loss_rna = 1863.911743,      

logpzc:-30.79361343383789, logqcx:-0.027373017743229866
kld_qz_pz = 1.587551,kld_qz_rna = 137.280106,kld_qz_atac = 88.790443,kl_divergence = 5.341537,reconst_loss_rna = 1839.629639,        reconst_loss_atac = 779.848022, mu=0.144437, sigma=1.706211
reconst_loss = 2596.788818,kl_divergence_local = 5.341537,kl_weight = 1.000000,loss = 2602.130371
tensor(2602.1304, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.700054168701172, logqcx:-0.024765409529209137
kld_qz_pz = 1.495052,kld_qz_rna = 129.936295,kld_qz_atac = 89.682106,kl_divergence = 5.322715,reconst_loss_rna = 1868.379639,        reconst_loss_atac = 729.968018, mu=0.144471, sigma=1.706370
reconst_loss = 2575.572266,kl_divergence_local = 5.322715,kl_weight = 1.000000,loss = 2580.894775
tensor(2580.8948, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.633872985839844, logqcx:-0.023859567940235138
kld_qz_pz = 1.428838,kld_qz_rna = 132.530060,kld_qz_atac = 90.708298,kl_divergence = 5.231341,reconst_loss_rna = 1851.456543,     

logpzc:-30.663185119628906, logqcx:-0.028517527505755424
kld_qz_pz = 1.436926,kld_qz_rna = 124.879944,kld_qz_atac = 92.073463,kl_divergence = 5.326900,reconst_loss_rna = 1789.269043,        reconst_loss_atac = 694.525879, mu=0.144178, sigma=1.709305
reconst_loss = 2461.152100,kl_divergence_local = 5.326900,kl_weight = 1.000000,loss = 2466.479004
tensor(2466.4790, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.647174835205078, logqcx:-0.026484008878469467
kld_qz_pz = 1.421756,kld_qz_rna = 124.189987,kld_qz_atac = 91.978088,kl_divergence = 5.232314,reconst_loss_rna = 1790.263916,        reconst_loss_atac = 810.397827, mu=0.144199, sigma=1.709438
reconst_loss = 2577.940918,kl_divergence_local = 5.232314,kl_weight = 1.000000,loss = 2583.173340
tensor(2583.1733, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.63938331604004, logqcx:-0.025866888463497162
kld_qz_pz = 1.413687,kld_qz_rna = 129.486389,kld_qz_atac = 91.795181,kl_divergence = 5.193822,reconst_loss_rna = 1870.822510,     

logpzc:-30.54429054260254, logqcx:-0.025457069277763367
kld_qz_pz = 1.302604,kld_qz_rna = 131.183350,kld_qz_atac = 94.175598,kl_divergence = 4.972146,reconst_loss_rna = 1792.987061,        reconst_loss_atac = 719.979797, mu=0.144645, sigma=1.712145
reconst_loss = 2490.214600,kl_divergence_local = 4.972146,kl_weight = 1.000000,loss = 2495.186523
tensor(2495.1865, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.55257225036621, logqcx:-0.02617788128554821
kld_qz_pz = 1.309332,kld_qz_rna = 130.894165,kld_qz_atac = 94.030640,kl_divergence = 5.107082,reconst_loss_rna = 1838.636963,        reconst_loss_atac = 704.661560, mu=0.144680, sigma=1.712284
reconst_loss = 2520.571777,kl_divergence_local = 5.107082,kl_weight = 1.000000,loss = 2525.679199
tensor(2525.6792, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.610347747802734, logqcx:-0.027782611548900604
kld_qz_pz = 1.364797,kld_qz_rna = 132.053711,kld_qz_atac = 93.251068,kl_divergence = 5.206642,reconst_loss_rna = 1841.081665,       

logpzc:-30.507816314697266, logqcx:-0.025110002607107162
kld_qz_pz = 1.247887,kld_qz_rna = 129.065125,kld_qz_atac = 85.017632,kl_divergence = 5.214033,reconst_loss_rna = 1834.675781,        reconst_loss_atac = 775.362793, mu=0.144966, sigma=1.714924
reconst_loss = 2587.271973,kl_divergence_local = 5.214033,kl_weight = 1.000000,loss = 2592.486328
tensor(2592.4863, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.537311553955078, logqcx:-0.026210974901914597
kld_qz_pz = 1.275585,kld_qz_rna = 130.170761,kld_qz_atac = 86.683868,kl_divergence = 5.323939,reconst_loss_rna = 1846.272949,        reconst_loss_atac = 807.932434, mu=0.144979, sigma=1.715057
reconst_loss = 2631.476562,kl_divergence_local = 5.323939,kl_weight = 1.000000,loss = 2636.800537
tensor(2636.8005, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.502635955810547, logqcx:-0.025872349739074707
kld_qz_pz = 1.240241,kld_qz_rna = 124.833054,kld_qz_atac = 88.259209,kl_divergence = 5.327982,reconst_loss_rna = 1761.823242,    

logpzc:-30.561767578125, logqcx:-0.025765428319573402
kld_qz_pz = 1.282639,kld_qz_rna = 135.759552,kld_qz_atac = 88.602219,kl_divergence = 5.640165,reconst_loss_rna = 1816.685791,        reconst_loss_atac = 792.842346, mu=0.145788, sigma=1.717537
reconst_loss = 2586.780762,kl_divergence_local = 5.640165,kl_weight = 1.000000,loss = 2592.421143
tensor(2592.4211, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.534643173217773, logqcx:-0.025592178106307983
kld_qz_pz = 1.254785,kld_qz_rna = 126.188461,kld_qz_atac = 88.217651,kl_divergence = 5.379893,reconst_loss_rna = 1798.179932,        reconst_loss_atac = 793.972534, mu=0.145823, sigma=1.717663
reconst_loss = 2569.400879,kl_divergence_local = 5.379893,kl_weight = 1.000000,loss = 2574.780762
tensor(2574.7808, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.525127410888672, logqcx:-0.02679595723748207
kld_qz_pz = 1.243318,kld_qz_rna = 119.474258,kld_qz_atac = 87.628113,kl_divergence = 5.247136,reconst_loss_rna = 1832.726196,        

logpzc:-30.381305694580078, logqcx:-0.023683859035372734
kld_qz_pz = 1.086299,kld_qz_rna = 117.939751,kld_qz_atac = 92.306488,kl_divergence = 4.896877,reconst_loss_rna = 1900.433594,        reconst_loss_atac = 872.659851, mu=0.146414, sigma=1.720132
reconst_loss = 2750.290283,kl_divergence_local = 4.896877,kl_weight = 1.000000,loss = 2755.187256
tensor(2755.1873, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.467487335205078, logqcx:-0.02501732110977173
kld_qz_pz = 1.170546,kld_qz_rna = 126.811951,kld_qz_atac = 91.778244,kl_divergence = 4.952789,reconst_loss_rna = 1849.778931,        reconst_loss_atac = 854.547974, mu=0.146395, sigma=1.720256
reconst_loss = 2681.562256,kl_divergence_local = 4.952789,kl_weight = 1.000000,loss = 2686.515137
tensor(2686.5151, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.346405029296875, logqcx:-0.02220528945326805
kld_qz_pz = 1.051229,kld_qz_rna = 119.856613,kld_qz_atac = 91.379898,kl_divergence = 4.744335,reconst_loss_rna = 1788.460938,      

logpzc:-30.410703659057617, logqcx:-0.023086532950401306
kld_qz_pz = 1.099194,kld_qz_rna = 119.127556,kld_qz_atac = 88.495560,kl_divergence = 4.979261,reconst_loss_rna = 1759.956909,        reconst_loss_atac = 687.163208, mu=0.146160, sigma=1.722762
reconst_loss = 2424.289062,kl_divergence_local = 4.979261,kl_weight = 1.000000,loss = 2429.268555
tensor(2429.2686, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.45882797241211, logqcx:-0.024307284504175186
kld_qz_pz = 1.145291,kld_qz_rna = 127.630997,kld_qz_atac = 88.941307,kl_divergence = 5.045365,reconst_loss_rna = 1801.770142,        reconst_loss_atac = 736.393433, mu=0.146162, sigma=1.722885
reconst_loss = 2515.371826,kl_divergence_local = 5.045365,kl_weight = 1.000000,loss = 2520.416992
tensor(2520.4170, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.442508697509766, logqcx:-0.024077076464891434
kld_qz_pz = 1.128359,kld_qz_rna = 125.331909,kld_qz_atac = 89.919708,kl_divergence = 5.168628,reconst_loss_rna = 1765.366699,     

logpzc:-30.456575393676758, logqcx:-0.02497030794620514
kld_qz_pz = 1.125981,kld_qz_rna = 114.959404,kld_qz_atac = 86.184540,kl_divergence = 5.186956,reconst_loss_rna = 1758.156738,        reconst_loss_atac = 663.325012, mu=0.146602, sigma=1.725363
reconst_loss = 2398.714844,kl_divergence_local = 5.186956,kl_weight = 1.000000,loss = 2403.901855
tensor(2403.9019, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.41960906982422, logqcx:-0.024118460714817047
kld_qz_pz = 1.088828,kld_qz_rna = 112.162308,kld_qz_atac = 86.480736,kl_divergence = 5.183812,reconst_loss_rna = 1819.292114,        reconst_loss_atac = 703.626648, mu=0.146641, sigma=1.725486
reconst_loss = 2500.121094,kl_divergence_local = 5.183812,kl_weight = 1.000000,loss = 2505.304688
tensor(2505.3047, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.45888900756836, logqcx:-0.02504188008606434
kld_qz_pz = 1.126561,kld_qz_rna = 112.416321,kld_qz_atac = 86.610779,kl_divergence = 5.290322,reconst_loss_rna = 1888.664307,        

training:  73%|███████▎  | 11/15 [01:06<00:25,  6.26s/it]logpzc:-30.399877548217773, logqcx:-0.023718304932117462
kld_qz_pz = 1.054432,kld_qz_rna = 115.895744,kld_qz_atac = 89.286621,kl_divergence = 5.068312,reconst_loss_rna = 1789.048340,        reconst_loss_atac = 718.151611, mu=0.147069, sigma=1.727761
reconst_loss = 2484.390869,kl_divergence_local = 5.068312,kl_weight = 1.000000,loss = 2489.458984
tensor(2489.4590, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.38132095336914, logqcx:-0.02377670258283615
kld_qz_pz = 1.035144,kld_qz_rna = 118.683899,kld_qz_atac = 90.303406,kl_divergence = 5.102689,reconst_loss_rna = 1753.265625,        reconst_loss_atac = 696.901611, mu=0.147105, sigma=1.727877
reconst_loss = 2427.363281,kl_divergence_local = 5.102689,kl_weight = 1.000000,loss = 2432.465820
tensor(2432.4658, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.44501304626465, logqcx:-0.024779878556728363
kld_qz_pz = 1.097072,kld_qz_rna = 118.059006,kld_qz_atac = 89.969177,kl_div

logpzc:-30.33254623413086, logqcx:-0.022694364190101624
kld_qz_pz = 0.972118,kld_qz_rna = 106.837646,kld_qz_atac = 89.255989,kl_divergence = 4.984264,reconst_loss_rna = 1856.465576,        reconst_loss_atac = 738.812744, mu=0.147298, sigma=1.730209
reconst_loss = 2572.437988,kl_divergence_local = 4.984264,kl_weight = 1.000000,loss = 2577.422363
tensor(2577.4224, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.312667846679688, logqcx:-0.022252529859542847
kld_qz_pz = 0.951890,kld_qz_rna = 114.526291,kld_qz_atac = 89.398132,kl_divergence = 4.903587,reconst_loss_rna = 1847.134521,        reconst_loss_atac = 807.196716, mu=0.147287, sigma=1.730327
reconst_loss = 2631.476074,kl_divergence_local = 4.903587,kl_weight = 1.000000,loss = 2636.379639
tensor(2636.3796, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.309650421142578, logqcx:-0.0216597281396389
kld_qz_pz = 0.948783,kld_qz_rna = 114.219666,kld_qz_atac = 89.828644,kl_divergence = 4.961807,reconst_loss_rna = 1848.483154,       

logpzc:-30.38334846496582, logqcx:-0.023688949644565582
kld_qz_pz = 1.005575,kld_qz_rna = 115.823891,kld_qz_atac = 88.729347,kl_divergence = 5.192621,reconst_loss_rna = 1819.103760,        reconst_loss_atac = 762.381226, mu=0.147563, sigma=1.732566
reconst_loss = 2558.674805,kl_divergence_local = 5.192621,kl_weight = 1.000000,loss = 2563.867188
tensor(2563.8672, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.462738037109375, logqcx:-0.02588345855474472
kld_qz_pz = 1.082305,kld_qz_rna = 112.985611,kld_qz_atac = 87.871429,kl_divergence = 5.361361,reconst_loss_rna = 1788.309448,        reconst_loss_atac = 738.872192, mu=0.147570, sigma=1.732677
reconst_loss = 2504.443848,kl_divergence_local = 5.361361,kl_weight = 1.000000,loss = 2509.805176
tensor(2509.8052, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.387645721435547, logqcx:-0.024554744362831116
kld_qz_pz = 1.007648,kld_qz_rna = 116.439987,kld_qz_atac = 86.829124,kl_divergence = 5.156730,reconst_loss_rna = 1842.122681,      

logpzc:-30.32328224182129, logqcx:-0.02335532382130623
kld_qz_pz = 0.929620,kld_qz_rna = 113.152252,kld_qz_atac = 86.596375,kl_divergence = 4.914742,reconst_loss_rna = 1763.896729,        reconst_loss_atac = 678.302124, mu=0.148172, sigma=1.734878
reconst_loss = 2419.384766,kl_divergence_local = 4.914742,kl_weight = 1.000000,loss = 2424.299561
tensor(2424.2996, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.274126052856445, logqcx:-0.021592559292912483
kld_qz_pz = 0.881209,kld_qz_rna = 111.442017,kld_qz_atac = 86.480370,kl_divergence = 4.851830,reconst_loss_rna = 1779.763550,        reconst_loss_atac = 648.352966, mu=0.148191, sigma=1.734989
reconst_loss = 2405.240967,kl_divergence_local = 4.851830,kl_weight = 1.000000,loss = 2410.093018
tensor(2410.0930, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.242305755615234, logqcx:-0.020641569048166275
kld_qz_pz = 0.849346,kld_qz_rna = 109.312729,kld_qz_atac = 86.321518,kl_divergence = 4.762680,reconst_loss_rna = 1942.044922,      

logpzc:-30.263904571533203, logqcx:-0.022356834262609482
kld_qz_pz = 0.855639,kld_qz_rna = 105.723732,kld_qz_atac = 88.182037,kl_divergence = 4.801435,reconst_loss_rna = 1873.752563,        reconst_loss_atac = 754.601562, mu=0.148151, sigma=1.737229
reconst_loss = 2605.508789,kl_divergence_local = 4.801435,kl_weight = 1.000000,loss = 2610.310059
tensor(2610.3101, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.324073791503906, logqcx:-0.02499941736459732
kld_qz_pz = 0.912908,kld_qz_rna = 108.447311,kld_qz_atac = 87.743118,kl_divergence = 4.856293,reconst_loss_rna = 1841.180664,        reconst_loss_atac = 796.241577, mu=0.148135, sigma=1.737342
reconst_loss = 2614.677734,kl_divergence_local = 4.856293,kl_weight = 1.000000,loss = 2619.534180
tensor(2619.5342, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.311574935913086, logqcx:-0.023191748186945915
kld_qz_pz = 0.901237,kld_qz_rna = 111.734421,kld_qz_atac = 87.811005,kl_divergence = 4.921290,reconst_loss_rna = 1755.756958,     

logpzc:-30.277936935424805, logqcx:-0.02398526296019554
kld_qz_pz = 0.854133,kld_qz_rna = 111.913673,kld_qz_atac = 88.053719,kl_divergence = 4.851976,reconst_loss_rna = 1778.117188,        reconst_loss_atac = 711.412109, mu=0.148374, sigma=1.739599
reconst_loss = 2466.754150,kl_divergence_local = 4.851976,kl_weight = 1.000000,loss = 2471.605957
tensor(2471.6060, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.278961181640625, logqcx:-0.022008270025253296
kld_qz_pz = 0.856104,kld_qz_rna = 112.300453,kld_qz_atac = 88.346664,kl_divergence = 4.903828,reconst_loss_rna = 1788.453369,        reconst_loss_atac = 600.742371, mu=0.148401, sigma=1.739714
reconst_loss = 2366.331787,kl_divergence_local = 4.903828,kl_weight = 1.000000,loss = 2371.235596
tensor(2371.2356, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.277667999267578, logqcx:-0.022347427904605865
kld_qz_pz = 0.853980,kld_qz_rna = 107.493530,kld_qz_atac = 88.328018,kl_divergence = 4.948111,reconst_loss_rna = 1865.223877,     

logpzc:-30.28042221069336, logqcx:-0.02174951881170273
kld_qz_pz = 0.843904,kld_qz_rna = 106.242447,kld_qz_atac = 87.596321,kl_divergence = 5.112417,reconst_loss_rna = 1783.399048,        reconst_loss_atac = 682.764526, mu=0.148194, sigma=1.741804
reconst_loss = 2443.290771,kl_divergence_local = 5.112417,kl_weight = 1.000000,loss = 2448.403320
tensor(2448.4033, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.345340728759766, logqcx:-0.023603957146406174
kld_qz_pz = 0.906524,kld_qz_rna = 110.253174,kld_qz_atac = 88.480324,kl_divergence = 5.210508,reconst_loss_rna = 1746.397217,        reconst_loss_atac = 787.979004, mu=0.148184, sigma=1.741905
reconst_loss = 2511.564941,kl_divergence_local = 5.210508,kl_weight = 1.000000,loss = 2516.775391
tensor(2516.7754, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.334909439086914, logqcx:-0.022951766848564148
kld_qz_pz = 0.895828,kld_qz_rna = 108.845978,kld_qz_atac = 88.091858,kl_divergence = 5.119156,reconst_loss_rna = 1820.489014,      

logpzc:-30.28640365600586, logqcx:-0.022031551226973534
kld_qz_pz = 0.834223,kld_qz_rna = 104.193558,kld_qz_atac = 88.036835,kl_divergence = 4.870595,reconst_loss_rna = 1814.933838,        reconst_loss_atac = 836.577637, mu=0.149017, sigma=1.743927
reconst_loss = 2628.650879,kl_divergence_local = 4.870595,kl_weight = 1.000000,loss = 2633.521484
tensor(2633.5215, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.284912109375, logqcx:-0.0213402658700943
kld_qz_pz = 0.832385,kld_qz_rna = 104.184540,kld_qz_atac = 87.934746,kl_divergence = 4.732182,reconst_loss_rna = 1875.481689,        reconst_loss_atac = 821.482788, mu=0.149020, sigma=1.744033
reconst_loss = 2674.079102,kl_divergence_local = 4.732182,kl_weight = 1.000000,loss = 2678.811523
tensor(2678.8115, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.252342224121094, logqcx:-0.021068718284368515
kld_qz_pz = 0.799420,kld_qz_rna = 105.810219,kld_qz_atac = 89.280258,kl_divergence = 4.766800,reconst_loss_rna = 1789.449707,        re

logpzc:-30.23651695251465, logqcx:-0.022040314972400665
kld_qz_pz = 0.770190,kld_qz_rna = 101.414589,kld_qz_atac = 87.526253,kl_divergence = 4.749659,reconst_loss_rna = 1906.168457,        reconst_loss_atac = 775.866943, mu=0.148619, sigma=1.746223
reconst_loss = 2659.178955,kl_divergence_local = 4.749659,kl_weight = 1.000000,loss = 2663.928711
tensor(2663.9287, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.214298248291016, logqcx:-0.02081594243645668
kld_qz_pz = 0.748545,kld_qz_rna = 100.332138,kld_qz_atac = 88.397003,kl_divergence = 4.771957,reconst_loss_rna = 1729.799805,        reconst_loss_atac = 656.072266, mu=0.148602, sigma=1.746331
reconst_loss = 2362.971680,kl_divergence_local = 4.771957,kl_weight = 1.000000,loss = 2367.743408
tensor(2367.7434, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.23099136352539, logqcx:-0.02088618464767933
kld_qz_pz = 0.764386,kld_qz_rna = 103.288025,kld_qz_atac = 88.696075,kl_divergence = 4.777424,reconst_loss_rna = 1838.980469,        

logpzc:-30.174665451049805, logqcx:-0.02036880888044834
kld_qz_pz = 0.695549,kld_qz_rna = 100.993286,kld_qz_atac = 87.064133,kl_divergence = 4.627400,reconst_loss_rna = 1779.292725,        reconst_loss_atac = 673.382446, mu=0.149187, sigma=1.748485
reconst_loss = 2429.759033,kl_divergence_local = 4.627400,kl_weight = 1.000000,loss = 2434.386230
tensor(2434.3862, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.188621520996094, logqcx:-0.020794082432985306
kld_qz_pz = 0.708444,kld_qz_rna = 100.152229,kld_qz_atac = 86.358658,kl_divergence = 4.598874,reconst_loss_rna = 1860.456055,        reconst_loss_atac = 764.497681, mu=0.149222, sigma=1.748595
reconst_loss = 2602.051758,kl_divergence_local = 4.598874,kl_weight = 1.000000,loss = 2606.650391
tensor(2606.6504, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.203876495361328, logqcx:-0.021588148549199104
kld_qz_pz = 0.722472,kld_qz_rna = 99.901352,kld_qz_atac = 87.105042,kl_divergence = 4.681351,reconst_loss_rna = 1863.324341,      

logpzc:-30.160951614379883, logqcx:-0.020800797268748283
kld_qz_pz = 0.668039,kld_qz_rna = 103.809952,kld_qz_atac = 89.731232,kl_divergence = 4.623810,reconst_loss_rna = 1823.671875,        reconst_loss_atac = 821.388977, mu=0.149751, sigma=1.750755
reconst_loss = 2622.158691,kl_divergence_local = 4.623810,kl_weight = 1.000000,loss = 2626.782471
tensor(2626.7825, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.182308197021484, logqcx:-0.021509379148483276
kld_qz_pz = 0.688175,kld_qz_rna = 102.456573,kld_qz_atac = 90.048729,kl_divergence = 4.752495,reconst_loss_rna = 1772.158569,        reconst_loss_atac = 744.562134, mu=0.149786, sigma=1.750861
reconst_loss = 2493.844727,kl_divergence_local = 4.752495,kl_weight = 1.000000,loss = 2498.597168
tensor(2498.5972, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.1660099029541, logqcx:-0.021244242787361145
kld_qz_pz = 0.671475,kld_qz_rna = 105.858109,kld_qz_atac = 89.582489,kl_divergence = 4.697203,reconst_loss_rna = 1809.676514,      

logpzc:-30.138059616088867, logqcx:-0.020658692345023155
kld_qz_pz = 0.631650,kld_qz_rna = 101.939964,kld_qz_atac = 89.912964,kl_divergence = 4.976684,reconst_loss_rna = 1768.190063,        reconst_loss_atac = 833.883179, mu=0.150375, sigma=1.752968
reconst_loss = 2579.167236,kl_divergence_local = 4.976684,kl_weight = 1.000000,loss = 2584.143799
tensor(2584.1438, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.131744384765625, logqcx:-0.0205141082406044
kld_qz_pz = 0.624857,kld_qz_rna = 103.111267,kld_qz_atac = 90.403946,kl_divergence = 5.051039,reconst_loss_rna = 1822.808105,        reconst_loss_atac = 756.762756, mu=0.150369, sigma=1.753065
reconst_loss = 2556.660645,kl_divergence_local = 5.051039,kl_weight = 1.000000,loss = 2561.711670
tensor(2561.7117, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.17281150817871, logqcx:-0.021595720201730728
kld_qz_pz = 0.664260,kld_qz_rna = 99.359047,kld_qz_atac = 90.231522,kl_divergence = 4.964637,reconst_loss_rna = 1854.064453,        

logpzc:-30.110713958740234, logqcx:-0.020621027797460556
kld_qz_pz = 0.590444,kld_qz_rna = 101.199135,kld_qz_atac = 88.135735,kl_divergence = 4.529994,reconst_loss_rna = 1758.639893,        reconst_loss_atac = 727.144775, mu=0.150834, sigma=1.755076
reconst_loss = 2462.878418,kl_divergence_local = 4.529994,kl_weight = 1.000000,loss = 2467.408691
tensor(2467.4087, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.210586547851562, logqcx:-0.023300033062696457
kld_qz_pz = 0.687404,kld_qz_rna = 102.827721,kld_qz_atac = 88.706955,kl_divergence = 4.731611,reconst_loss_rna = 1879.100830,        reconst_loss_atac = 830.412964, mu=0.150842, sigma=1.755178
reconst_loss = 2686.695801,kl_divergence_local = 4.731611,kl_weight = 1.000000,loss = 2691.427246
tensor(2691.4272, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.147132873535156, logqcx:-0.020869750529527664
kld_qz_pz = 0.625422,kld_qz_rna = 102.358261,kld_qz_atac = 89.726212,kl_divergence = 4.753560,reconst_loss_rna = 1794.354126,    

logpzc:-30.147966384887695, logqcx:-0.021636268123984337
kld_qz_pz = 0.613693,kld_qz_rna = 103.681190,kld_qz_atac = 91.119324,kl_divergence = 4.669721,reconst_loss_rna = 1829.733521,        reconst_loss_atac = 645.472656, mu=0.150539, sigma=1.757233
reconst_loss = 2452.330566,kl_divergence_local = 4.669721,kl_weight = 1.000000,loss = 2457.000488
tensor(2457.0005, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.181026458740234, logqcx:-0.02184925600886345
kld_qz_pz = 0.645926,kld_qz_rna = 102.512924,kld_qz_atac = 91.669922,kl_divergence = 4.839202,reconst_loss_rna = 1764.587646,        reconst_loss_atac = 832.557312, mu=0.150539, sigma=1.757333
reconst_loss = 2574.274170,kl_divergence_local = 4.839202,kl_weight = 1.000000,loss = 2579.113525
tensor(2579.1135, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.084280014038086, logqcx:-0.019869085401296616
kld_qz_pz = 0.550090,kld_qz_rna = 102.744629,kld_qz_atac = 90.215042,kl_divergence = 4.529823,reconst_loss_rna = 1823.127197,     

reconst_loss = 2572.632080,kl_divergence_local = 4.819656,kl_weight = 1.000000,loss = 2577.451660
tensor(2577.4517, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.19377899169922, logqcx:-0.022214030846953392
kld_qz_pz = 0.646038,kld_qz_rna = 97.324982,kld_qz_atac = 86.424072,kl_divergence = 4.909023,reconst_loss_rna = 1778.682861,        reconst_loss_atac = 738.872498, mu=0.150777, sigma=1.759267
reconst_loss = 2494.697510,kl_divergence_local = 4.909023,kl_weight = 1.000000,loss = 2499.606689
tensor(2499.6067, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.121116638183594, logqcx:-0.02081826515495777
kld_qz_pz = 0.574000,kld_qz_rna = 98.459671,kld_qz_atac = 86.494194,kl_divergence = 4.736370,reconst_loss_rna = 1796.850098,        reconst_loss_atac = 750.237061, mu=0.150736, sigma=1.759357
reconst_loss = 2524.184326,kl_divergence_local = 4.736370,kl_weight = 1.000000,loss = 2528.920898
tensor(2528.9209, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.12590217590332, logqcx

logpzc:-30.130844116210938, logqcx:-0.021796083077788353
kld_qz_pz = 0.582479,kld_qz_rna = 105.300705,kld_qz_atac = 87.943710,kl_divergence = 4.618630,reconst_loss_rna = 1844.864258,        reconst_loss_atac = 857.497070, mu=0.150695, sigma=1.759447
logpzc:-30.089366912841797, logqcx:-0.020619414746761322
kld_qz_pz = 0.541763,kld_qz_rna = 107.496635,kld_qz_atac = 87.854301,kl_divergence = 4.566188,reconst_loss_rna = 1859.543091,        reconst_loss_atac = 748.932007, mu=0.150695, sigma=1.759447
logpzc:-30.08350944519043, logqcx:-0.020257115364074707
kld_qz_pz = 0.536423,kld_qz_rna = 102.820358,kld_qz_atac = 87.824982,kl_divergence = 4.576951,reconst_loss_rna = 1837.723633,        reconst_loss_atac = 748.493103, mu=0.150695, sigma=1.759447
logpzc:-30.13896942138672, logqcx:-0.021012892946600914
kld_qz_pz = 0.591130,kld_qz_rna = 105.901070,kld_qz_atac = 87.565781,kl_divergence = 4.641905,reconst_loss_rna = 1782.444824,        reconst_loss_atac = 697.861084, mu=0.150695, sigma=1.759447
lo

logpzc:-30.054399490356445, logqcx:-0.019392982125282288
kld_qz_pz = 0.508054,kld_qz_rna = 104.652069,kld_qz_atac = 87.751099,kl_divergence = 4.598820,reconst_loss_rna = 1830.572510,        reconst_loss_atac = 706.817505, mu=0.150695, sigma=1.759447
logpzc:-30.06378173828125, logqcx:-0.019752373918890953
kld_qz_pz = 0.517094,kld_qz_rna = 101.949081,kld_qz_atac = 87.922256,kl_divergence = 4.585655,reconst_loss_rna = 1754.338501,        reconst_loss_atac = 740.219604, mu=0.150695, sigma=1.759447
logpzc:-30.032176971435547, logqcx:-0.01930706575512886
kld_qz_pz = 0.485553,kld_qz_rna = 99.672318,kld_qz_atac = 87.395973,kl_divergence = 4.536216,reconst_loss_rna = 1800.516113,        reconst_loss_atac = 808.025452, mu=0.150695, sigma=1.759447
logpzc:-30.106765747070312, logqcx:-0.020668912678956985
kld_qz_pz = 0.559218,kld_qz_rna = 104.607338,kld_qz_atac = 87.721092,kl_divergence = 4.585736,reconst_loss_rna = 1816.715332,        reconst_loss_atac = 794.201477, mu=0.150695, sigma=1.759447
log

logpzc:-29.991657257080078, logqcx:-0.018369611352682114
kld_qz_pz = 0.446141,kld_qz_rna = 100.961807,kld_qz_atac = 87.925018,kl_divergence = 4.496940,reconst_loss_rna = 1853.009766,        reconst_loss_atac = 851.411011, mu=0.150695, sigma=1.759447
logpzc:-30.036216735839844, logqcx:-0.019014529883861542
kld_qz_pz = 0.490219,kld_qz_rna = 104.227463,kld_qz_atac = 87.919502,kl_divergence = 4.559896,reconst_loss_rna = 1923.289307,        reconst_loss_atac = 773.958618, mu=0.150695, sigma=1.759447
logpzc:-30.030742645263672, logqcx:-0.019052142277359962
kld_qz_pz = 0.484609,kld_qz_rna = 102.269028,kld_qz_atac = 87.594978,kl_divergence = 4.534973,reconst_loss_rna = 1870.008423,        reconst_loss_atac = 754.840637, mu=0.150695, sigma=1.759447
logpzc:-30.057374954223633, logqcx:-0.01962468773126602
kld_qz_pz = 0.510795,kld_qz_rna = 101.943382,kld_qz_atac = 87.417694,kl_divergence = 4.597278,reconst_loss_rna = 1925.834351,        reconst_loss_atac = 744.993530, mu=0.150695, sigma=1.759447
l

logpzc:-30.187578201293945, logqcx:-0.021660657599568367
kld_qz_pz = 0.630605,kld_qz_rna = 100.232719,kld_qz_atac = 89.902687,kl_divergence = 4.854908,reconst_loss_rna = 1757.366821,        reconst_loss_atac = 671.041382, mu=0.150934, sigma=1.760685
reconst_loss = 2405.534180,kl_divergence_local = 4.854908,kl_weight = 1.000000,loss = 2410.389160
tensor(2410.3892, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.207368850708008, logqcx:-0.02279214933514595
kld_qz_pz = 0.648834,kld_qz_rna = 98.366806,kld_qz_atac = 90.710953,kl_divergence = 4.973551,reconst_loss_rna = 1843.981445,        reconst_loss_atac = 668.311340, mu=0.150955, sigma=1.760767
reconst_loss = 2489.455566,kl_divergence_local = 4.973551,kl_weight = 1.000000,loss = 2494.429199
tensor(2494.4292, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.165416717529297, logqcx:-0.021155670285224915
kld_qz_pz = 0.607607,kld_qz_rna = 101.110580,kld_qz_atac = 90.757576,kl_divergence = 4.885382,reconst_loss_rna = 1813.608154,      

logpzc:-30.165782928466797, logqcx:-0.021482160314917564
kld_qz_pz = 0.595251,kld_qz_rna = 97.272713,kld_qz_atac = 87.217728,kl_divergence = 4.881033,reconst_loss_rna = 1848.401855,        reconst_loss_atac = 821.024536, mu=0.151280, sigma=1.762599
reconst_loss = 2646.550781,kl_divergence_local = 4.881033,kl_weight = 1.000000,loss = 2651.431885
tensor(2651.4319, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.186033248901367, logqcx:-0.022257156670093536
kld_qz_pz = 0.614270,kld_qz_rna = 100.582222,kld_qz_atac = 87.698853,kl_divergence = 5.020841,reconst_loss_rna = 1774.112305,        reconst_loss_atac = 799.350830, mu=0.151345, sigma=1.762690
reconst_loss = 2550.611328,kl_divergence_local = 5.020841,kl_weight = 1.000000,loss = 2555.632324
tensor(2555.6323, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.268953323364258, logqcx:-0.024375170469284058
kld_qz_pz = 0.694872,kld_qz_rna = 99.998901,kld_qz_atac = 87.786873,kl_divergence = 5.158951,reconst_loss_rna = 1856.364990,      

logpzc:-30.206531524658203, logqcx:-0.02197512611746788
kld_qz_pz = 0.621643,kld_qz_rna = 99.784935,kld_qz_atac = 85.442688,kl_divergence = 4.856315,reconst_loss_rna = 1828.055420,        reconst_loss_atac = 833.313416, mu=0.152667, sigma=1.764440
reconst_loss = 2638.508789,kl_divergence_local = 4.856315,kl_weight = 1.000000,loss = 2643.365234
tensor(2643.3652, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.194438934326172, logqcx:-0.020681604743003845
kld_qz_pz = 0.610141,kld_qz_rna = 96.945305,kld_qz_atac = 85.302071,kl_divergence = 4.862062,reconst_loss_rna = 1747.020264,        reconst_loss_atac = 705.203491, mu=0.152746, sigma=1.764530
reconst_loss = 2429.317871,kl_divergence_local = 4.862062,kl_weight = 1.000000,loss = 2434.179688
tensor(2434.1797, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.2003173828125, logqcx:-0.02140716463327408
kld_qz_pz = 0.614818,kld_qz_rna = 101.748123,kld_qz_atac = 84.880707,kl_divergence = 4.755867,reconst_loss_rna = 1777.292358,        re

logpzc:-30.143901824951172, logqcx:-0.020540248602628708
kld_qz_pz = 0.548725,kld_qz_rna = 98.449142,kld_qz_atac = 87.636475,kl_divergence = 4.649350,reconst_loss_rna = 1804.414673,        reconst_loss_atac = 815.180847, mu=0.153077, sigma=1.766540
reconst_loss = 2596.686768,kl_divergence_local = 4.649350,kl_weight = 1.000000,loss = 2601.336182
tensor(2601.3362, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.145301818847656, logqcx:-0.020320024341344833
kld_qz_pz = 0.549867,kld_qz_rna = 97.926819,kld_qz_atac = 87.590530,kl_divergence = 4.767081,reconst_loss_rna = 1762.071167,        reconst_loss_atac = 678.680664, mu=0.153074, sigma=1.766643
reconst_loss = 2417.834473,kl_divergence_local = 4.767081,kl_weight = 1.000000,loss = 2422.601562
tensor(2422.6016, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.14680290222168, logqcx:-0.020726308226585388
kld_qz_pz = 0.550585,kld_qz_rna = 98.102478,kld_qz_atac = 88.131531,kl_divergence = 4.786099,reconst_loss_rna = 1784.889771,        

training: 16it [01:46,  7.00s/it]                        logpzc:-30.07445526123047, logqcx:-0.019926924258470535
kld_qz_pz = 0.467762,kld_qz_rna = 93.977859,kld_qz_atac = 90.303299,kl_divergence = 4.526349,reconst_loss_rna = 1808.723877,        reconst_loss_atac = 826.854553, mu=0.153238, sigma=1.768624
reconst_loss = 2612.649170,kl_divergence_local = 4.526349,kl_weight = 1.000000,loss = 2617.175293
tensor(2617.1753, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.091270446777344, logqcx:-0.02037295699119568
kld_qz_pz = 0.483563,kld_qz_rna = 95.803368,kld_qz_atac = 89.960037,kl_divergence = 4.534919,reconst_loss_rna = 1692.183472,        reconst_loss_atac = 714.912231, mu=0.153288, sigma=1.768721
reconst_loss = 2384.181152,kl_divergence_local = 4.534919,kl_weight = 1.000000,loss = 2388.716064
tensor(2388.7161, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.06911849975586, logqcx:-0.02005743235349655
kld_qz_pz = 0.461125,kld_qz_rna = 97.625687,kld_qz_atac = 90.004440,kl_diverge

logpzc:-30.1567440032959, logqcx:-0.02212340757250786
kld_qz_pz = 0.536808,kld_qz_rna = 100.199219,kld_qz_atac = 90.551247,kl_divergence = 4.867879,reconst_loss_rna = 1818.606445,        reconst_loss_atac = 772.629517, mu=0.154258, sigma=1.770573
reconst_loss = 2568.378906,kl_divergence_local = 4.867879,kl_weight = 1.000000,loss = 2573.246826
tensor(2573.2468, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.166955947875977, logqcx:-0.02233307436108589
kld_qz_pz = 0.546328,kld_qz_rna = 100.388382,kld_qz_atac = 91.056625,kl_divergence = 4.854681,reconst_loss_rna = 1783.689941,        reconst_loss_atac = 766.760925, mu=0.154254, sigma=1.770660
reconst_loss = 2527.601562,kl_divergence_local = 4.854681,kl_weight = 1.000000,loss = 2532.456299
tensor(2532.4563, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.193546295166016, logqcx:-0.022207580506801605
kld_qz_pz = 0.572387,kld_qz_rna = 97.734901,kld_qz_atac = 91.083206,kl_divergence = 5.006492,reconst_loss_rna = 1786.802979,        r

logpzc:-30.07337188720703, logqcx:-0.01947271265089512
kld_qz_pz = 0.443560,kld_qz_rna = 94.134018,kld_qz_atac = 91.632271,kl_divergence = 4.676626,reconst_loss_rna = 1790.255859,        reconst_loss_atac = 718.395386, mu=0.153963, sigma=1.772372
reconst_loss = 2485.705566,kl_divergence_local = 4.676626,kl_weight = 1.000000,loss = 2490.382324
tensor(2490.3823, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.0915470123291, logqcx:-0.019705556333065033
kld_qz_pz = 0.460886,kld_qz_rna = 96.716721,kld_qz_atac = 91.336502,kl_divergence = 4.766809,reconst_loss_rna = 1661.783813,        reconst_loss_atac = 627.873108, mu=0.153968, sigma=1.772463
reconst_loss = 2266.717529,kl_divergence_local = 4.766809,kl_weight = 1.000000,loss = 2271.484375
tensor(2271.4844, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.116016387939453, logqcx:-0.020654354244470596
kld_qz_pz = 0.483991,kld_qz_rna = 93.925545,kld_qz_atac = 90.969368,kl_divergence = 4.769824,reconst_loss_rna = 1720.328857,        rec

logpzc:-30.004961013793945, logqcx:-0.018674563616514206
kld_qz_pz = 0.363691,kld_qz_rna = 93.966171,kld_qz_atac = 91.480705,kl_divergence = 4.551578,reconst_loss_rna = 1762.038696,        reconst_loss_atac = 697.840515, mu=0.154329, sigma=1.774194
reconst_loss = 2436.908936,kl_divergence_local = 4.551578,kl_weight = 1.000000,loss = 2441.460449
tensor(2441.4604, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.03107452392578, logqcx:-0.01969277858734131
kld_qz_pz = 0.388504,kld_qz_rna = 95.253128,kld_qz_atac = 91.303223,kl_divergence = 4.480945,reconst_loss_rna = 1865.107666,        reconst_loss_atac = 767.780518, mu=0.154365, sigma=1.774282
reconst_loss = 2609.952637,kl_divergence_local = 4.480945,kl_weight = 1.000000,loss = 2614.433594
tensor(2614.4336, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.044696807861328, logqcx:-0.019726641476154327
kld_qz_pz = 0.401577,kld_qz_rna = 95.000015,kld_qz_atac = 90.850822,kl_divergence = 4.537833,reconst_loss_rna = 1762.592773,        r

logpzc:-30.05314064025879, logqcx:-0.0199783556163311
kld_qz_pz = 0.399614,kld_qz_rna = 97.344223,kld_qz_atac = 91.650406,kl_divergence = 4.589379,reconst_loss_rna = 1774.130249,        reconst_loss_atac = 742.787842, mu=0.154643, sigma=1.776143
reconst_loss = 2493.988281,kl_divergence_local = 4.589379,kl_weight = 1.000000,loss = 2498.577637
tensor(2498.5776, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.018800735473633, logqcx:-0.019842248409986496
kld_qz_pz = 0.364718,kld_qz_rna = 96.803986,kld_qz_atac = 90.765617,kl_divergence = 4.554488,reconst_loss_rna = 1832.686523,        reconst_loss_atac = 664.919128, mu=0.154704, sigma=1.776241
reconst_loss = 2474.674316,kl_divergence_local = 4.554488,kl_weight = 1.000000,loss = 2479.229004
tensor(2479.2290, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-29.991056442260742, logqcx:-0.019913338124752045
kld_qz_pz = 0.336463,kld_qz_rna = 94.973740,kld_qz_atac = 90.675537,kl_divergence = 4.660694,reconst_loss_rna = 1790.736328,        re

logpzc:-30.0322265625, logqcx:-0.021298667415976524
kld_qz_pz = 0.365909,kld_qz_rna = 94.359222,kld_qz_atac = 92.702629,kl_divergence = 4.696903,reconst_loss_rna = 1771.315063,        reconst_loss_atac = 690.826660, mu=0.154745, sigma=1.777964
reconst_loss = 2439.260742,kl_divergence_local = 4.696903,kl_weight = 1.000000,loss = 2443.957520
tensor(2443.9575, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.0256404876709, logqcx:-0.020832251757383347
kld_qz_pz = 0.359221,kld_qz_rna = 91.525642,kld_qz_atac = 91.946396,kl_divergence = 4.666977,reconst_loss_rna = 1765.494263,        reconst_loss_atac = 803.920654, mu=0.154743, sigma=1.778043
reconst_loss = 2546.518066,kl_divergence_local = 4.666977,kl_weight = 1.000000,loss = 2551.185059
tensor(2551.1851, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.035724639892578, logqcx:-0.020815320312976837
kld_qz_pz = 0.368844,kld_qz_rna = 96.976357,kld_qz_atac = 90.532425,kl_divergence = 4.584926,reconst_loss_rna = 1813.386230,        recons

logpzc:-30.046707153320312, logqcx:-0.019674405455589294
kld_qz_pz = 0.370131,kld_qz_rna = 89.041977,kld_qz_atac = 89.465080,kl_divergence = 4.641997,reconst_loss_rna = 1728.821777,        reconst_loss_atac = 749.219177, mu=0.154177, sigma=1.779328
reconst_loss = 2455.102051,kl_divergence_local = 4.641997,kl_weight = 1.000000,loss = 2459.743896
tensor(2459.7439, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.072193145751953, logqcx:-0.019839152693748474
kld_qz_pz = 0.394943,kld_qz_rna = 92.061493,kld_qz_atac = 89.658195,kl_divergence = 4.607725,reconst_loss_rna = 1741.880249,        reconst_loss_atac = 780.828857, mu=0.154194, sigma=1.779396
reconst_loss = 2499.774902,kl_divergence_local = 4.607725,kl_weight = 1.000000,loss = 2504.382324
tensor(2504.3823, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.023616790771484, logqcx:-0.019097425043582916
kld_qz_pz = 0.346384,kld_qz_rna = 91.115189,kld_qz_atac = 89.854858,kl_divergence = 4.564250,reconst_loss_rna = 1768.373047,       

logpzc:-30.103294372558594, logqcx:-0.02052607387304306
kld_qz_pz = 0.414989,kld_qz_rna = 94.295273,kld_qz_atac = 88.614334,kl_divergence = 4.709306,reconst_loss_rna = 1771.571045,        reconst_loss_atac = 705.428467, mu=0.154532, sigma=1.780948
reconst_loss = 2454.088379,kl_divergence_local = 4.709306,kl_weight = 1.000000,loss = 2458.797607
tensor(2458.7976, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.082687377929688, logqcx:-0.020870991051197052
kld_qz_pz = 0.393567,kld_qz_rna = 92.745186,kld_qz_atac = 88.433380,kl_divergence = 4.674250,reconst_loss_rna = 1819.791016,        reconst_loss_atac = 745.173096, mu=0.154605, sigma=1.781028
reconst_loss = 2542.067383,kl_divergence_local = 4.674250,kl_weight = 1.000000,loss = 2546.741699
tensor(2546.7417, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-29.913558959960938, logqcx:-0.017275897786021233
kld_qz_pz = 0.226589,kld_qz_rna = 91.222458,kld_qz_atac = 88.161552,kl_divergence = 4.365747,reconst_loss_rna = 1776.232910,        

logpzc:-30.14312744140625, logqcx:-0.022328803315758705
kld_qz_pz = 0.442021,kld_qz_rna = 95.416893,kld_qz_atac = 90.503441,kl_divergence = 4.703038,reconst_loss_rna = 1842.786011,        reconst_loss_atac = 845.353638, mu=0.155067, sigma=1.782611
reconst_loss = 2665.291016,kl_divergence_local = 4.703038,kl_weight = 1.000000,loss = 2669.994141
tensor(2669.9941, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.08561134338379, logqcx:-0.02079165354371071
kld_qz_pz = 0.385114,kld_qz_rna = 91.090805,kld_qz_atac = 89.953430,kl_divergence = 4.658063,reconst_loss_rna = 1791.190430,        reconst_loss_atac = 670.352356, mu=0.155131, sigma=1.782686
reconst_loss = 2438.642578,kl_divergence_local = 4.658063,kl_weight = 1.000000,loss = 2443.300537
tensor(2443.3005, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.010543823242188, logqcx:-0.018774457275867462
kld_qz_pz = 0.311186,kld_qz_rna = 92.091408,kld_qz_atac = 89.949371,kl_divergence = 4.455805,reconst_loss_rna = 1779.592773,        re

logpzc:-30.013713836669922, logqcx:-0.01933470368385315
kld_qz_pz = 0.303761,kld_qz_rna = 90.973366,kld_qz_atac = 89.298538,kl_divergence = 4.628732,reconst_loss_rna = 1840.595215,        reconst_loss_atac = 853.014221, mu=0.155839, sigma=1.784252
reconst_loss = 2670.661621,kl_divergence_local = 4.628732,kl_weight = 1.000000,loss = 2675.290283
tensor(2675.2903, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.084569931030273, logqcx:-0.02129712514579296
kld_qz_pz = 0.372540,kld_qz_rna = 89.947296,kld_qz_atac = 89.011551,kl_divergence = 4.595290,reconst_loss_rna = 1806.745117,        reconst_loss_atac = 775.555298, mu=0.155870, sigma=1.784336
reconst_loss = 2559.419922,kl_divergence_local = 4.595290,kl_weight = 1.000000,loss = 2564.015381
tensor(2564.0154, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.002994537353516, logqcx:-0.019778277724981308
kld_qz_pz = 0.291641,kld_qz_rna = 88.681412,kld_qz_atac = 88.844788,kl_divergence = 4.424612,reconst_loss_rna = 1725.058716,        r

logpzc:-30.075971603393555, logqcx:-0.022104548290371895
kld_qz_pz = 0.353651,kld_qz_rna = 87.723785,kld_qz_atac = 90.471581,kl_divergence = 4.638409,reconst_loss_rna = 1754.698486,        reconst_loss_atac = 789.933838, mu=0.156471, sigma=1.786014
reconst_loss = 2521.778809,kl_divergence_local = 4.638409,kl_weight = 1.000000,loss = 2526.417480
tensor(2526.4175, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.033275604248047, logqcx:-0.020223204046487808
kld_qz_pz = 0.311906,kld_qz_rna = 86.949951,kld_qz_atac = 90.165459,kl_divergence = 4.585989,reconst_loss_rna = 1724.203369,        reconst_loss_atac = 690.931030, mu=0.156498, sigma=1.786096
reconst_loss = 2392.214355,kl_divergence_local = 4.585989,kl_weight = 1.000000,loss = 2396.800293
tensor(2396.8003, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.08955192565918, logqcx:-0.021697448566555977
kld_qz_pz = 0.366561,kld_qz_rna = 89.132790,kld_qz_atac = 90.217804,kl_divergence = 4.846704,reconst_loss_rna = 1876.388672,        

logpzc:-30.577186584472656, logqcx:-0.014562336727976799
kld_qz_pz = 0.840293,kld_qz_rna = 84.526039,kld_qz_atac = 75.634048,kl_divergence = 5.595485,reconst_loss_rna = 1749.079468,        reconst_loss_atac = 749.748901, mu=0.155778, sigma=1.782521
reconst_loss = 2475.735352,kl_divergence_local = 5.595485,kl_weight = 1.000000,loss = 2481.330811
tensor(2481.3308, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.711673736572266, logqcx:-0.016095900908112526
kld_qz_pz = 0.972884,kld_qz_rna = 86.559479,kld_qz_atac = 76.911812,kl_divergence = 5.765931,reconst_loss_rna = 1772.387939,        reconst_loss_atac = 763.189575, mu=0.155961, sigma=1.782498
reconst_loss = 2512.531494,kl_divergence_local = 5.765931,kl_weight = 1.000000,loss = 2518.297363
tensor(2518.2974, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.70584487915039, logqcx:-0.016177069395780563
kld_qz_pz = 0.966169,kld_qz_rna = 89.946167,kld_qz_atac = 78.622658,kl_divergence = 5.680056,reconst_loss_rna = 1822.723633,        

logpzc:-30.40205955505371, logqcx:-0.015175623819231987
kld_qz_pz = 0.655741,kld_qz_rna = 90.261406,kld_qz_atac = 90.054474,kl_divergence = 5.108608,reconst_loss_rna = 1775.536377,        reconst_loss_atac = 779.533203, mu=0.156952, sigma=1.785402
reconst_loss = 2531.990967,kl_divergence_local = 5.108608,kl_weight = 1.000000,loss = 2537.099609
tensor(2537.0996, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.43171501159668, logqcx:-0.01585271582007408
kld_qz_pz = 0.684413,kld_qz_rna = 88.989349,kld_qz_atac = 89.760628,kl_divergence = 5.125244,reconst_loss_rna = 1737.555908,        reconst_loss_atac = 787.573975, mu=0.156933, sigma=1.785612
reconst_loss = 2502.071777,kl_divergence_local = 5.125244,kl_weight = 1.000000,loss = 2507.196777
tensor(2507.1968, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.382415771484375, logqcx:-0.016168687492609024
kld_qz_pz = 0.634524,kld_qz_rna = 90.588409,kld_qz_atac = 89.406219,kl_divergence = 5.033135,reconst_loss_rna = 1823.188232,        re

logpzc:-30.13384437561035, logqcx:-0.018872834742069244
kld_qz_pz = 0.375002,kld_qz_rna = 92.294250,kld_qz_atac = 88.838165,kl_divergence = 4.617157,reconst_loss_rna = 1811.434326,        reconst_loss_atac = 733.184326, mu=0.156558, sigma=1.789349
reconst_loss = 2521.660156,kl_divergence_local = 4.617157,kl_weight = 1.000000,loss = 2526.277344
tensor(2526.2773, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.134340286254883, logqcx:-0.018973272293806076
kld_qz_pz = 0.374982,kld_qz_rna = 93.622704,kld_qz_atac = 89.126068,kl_divergence = 4.537897,reconst_loss_rna = 1752.072998,        reconst_loss_atac = 807.397278, mu=0.156555, sigma=1.789502
reconst_loss = 2536.514648,kl_divergence_local = 4.537897,kl_weight = 1.000000,loss = 2541.052734
tensor(2541.0527, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.089067459106445, logqcx:-0.01780708134174347
kld_qz_pz = 0.330143,kld_qz_rna = 95.537094,kld_qz_atac = 88.982430,kl_divergence = 4.390963,reconst_loss_rna = 1841.596191,        r

logpzc:-30.058326721191406, logqcx:-0.019963331520557404
kld_qz_pz = 0.289378,kld_qz_rna = 97.307266,kld_qz_atac = 91.035080,kl_divergence = 4.491994,reconst_loss_rna = 1755.575317,        reconst_loss_atac = 703.264771, mu=0.156878, sigma=1.791919
reconst_loss = 2435.912109,kl_divergence_local = 4.491994,kl_weight = 1.000000,loss = 2440.404297
tensor(2440.4043, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.013092041015625, logqcx:-0.01937197893857956
kld_qz_pz = 0.244019,kld_qz_rna = 93.146759,kld_qz_atac = 90.642944,kl_divergence = 4.387217,reconst_loss_rna = 1739.071167,        reconst_loss_atac = 702.787231, mu=0.156902, sigma=1.792017
reconst_loss = 2418.911865,kl_divergence_local = 4.387217,kl_weight = 1.000000,loss = 2423.299072
tensor(2423.2991, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.00720977783203, logqcx:-0.019916221499443054
kld_qz_pz = 0.237290,kld_qz_rna = 93.048714,kld_qz_atac = 90.764328,kl_divergence = 4.414396,reconst_loss_rna = 1800.191284,        r

logpzc:-30.0313777923584, logqcx:-0.021342042833566666
kld_qz_pz = 0.251835,kld_qz_rna = 92.846970,kld_qz_atac = 91.562881,kl_divergence = 4.524554,reconst_loss_rna = 1802.610352,        reconst_loss_atac = 699.252197, mu=0.157203, sigma=1.793636
reconst_loss = 2478.983887,kl_divergence_local = 4.524554,kl_weight = 1.000000,loss = 2483.508301
tensor(2483.5083, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-29.964614868164062, logqcx:-0.01948610134422779
kld_qz_pz = 0.186031,kld_qz_rna = 94.482643,kld_qz_atac = 91.553543,kl_divergence = 4.442956,reconst_loss_rna = 1790.770752,        reconst_loss_atac = 745.360474, mu=0.157200, sigma=1.793701
reconst_loss = 2513.189941,kl_divergence_local = 4.442956,kl_weight = 1.000000,loss = 2517.633057
tensor(2517.6331, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-29.939794540405273, logqcx:-0.018818173557519913
kld_qz_pz = 0.161252,kld_qz_rna = 93.157272,kld_qz_atac = 91.377167,kl_divergence = 4.461075,reconst_loss_rna = 1800.037842,        re

logpzc:-30.032573699951172, logqcx:-0.020527565851807594
kld_qz_pz = 0.243642,kld_qz_rna = 92.429443,kld_qz_atac = 89.268036,kl_divergence = 4.673082,reconst_loss_rna = 1763.422241,        reconst_loss_atac = 785.524841, mu=0.157735, sigma=1.794957
reconst_loss = 2526.039795,kl_divergence_local = 4.673082,kl_weight = 1.000000,loss = 2530.712891
tensor(2530.7129, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.017976760864258, logqcx:-0.019707251340150833
kld_qz_pz = 0.229259,kld_qz_rna = 90.664970,kld_qz_atac = 88.921654,kl_divergence = 4.570233,reconst_loss_rna = 1760.508423,        reconst_loss_atac = 663.871704, mu=0.157750, sigma=1.795016
reconst_loss = 2401.442871,kl_divergence_local = 4.570233,kl_weight = 1.000000,loss = 2406.013184
tensor(2406.0132, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-30.0642147064209, logqcx:-0.020721716806292534
kld_qz_pz = 0.274098,kld_qz_rna = 90.790611,kld_qz_atac = 88.781715,kl_divergence = 4.634200,reconst_loss_rna = 1813.746826,        r

logpzc:-29.93218994140625, logqcx:-0.017891187220811844
kld_qz_pz = 0.135066,kld_qz_rna = 90.262543,kld_qz_atac = 89.512032,kl_divergence = 4.565406,reconst_loss_rna = 1800.028076,        reconst_loss_atac = 801.763550, mu=0.158176, sigma=1.796315
reconst_loss = 2578.796387,kl_divergence_local = 4.565406,kl_weight = 1.000000,loss = 2583.361816
tensor(2583.3618, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-29.895984649658203, logqcx:-0.01758432388305664
kld_qz_pz = 0.098758,kld_qz_rna = 89.391472,kld_qz_atac = 89.713364,kl_divergence = 4.395152,reconst_loss_rna = 1792.143066,        reconst_loss_atac = 714.923096, mu=0.158238, sigma=1.796387
reconst_loss = 2484.062500,kl_divergence_local = 4.395152,kl_weight = 1.000000,loss = 2488.457275
tensor(2488.4573, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-29.930927276611328, logqcx:-0.018306339159607887
kld_qz_pz = 0.132585,kld_qz_rna = 89.752251,kld_qz_atac = 89.090828,kl_divergence = 4.342940,reconst_loss_rna = 1800.743164,        r

logpzc:-29.971508026123047, logqcx:-0.019571954384446144
kld_qz_pz = 0.163195,kld_qz_rna = 88.485550,kld_qz_atac = 90.415604,kl_divergence = 4.425909,reconst_loss_rna = 1797.396606,        reconst_loss_atac = 725.843811, mu=0.158150, sigma=1.797810
reconst_loss = 2500.302246,kl_divergence_local = 4.425909,kl_weight = 1.000000,loss = 2504.728027
tensor(2504.7280, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-29.991968154907227, logqcx:-0.020021915435791016
kld_qz_pz = 0.182919,kld_qz_rna = 89.501854,kld_qz_atac = 90.430054,kl_divergence = 4.349897,reconst_loss_rna = 1817.931152,        reconst_loss_atac = 715.468933, mu=0.158138, sigma=1.797879
reconst_loss = 2510.475098,kl_divergence_local = 4.349897,kl_weight = 1.000000,loss = 2514.825195
tensor(2514.8252, device='cuda:0', grad_fn=<DivBackward0>)
logpzc:-29.96550750732422, logqcx:-0.01929640769958496
kld_qz_pz = 0.156609,kld_qz_rna = 89.397034,kld_qz_atac = 90.791931,kl_divergence = 4.365183,reconst_loss_rna = 1825.119141,        r

kld_qz_pz = 0.146947,kld_qz_rna = 91.282883,kld_qz_atac = 91.270729,kl_divergence = 4.189649,reconst_loss_rna = 1829.946777,        reconst_loss_atac = 875.353149, mu=0.158373, sigma=1.799037
logpzc:-29.92209815979004, logqcx:-0.018819475546479225
kld_qz_pz = 0.106204,kld_qz_rna = 88.982803,kld_qz_atac = 91.351242,kl_divergence = 4.137659,reconst_loss_rna = 1735.761108,        reconst_loss_atac = 765.939026, mu=0.158373, sigma=1.799037
logpzc:-29.92742156982422, logqcx:-0.018887460231781006
kld_qz_pz = 0.111489,kld_qz_rna = 91.608078,kld_qz_atac = 90.878815,kl_divergence = 4.104341,reconst_loss_rna = 1795.671509,        reconst_loss_atac = 729.645630, mu=0.158373, sigma=1.799037
logpzc:-29.955543518066406, logqcx:-0.01891961321234703
kld_qz_pz = 0.139655,kld_qz_rna = 92.281570,kld_qz_atac = 91.238319,kl_divergence = 4.151643,reconst_loss_rna = 1754.568237,        reconst_loss_atac = 774.330322, mu=0.158373, sigma=1.799037
logpzc:-30.026866912841797, logqcx:-0.020234474912285805
kld_qz_

kld_qz_pz = 0.063008,kld_qz_rna = 87.152298,kld_qz_atac = 91.245911,kl_divergence = 4.050262,reconst_loss_rna = 1762.510010,        reconst_loss_atac = 685.530884, mu=0.158373, sigma=1.799037
logpzc:-29.913772583007812, logqcx:-0.018778014928102493
kld_qz_pz = 0.097813,kld_qz_rna = 89.555702,kld_qz_atac = 91.033592,kl_divergence = 4.018953,reconst_loss_rna = 1834.585693,        reconst_loss_atac = 922.357422, mu=0.158373, sigma=1.799037
logpzc:-29.892805099487305, logqcx:-0.01810034178197384
kld_qz_pz = 0.077400,kld_qz_rna = 87.993576,kld_qz_atac = 91.019333,kl_divergence = 4.030270,reconst_loss_rna = 1751.586548,        reconst_loss_atac = 695.669495, mu=0.158373, sigma=1.799037
logpzc:-29.92905044555664, logqcx:-0.019152969121932983
kld_qz_pz = 0.113026,kld_qz_rna = 87.497559,kld_qz_atac = 91.281502,kl_divergence = 4.140182,reconst_loss_rna = 1818.997192,        reconst_loss_atac = 743.191162, mu=0.158373, sigma=1.799037
logpzc:-29.95406723022461, logqcx:-0.019030433148145676
kld_qz_

logpzc:-29.949512481689453, logqcx:-0.02022586017847061
kld_qz_pz = 0.132495,kld_qz_rna = 89.879166,kld_qz_atac = 90.830566,kl_divergence = 4.200336,reconst_loss_rna = 1731.656006,        reconst_loss_atac = 788.268127, mu=0.158373, sigma=1.799037
logpzc:-29.94100570678711, logqcx:-0.01882471889257431
kld_qz_pz = 0.125105,kld_qz_rna = 89.360214,kld_qz_atac = 91.011566,kl_divergence = 4.133474,reconst_loss_rna = 1783.476196,        reconst_loss_atac = 744.270020, mu=0.158373, sigma=1.799037
logpzc:-29.853931427001953, logqcx:-0.017533203586935997
kld_qz_pz = 0.038939,kld_qz_rna = 87.625832,kld_qz_atac = 90.980240,kl_divergence = 4.037556,reconst_loss_rna = 1869.212280,        reconst_loss_atac = 701.954956, mu=0.158373, sigma=1.799037
logpzc:-29.910743713378906, logqcx:-0.018733087927103043
kld_qz_pz = 0.094841,kld_qz_rna = 86.233887,kld_qz_atac = 91.045944,kl_divergence = 4.019294,reconst_loss_rna = 1735.894043,        reconst_loss_atac = 788.107666, mu=0.158373, sigma=1.799037
logpzc:

training: 100%|██████████| 15/15 [02:25<00:00,  9.72s/it]
logpzc:-29.954814910888672, logqcx:-0.01890757866203785
kld_qz_pz = 0.138804,kld_qz_rna = 91.727951,kld_qz_atac = 91.157631,kl_divergence = 4.138890,reconst_loss_rna = 1778.038086,        reconst_loss_atac = 757.010864, mu=0.158373, sigma=1.799037
logpzc:-29.903770446777344, logqcx:-0.018651606515049934
kld_qz_pz = 0.087988,kld_qz_rna = 89.859802,kld_qz_atac = 91.154488,kl_divergence = 4.124632,reconst_loss_rna = 1753.792969,        reconst_loss_atac = 653.684875, mu=0.158373, sigma=1.799037
logpzc:-29.945724487304688, logqcx:-0.018924634903669357
kld_qz_pz = 0.129733,kld_qz_rna = 91.424347,kld_qz_atac = 91.153198,kl_divergence = 4.098310,reconst_loss_rna = 1817.390503,        reconst_loss_atac = 781.542847, mu=0.158373, sigma=1.799037
logpzc:-30.024080276489258, logqcx:-0.02064727433025837
kld_qz_pz = 0.206606,kld_qz_rna = 91.736282,kld_qz_atac = 91.023361,kl_divergence = 4.215245,reconst_loss_rna = 1722.080078,        reconst_

logpzc:-29.906047821044922, logqcx:-0.018214909359812737
kld_qz_pz = 0.090415,kld_qz_rna = 90.040993,kld_qz_atac = 90.768814,kl_divergence = 4.043533,reconst_loss_rna = 1681.079346,        reconst_loss_atac = 645.459229, mu=0.158373, sigma=1.799037
logpzc:-29.966197967529297, logqcx:-0.019135046750307083
kld_qz_pz = 0.149889,kld_qz_rna = 89.457443,kld_qz_atac = 90.951416,kl_divergence = 4.133798,reconst_loss_rna = 1702.669189,        reconst_loss_atac = 735.464355, mu=0.158373, sigma=1.799037
logpzc:-29.889102935791016, logqcx:-0.017933830618858337
kld_qz_pz = 0.073768,kld_qz_rna = 89.576073,kld_qz_atac = 90.921173,kl_divergence = 4.016088,reconst_loss_rna = 1868.348755,        reconst_loss_atac = 873.966797, mu=0.158373, sigma=1.799037
logpzc:-29.90842056274414, logqcx:-0.018104188144207
kld_qz_pz = 0.093088,kld_qz_rna = 88.910950,kld_qz_atac = 91.026031,kl_divergence = 4.076343,reconst_loss_rna = 1764.947266,        reconst_loss_atac = 896.280151, mu=0.158373, sigma=1.799037
logpzc:-

kld_qz_pz = 0.128916,kld_qz_rna = 88.450668,kld_qz_atac = 91.456848,kl_divergence = 4.160156,reconst_loss_rna = 1767.443848,        reconst_loss_atac = 735.968018, mu=0.158373, sigma=1.799037
logpzc:-29.90085220336914, logqcx:-0.018050622195005417
kld_qz_pz = 0.085602,kld_qz_rna = 88.361717,kld_qz_atac = 91.111549,kl_divergence = 4.040233,reconst_loss_rna = 1777.401978,        reconst_loss_atac = 822.298462, mu=0.158373, sigma=1.799037
logpzc:-29.90607452392578, logqcx:-0.017721308395266533
kld_qz_pz = 0.090962,kld_qz_rna = 89.396996,kld_qz_atac = 91.189514,kl_divergence = 4.059210,reconst_loss_rna = 1700.548340,        reconst_loss_atac = 669.758545, mu=0.158373, sigma=1.799037
logpzc:-29.864269256591797, logqcx:-0.017320483922958374
kld_qz_pz = 0.049427,kld_qz_rna = 87.525787,kld_qz_atac = 90.966942,kl_divergence = 4.006605,reconst_loss_rna = 1808.491821,        reconst_loss_atac = 664.507935, mu=0.158373, sigma=1.799037
logpzc:-29.993831634521484, logqcx:-0.01985449157655239
kld_qz_

logpzc:-29.8114013671875, logqcx:-0.01652890257537365
kld_qz_pz = -0.002581,kld_qz_rna = 87.341721,kld_qz_atac = 91.255630,kl_divergence = 3.958299,reconst_loss_rna = 1792.736450,        reconst_loss_atac = 646.043091, mu=0.158373, sigma=1.799037
logpzc:-29.793384552001953, logqcx:-0.01623871549963951
kld_qz_pz = -0.020309,kld_qz_rna = 89.033455,kld_qz_atac = 91.187546,kl_divergence = 4.007509,reconst_loss_rna = 1811.413086,        reconst_loss_atac = 707.693726, mu=0.158373, sigma=1.799037
logpzc:-29.811634063720703, logqcx:-0.01636490784585476
kld_qz_pz = -0.002233,kld_qz_rna = 90.384445,kld_qz_atac = 91.355484,kl_divergence = 3.918237,reconst_loss_rna = 1785.967163,        reconst_loss_atac = 754.048767, mu=0.158373, sigma=1.799037
logpzc:-29.804462432861328, logqcx:-0.016383811831474304
kld_qz_pz = -0.009476,kld_qz_rna = 87.369164,kld_qz_atac = 91.055565,kl_divergence = 4.000754,reconst_loss_rna = 1766.723755,        reconst_loss_atac = 704.606262, mu=0.158373, sigma=1.799037
logpz

In [16]:
# 输出模型结果
# posterior
full = trainer.create_posterior(trainer.model, dataset, indices=np.arange(len(dataset)),type_class=MultiPosterior)
latent, latent_rna, latent_atac, cluster_gamma, cluster_index, batch_indices, labels = full.sequential().get_latent()
batch_indices = batch_indices.ravel()
imputed_values = full.sequential().imputation()
# visulization
prior_adata = anndata.AnnData(X=latent)
prior_adata.obsm["X_multi_vi"] = latent
prior_adata.obs['cell_type'] = torch.tensor(labels.reshape(-1,1))
sc.pp.neighbors(prior_adata, use_rep="X_multi_vi", n_neighbors=30)
sc.tl.umap(prior_adata, min_dist=0.3)
## matplotlib.use('TkAgg')
#fig, ax = plt.subplots(figsize=(7, 6))
#sc.pl.umap(prior_adata, color=["cell_type"], ax=ax, show=show_plot)
#plt.show()
sc.tl.louvain(prior_adata)
sc.pl.umap(prior_adata, color=['louvain'])
plt.show()

# save file
df = pd.DataFrame(data=prior_adata.obsm["X_multi_vi"],  index=rna_dataset.barcodes)
df.to_csv(os.path.join(output_path,"multivae_latent_imputation.csv"))

df = pd.DataFrame(data=prior_adata.obsm["X_umap"],  columns=["umap_dim1","umap_dim2"] , index=rna_dataset.barcodes)
df.insert(0,"louvain",prior_adata.obs['louvain'].values)
df.to_csv(os.path.join(output_path,"multivae_umap_louvain.csv"))

df = pd.DataFrame(data=imputed_values[1].T, columns=rna_dataset.barcodes, index=rna_dataset.atac_names)
#df.to_csv(os.path.join(output_path,"atac_multivae_imputation_softmax.csv"))

df = pd.DataFrame(data=imputed_values[0].T, columns=rna_dataset.barcodes, index=rna_dataset.gene_names)
#df.to_csv(os.path.join(output_path,"gene_multivae_imputation_softmax.csv"))