In [1]:
import anndata as ad
import scvi
import scanpy as sc
import mrvi
import pandas as pd
import scipy as sp
import numpy as np
import pickle as pkl
import utils

Global seed set to 0
  jax.tree_util.register_keypaths(data_clz, keypaths)
  jax.tree_util.register_keypaths(data_clz, keypaths)
  jax.tree_util.register_keypaths(data_clz, keypaths)


# Data Preprocessing

In [5]:
adata = sc.read('./../data/MGH66_bacdrop.h5ad')

In [6]:
adata.var_names
adata.obs_keys()
adata.obs

Unnamed: 0,sample,replicate
AACGTTCTGTCGAAAAAAAAAAAAAAAAA-0,untreated,0
AACGTTCTGTCGAAAACGAAAGAGCTACG-0,untreated,0
AACGTTCTGTCGAAAACGAAAGAGGTCCA-0,untreated,0
AACGTTCTGTCGAAAACGAAAGCACCATT-0,untreated,0
AACGTTCTGTCGAAAACGAACAACTCGAT-0,untreated,0
...,...,...
TTGCAGCCACAGCTTGTATCTGCCTAGTA-1,meropenem,1
TTGCAGCCACAGCTTGTGCCTGCGCACTG-1,meropenem,1
TTGCAGCCACAGCTTGTTTGTGGAGTAAT-1,meropenem,1
TTGCAGCCACAGCTTTAGCGTGAGGTGAC-1,meropenem,1


In [7]:
# TODO: more data cleaning and filtering etc?

# Running bacdrop data through MrVI

In [8]:
mrvi.MrVI.setup_anndata(adata, sample_key="sample", categorical_nuisance_keys=["replicate"])
mrvi_model = mrvi.MrVI(adata)

In [9]:
mrvi_model.train()

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Epoch 5/5: 100%|██████████| 5/5 [11:16<00:00, 135.39s/it, loss=23.6, v_num=1]


In [10]:
# Get z representation
adata.obsm["X_mrvi_z"] = mrvi_model.get_latent_representation(give_z=True)

100%|██████████| 13223/13223 [10:49<00:00, 20.37it/s]


In [11]:
# Get u representation
adata.obsm["X_mrvi_u"] = mrvi_model.get_latent_representation(give_z=False)

100%|██████████| 13223/13223 [10:50<00:00, 20.33it/s]


In [12]:
# Cells by n_sample by n_latent
# representations of each cell in its local sample
cell_sample_representations = mrvi_model.get_local_sample_representation()

100%|██████████| 6612/6612 [01:09<00:00, 94.92it/s] 


In [13]:
# Cells by n_sample by n_sample
# local sample-sample distances (section 3.1), quantifies differences in gene expression across biological samples (S x S), where S = set of samples
cell_sample_sample_distances = mrvi_model.get_local_sample_representation(return_distances=True)

100%|██████████| 6612/6612 [01:09<00:00, 95.58it/s] 


In [14]:
adata.obsm

AxisArrays with keys: _scvi_categorical_nuisance_keys, X_mrvi_z, X_mrvi_u

In [15]:
mrvi_model.adata

AnnData object with n_obs × n_vars = 1692542 × 4628
    obs: 'sample', 'replicate', '_scvi_sample', '_scvi_labels'
    uns: '_scvi_uuid', '_scvi_manager_uuid'
    obsm: '_scvi_categorical_nuisance_keys', 'X_mrvi_z', 'X_mrvi_u'

In [16]:
mrvi_model.adata.obs['_scvi_labels']

AACGTTCTGTCGAAAAAAAAAAAAAAAAA-0    0
AACGTTCTGTCGAAAACGAAAGAGCTACG-0    0
AACGTTCTGTCGAAAACGAAAGAGGTCCA-0    0
AACGTTCTGTCGAAAACGAAAGCACCATT-0    0
AACGTTCTGTCGAAAACGAACAACTCGAT-0    0
                                  ..
TTGCAGCCACAGCTTGTATCTGCCTAGTA-1    0
TTGCAGCCACAGCTTGTGCCTGCGCACTG-1    0
TTGCAGCCACAGCTTGTTTGTGGAGTAAT-1    0
TTGCAGCCACAGCTTTAGCGTGAGGTGAC-1    0
TTGCAGCCACAGCTTTCGTAACCATCCTC-1    0
Name: _scvi_labels, Length: 1692542, dtype: int8

In [17]:
mrvi_model.summary_stats

attrdict({'n_cells': 1692542, 'n_vars': 4628, 'n_sample': 4, 'n_labels': 1, 'n_categorical_nuisance_keys': 1})

In [18]:
sc.pp.filter_cells(adata, min_genes=20)
sc.pp.filter_genes(adata, min_cells=3) 

In [19]:
# Saved for easy reloads later
outpath = './../data/MrVIoutputs/bacdrop.h5ad'
adata.write_h5ad(outpath)

In [20]:
# Serializing everything for even easier reloads later
utils.write_pickle(mrvi_model, './../data/pickles/mrvi_model.pickle')

utils.write_pickle(cell_sample_representations, './../data/pickles/sample_representations.pickle')

utils.write_pickle(cell_sample_sample_distances, './../data/pickles/sample_distances.pickle')