# Use an scVI run to created "de-batched" count data for use downstream

Stephen Fleming

2025.10.08

Use a Cellarium ML scVI checkpoint for downstream inference and explore the outputs.

This notebook assumes the following kind of training run has been carried out already 

```python
config_file = "../example_configs/scvi_pbmc_config.yaml"  # your config file

!cellarium-ml scvi fit --config {config_file}
```

and that we have saved checkpoints somewhere.

In [1]:
import lightning.pytorch as pl
import pandas as pd

from cellarium.ml.callbacks import PredictionWriter
from cellarium.ml.core import CellariumAnnDataDataModule, CellariumModule
from cellarium.ml.data import DistributedAnnDataCollection
from cellarium.ml.utilities.data import AnnDataField

%matplotlib inline

  from .autonotebook import tqdm as notebook_tqdm


## Load checkpoint

In [2]:
# your saved checkpoint file (you'll need to look for it)

checkpoint_file = "gs://cellarium-dev/scvi_glyco/scvi_2025_nothing_encoded_biglatent/lightning_logs/version_1/checkpoints/epoch=14-step=916605.ckpt"

In [3]:
# load the trained module
scvi_module = CellariumModule.load_from_checkpoint(checkpoint_file, map_location="cpu")

In [4]:
scvi_module

CellariumModule(pipeline = CellariumPipeline(
  (0): Filter(filter_list=['ENSG00000000005' 'ENSG00000000971' 'ENSG00000001167' ...
   'ENSG00000101911' 'ENSG00000170222' 'ENSG00000172264'])
  (1): SingleCellVariationalInference(
    (z_encoder): EncoderSCVI(
      (fully_connected): FullyConnectedWithBatchArchitecture(
        (module_list): ModuleList(
          (0): DressedLayer(
            (layer): Linear(in_features=8430, out_features=512, bias=True)
            (dressing): Sequential(
              (0): BatchNorm1d(512, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
              (1): ReLU()
              (2): Dropout(p=0.1, inplace=False)
            )
          )
          (1): DressedLayer(
            (layer): Linear(in_features=512, out_features=512, bias=True)
            (dressing): Sequential(
              (0): BatchNorm1d(512, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
              (1): ReLU()
              (2): Dropout(p=0.1, in

## Data

In [5]:
# helper functions to get file sizes (eventually this will be part of cellarium-ml)

import concurrent.futures
from typing import Callable

import gcsfs
import h5py
import numpy as np
from tqdm import tqdm


def _h5py_read_n_obs(h5handle: h5py.File) -> int:
    idx_col = h5handle["obs"].attrs["_index"]
    try:
        n_obs = h5handle[f"obs/{idx_col}"].shape[0]
    except AttributeError:
        # can happen if somehow the obs index is saved as a categorical (not supposed to be allowed)
        n_obs = h5handle[f"obs/{idx_col}/codes"].shape[0]
    return n_obs


def _h5py_read_var_names(h5handle: h5py.File) -> np.ndarray:
    idx_col = h5handle["var"].attrs["_index"]
    try:
        var_names = h5handle[f"var/{idx_col}"][:]
    except AttributeError:
        # can happen if somehow the var index is saved as a categorical (not supposed to be allowed)
        var_names = h5handle[f"var/{idx_col}/categories"][:]
    return var_names


def get_h5ad_file_n_cells(h5ad_path: str) -> int:
    """
    Get the number of cells in each h5ad file in a list of paths.
    """
    n_cells = _h5ad_file_read_elem(h5ad_path, fun=_h5py_read_n_obs)
    assert isinstance(n_cells, int), "Expected int from _h5py_read_n_obs"
    return n_cells


def get_h5ad_files_n_cells(h5ad_paths: list[str]) -> list[int]:
    """
    Get the number of cells in each h5ad file in a list of paths.
    ThreadPoolExecutor is used (preserves order).
    """
    # return [get_h5ad_file_n_cells(h5ad_path) for h5ad_path in h5ad_paths]
    with concurrent.futures.ThreadPoolExecutor(max_workers=None) as executor:
        return list(
            tqdm(
                executor.map(get_h5ad_file_n_cells, h5ad_paths),
                total=len(h5ad_paths),
                desc="Reading n_obs from h5ad files",
                unit="file",
            )
        )


def get_h5ad_files_limits(h5ad_paths: list[str]) -> np.ndarray:
    """
    Return the `limits` to be used in constructing a :class:`~cellarium.ml.data.DistributedAnnDataCollection`
    based on sizes of the provided h5ad files.
    """
    limits = np.cumsum(get_h5ad_files_n_cells(h5ad_paths))
    return limits


def get_h5ad_file_var_names_g(h5ad_path: str) -> np.ndarray:
    """
    Get var_names_g from an h5ad file.
    """
    var_names_g = _h5ad_file_read_elem(h5ad_path, fun=_h5py_read_var_names)
    assert isinstance(var_names_g, np.ndarray), "Expected numpy array from _h5py_read_var_names"
    return var_names_g.astype(str)


def _h5ad_file_read_elem(h5ad_path: str, fun: Callable[[h5py.File], int | np.ndarray]) -> np.ndarray | int:
    """
    Read info from an h5ad file, loading as little of it as possible.
    """

    def _gcloud_version(h5ad_path: str) -> int | np.ndarray:
        fs = gcsfs.GCSFileSystem()
        with fs.open(h5ad_path, "rb") as f:
            with h5py.File(f) as h5handle:
                return fun(h5handle)

    def _local_version(h5ad_path: str) -> int | np.ndarray:
        with h5py.File(h5ad_path, "r") as h5handle:
            return fun(h5handle)

    def _url_version(h5ad_path: str) -> int | np.ndarray:
        """Optimized version that streams only the needed parts of the file"""
        raise NotImplementedError("URL version not implemented here")
        # with SeekableHTTPFile(h5ad_path) as f:
        #     with h5py.File(f, "r") as h5handle:
        #         return fun(h5handle)

    if h5ad_path.startswith("gs://"):
        out = _gcloud_version(h5ad_path)
    elif h5ad_path.startswith("http://") or h5ad_path.startswith("https://"):
        out = _url_version(h5ad_path)
    else:
        out = _local_version(h5ad_path)

    return out

Here you'll need to instantiate a `CellariumAnnDataDataModule` to handle the data.

Much of the configuration for this exists in the checkpoint itself, but you have to supply the dataset you want 
as a `DistributedAnnDataCollectionDataset`.

In [6]:
# locate the data
# here I'm pulling from cloud, but this could be local
# NOTE: streaming from cloud is slow, so likely you should download the files

nexus_path = "gs://cellarium-nexus-file-system-3293a8/pipeline/data-extracts"  # fixed path for now
curriculum_name = "czi_20250130_human_primary_gte300umi"  # latest curriculum 62M cells

extract_numbers = [0, 1]  # can choose which extracts to use; here we use the first two extracts (0 and 1)

extract_paths = [f"{nexus_path}/{curriculum_name}/extract_files/extract_{i}.h5ad" for i in extract_numbers]
extract_paths

['gs://cellarium-nexus-file-system-3293a8/pipeline/data-extracts/czi_20250130_human_primary_gte300umi/extract_files/extract_0.h5ad',
 'gs://cellarium-nexus-file-system-3293a8/pipeline/data-extracts/czi_20250130_human_primary_gte300umi/extract_files/extract_1.h5ad']

You'll need the "limits", which is computed from the number of cells in each file.  There is a helper function above.

In [7]:
limits = get_h5ad_files_limits(extract_paths)
limits

Reading n_obs from h5ad files: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 2/2 [00:05<00:00,  2.89s/file]


array([10000, 20000])

In [8]:
# this sets up the datamodule

datamodule = CellariumAnnDataDataModule.load_from_checkpoint(
    checkpoint_path=checkpoint_file,
    dadc=DistributedAnnDataCollection(
        filenames=extract_paths,
        limits=limits,
        obs_columns_to_validate=["cellarium_scvi_batch"],
        max_cache_size=2,
        cache_size_strictly_enforced=True,
    ),
    batch_size=1024,  # can adjust based on memory
    shuffle=False,
    num_workers=0,  # override the training setting which uses a beefy machine with 14 persistent workers
    prefetch_factor=None,  # override the training setting
    persistent_workers=False,  # override the training setting
)

datamodule.setup(stage="predict")



In [9]:
datamodule.batch_keys

{'x_ng': AnnDataField(attr='X', key=None, convert_fn=<function densify at 0x3496a28c0>),
 'var_names_g': AnnDataField(attr='var_names', key=None, convert_fn=None),
 'batch_index_n': AnnDataField(attr='obs', key='cellarium_scvi_batch', convert_fn=<function categories_to_codes at 0x3496a29e0>)}

In [10]:
# this is wonky but we need to add a batch_key with the cell names to help us save the output
# we do not load this data during training, so it is not part of the checkpointed datamodule

datamodule.batch_keys |= {"obs_names_n": AnnDataField(attr="obs_names")}

datamodule.batch_keys

{'x_ng': AnnDataField(attr='X', key=None, convert_fn=<function densify at 0x3496a28c0>),
 'var_names_g': AnnDataField(attr='var_names', key=None, convert_fn=None),
 'batch_index_n': AnnDataField(attr='obs', key='cellarium_scvi_batch', convert_fn=<function categories_to_codes at 0x3496a29e0>),
 'obs_names_n': AnnDataField(attr='obs_names', key=None, convert_fn=None)}

In [11]:
# we can access an anndata file directly if we wanted (but we don't):
datamodule.dadc.adatas[0].adata

AnnData object with n_obs Ã— n_vars = 10000 Ã— 61886
    obs: 'original_id', 'donor_id', 'cell_type', 'assay', 'development_stage', 'tissue', 'disease', 'organism', 'self_reported_ethnicity', 'sex', 'suspension_type', 'total_mrna_umis', 'cell_type_ontology_term_id', 'assay_ontology_term_id', 'development_stage_ontology_term_id', 'tissue_ontology_term_id', 'disease_ontology_term_id', 'organism_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'sex_ontology_term_id', 'tag', 'is_primary_data', 'cellarium_scvi_batch', 'observation_joinid'

In [12]:
# we could also grab a dataloader if we needed it (but we shouldn't):
datamodule.predict_dataloader()

<torch.utils.data.dataloader.DataLoader at 0x350a4c040>

## Reconstruction

### Pick genes to reconstruct

In [13]:
gene_df = pd.read_csv('/Users/sfleming/Documents/Projects/scvi/glycosylation/cellarium_train_20251007.tsv', sep='\t')
gene_df

Unnamed: 0,ensembl_id,Group,HGNC
0,ENSG00000000005,cellxgene_census_hvg_8000,TNMD
1,ENSG00000000971,cellxgene_census_hvg_8000,CFH
2,ENSG00000001167,cellxgene_census_hvg_8000,NFYA
3,ENSG00000001497,cellxgene_census_hvg_8000,LAS1L
4,ENSG00000001626,cellxgene_census_hvg_8000,CFTR
...,...,...,...
8425,ENSG00000169299,glyco_pentose_phosphate,PGM2
8426,ENSG00000147224,glyco_pentose_phosphate,PRPS1
8427,ENSG00000101911,glyco_pentose_phosphate,PRPS2
8428,ENSG00000170222,glyco_pentose_phosphate,ADPRM


In [14]:
gene_df['Group'].value_counts()

Group
cellxgene_census_hvg_8000      8000
glyco_degradation                72
glyco_Initiation                 58
glyco_transporter                40
glyco_Core_extension             37
glyco_Capping                    33
glyco_Capping_Sulfo              27
glyco_binding                    20
glyco_Elongation                 16
glyco_pentose_phosphate          15
glyco_mannose                    14
glyco_Detoxification             12
glyco_Branching                  11
glyco_glycolysis                 10
glyco_hexosamine                 10
glyco_Core_extension_repeat       9
glyco_Glycogen                    7
glyco_Unknown                     5
glyco_galactose                   5
glyco_fucose                      5
glyco_sialic_acid                 4
glyco_Pseudogene                  4
glyco_glucose                     4
glyco_Core_degradation            4
glyco_Inactive                    2
glyco_Donorpool                   2
glyco_QC                          2
glyco_tRNA            

In [15]:
# demo list

gene_ids_of_interest = gene_df.loc[gene_df['Group'].str.startswith('glyco'), 'ensembl_id'].values
gene_ids_of_interest[:5]

array(['ENSG00000163389', 'ENSG00000101346', 'ENSG00000186866',
       'ENSG00000130714', 'ENSG00000009830'], dtype=object)

In [16]:
scvi_module.model.var_names_g

array(['ENSG00000151136', 'ENSG00000253535', 'ENSG00000237468', ...,
       'ENSG00000288583', 'ENSG00000200883', 'ENSG00000222430'],
      dtype='<U15')

In [17]:
# are any genes missing

len(set(gene_ids_of_interest) - set(scvi_module.model.var_names_g))

0

In [18]:
gene_ids_of_interest = list(set(gene_ids_of_interest).intersection(set(scvi_module.model.var_names_g)))
len(gene_ids_of_interest)

430

### Patch a problematic function in cellarium-ml

(There is a github issue for this)

In [19]:
# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

from collections.abc import Sequence
from functools import cache
from typing import Any

import numpy as np
import torch
from torch import nn

from cellarium.ml.utilities.testing import (
    assert_columns_and_array_lengths_equal,
)


class OrderedFilter(nn.Module):
    """
    Filter gene counts by a list of features.

    .. math::

        \\mathrm{mask}_g = \\mathrm{feature}_g \\in \\mathrm{filter\\_list}

        y_{ng} = x_{ng}[:, \\mathrm{mask}_g]

    Args:
        filter_list: A list of features to filter by.
    """

    def __init__(self, filter_list: Sequence[str]) -> None:
        super().__init__()
        self.filter_list = np.array(filter_list)
        if len(self.filter_list) == 0:
            raise ValueError(f"`filter_list` must not be empty. Got {self.filter_list}")

    @cache
    def filter(self, var_names_g: tuple) -> np.ndarray[Any, np.dtype[np.int_]]:
        """
        Args:
            var_names_g: The list of the variable names in the input data.

        Returns:
            An array of indices of the features in ``var_names_g`` that are in :attr:`filter_list`,
            ordered according to the order in :attr:`filter_list`.
        """
        # Convert var_names_g to numpy array for vectorized operations
        var_names_array = np.array(var_names_g)

        # Use numpy's isin to find which filter_list genes are present in var_names_g
        mask = np.isin(self.filter_list, var_names_array)

        if not np.any(mask):
            raise AssertionError("No features in `var_names_g` matched the `filter_list`")

        # Get the genes from filter_list that are present (maintains filter_list order)
        present_genes = self.filter_list[mask]

        # Use searchsorted to find indices efficiently
        # First sort var_names_g and get the sort indices
        sort_indices = np.argsort(var_names_array)
        sorted_var_names = var_names_array[sort_indices]

        # Find positions of present_genes in sorted array
        positions = np.searchsorted(sorted_var_names, present_genes)

        # Map back to original indices
        mask_indices = sort_indices[positions]

        return mask_indices.astype(np.int_)

    def forward(self, x_ng: torch.Tensor, var_names_g: np.ndarray) -> dict[str, torch.Tensor | np.ndarray]:
        """
        .. note::

            When used with :class:`~cellarium.ml.core.CellariumModule` or :class:`~cellarium.ml.core.CellariumPipeline`,
            ``x_ng`` and ``var_names_g`` keys in the input dictionary will be overwritten with the filtered values.

        Args:
            x_ng:
                Gene counts.
            var_names_g:
                The list of the variable names in the input data.

        Returns:
            A dictionary with the following keys:

            - ``x_ng``: Gene counts filtered by :attr:`filter_list`.
            - ``var_names_g``: The list of the variable names in the input data filtered by :attr:`filter_list`.
        """
        assert_columns_and_array_lengths_equal("x_ng", x_ng, "var_names_g", var_names_g)

        filter_indices = self.filter(tuple(var_names_g.tolist()))
        ndx = torch.arange(x_ng.shape[0])
        x_ng = x_ng[ndx[:, None], filter_indices]
        var_names_g = var_names_g[filter_indices]

        return {"x_ng": x_ng, "var_names_g": var_names_g}

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(filter_list={self.filter_list})"


In [20]:
# replace the Filter transform with an OrderedFilter transform

scvi_module.pipeline[0] = OrderedFilter(filter_list=scvi_module.model.var_names_g)

### Running reconstruction method

In [21]:
# set up the model for reconstruction mode
# this needs to be done because it was not part of the training checkpoint

scvi_module.model.reconstruct_counts_on_predict = True
scvi_module.model.reconstruction_var_names_g = gene_ids_of_interest
scvi_module.model.reconstruction_transform_batch = 0  # batch index to project into; None will not do batch correction
scvi_module.model.reconstruction_use_latent_mean = False  # don't do this: sample as below
scvi_module.model.reconstruction_n_latent_samples = 30  # time is linear in this, but more is better
scvi_module.model.reconstruction_use_importance_sampling = False  # you could, I think scvi-tools defaults to False
scvi_module.model.reconstructed_library_size = 10000  # set a fixed library size for reconstruction

There are a couple ways this could be done:
1. You could manually iterate through the dataloader and call the `reconstruct` method yourself, and deal with the outputs.
2. You can use a pytorch lightning trainer to do it for you, but you have to tell lightning what to do with the outputs,
 and this requires setting up a callback.
 This will also require you to set object attributes appropriately so that the model 
 knows that it should call `predict` and that `predict` should in turn call `reconstruct`.

#### Demo of manual mode

In [22]:
# see what is being loaded (it's the batch_keys from above)

for batch in datamodule.predict_dataloader():
    print(batch.keys())
    break

dict_keys(['x_ng', 'var_names_g', 'batch_index_n', 'obs_names_n'])




In [23]:
# see how it would be run through the predict method and what output you get

batch = next(iter(datamodule.predict_dataloader()))

# manual application of the Filter is needed to get the right genes in the right order
out = scvi_module.pipeline[0](x_ng=batch["x_ng"], var_names_g=batch["var_names_g"])
batch["x_ng"] = out["x_ng"]
batch["var_names_g"] = out["var_names_g"]

predict_output = scvi_module.model.predict(
    x_ng=batch["x_ng"],
    var_names_g=batch["var_names_g"],
    batch_index_n=batch["batch_index_n"],
    # categorical_covariate_index_nd=batch["categorical_covariate_index_nd"],  # if the model needs it, depends on model
)



In [24]:
predict_output.keys()

dict_keys(['x_ng'])

In [25]:
predict_output["x_ng"].shape

torch.Size([1024, 430])

In [26]:
# these are the cell names that go with the output

batch["obs_names_n"].shape

(1024,)

In [32]:
# this would probably be a reasonable way to deal with the output

pd.DataFrame(
    data=(
        {"obs_names": batch["obs_names_n"]}
        | {
            scvi_module.model.reconstruction_var_names_g[i]: predict_output["x_ng"][:, i].numpy()
            for i in range(predict_output["x_ng"].shape[1])
            }
    )
).set_index('obs_names')

Unnamed: 0_level_0,ENSG00000174684,ENSG00000163527,ENSG00000116406,ENSG00000146411,ENSG00000182050,ENSG00000113532,ENSG00000167130,ENSG00000167165,ENSG00000187210,ENSG00000175229,...,ENSG00000070614,ENSG00000133116,ENSG00000168995,ENSG00000189366,ENSG00000109181,ENSG00000118094,ENSG00000173852,ENSG00000143641,ENSG00000136542,ENSG00000110328
obs_names,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
12024477,0.040042,1.565121,0.709921,0.005763,0.002386,0.831880,0.066981,0.000245,0.323482,0.000570,...,0.092326,0.011391,0.439055,0.004553,0.000186,0.000491,0.073392,0.392312,0.015904,0.003739
29591694,0.122643,3.984713,1.911196,0.006702,0.083066,1.467092,0.132956,0.000133,0.996605,0.005099,...,0.157977,0.008412,0.013219,0.014772,0.000201,0.001176,0.343288,3.425517,0.013978,0.035153
53717858,0.041249,0.736323,0.312037,0.000599,0.001461,0.611627,0.065821,0.000016,0.298412,0.000143,...,0.065119,0.000637,0.563698,0.001025,0.000017,0.000048,0.078044,0.259090,0.000405,0.000953
20503671,0.640971,2.045768,1.269840,0.194664,0.014940,0.290521,0.225634,0.012869,0.429593,0.018281,...,0.441988,0.159351,0.004990,0.050459,0.003046,0.012775,0.378543,1.203335,0.148103,0.728982
90193068,0.178522,8.406000,3.417926,0.060516,0.032390,7.816574,0.311471,0.000560,1.380279,0.010784,...,0.111675,0.050222,0.007012,0.013987,0.000514,0.000650,0.804437,5.250274,0.008198,0.247794
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2625492,4.451240,3.437879,3.295121,0.162315,13.930861,1.566581,0.623352,0.000460,1.304275,0.367847,...,1.668445,0.068736,0.009307,0.097233,0.002531,0.049939,2.058690,2.509491,0.019584,7.477097
91767037,4.737754,2.496646,1.992297,0.365746,2.830714,0.219430,0.719214,0.000448,0.193655,0.387726,...,1.240770,0.063702,0.002049,0.094269,0.000706,0.041929,2.124146,2.852248,0.047519,0.517604
36724053,0.134790,1.509463,0.577212,0.005737,0.021473,1.993769,0.210360,0.000158,0.347409,0.001056,...,0.022029,0.001872,5.450201,0.014451,0.000531,0.000121,0.356201,0.953058,0.002630,0.015277
28676860,1.720756,2.704497,1.982792,0.240452,11.347136,0.722401,0.285187,0.000367,0.363596,0.421066,...,0.583255,0.019177,0.001027,0.111516,0.000482,0.005215,3.282543,1.842872,0.002354,23.579630


#### Batched mode using pytorch lightning

This is probably recommended since there is less room for error.

In [28]:
# demo of pytorch lightning mode

trainer = pl.Trainer(
    accelerator="cpu",  # adjust as needed
    devices=1,
    max_epochs=1,
    callbacks=[
        PredictionWriter(output_dir="/Users/sfleming/Desktop/tmp")
    ],
)

# return_predictions could work, but it can overflow memory, so we write to disk instead using PredictionWriter
trainer.predict(model=scvi_module, datamodule=datamodule, return_predictions=False)

ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/sfleming/miniconda3/envs/cellarium/lib/python3.10/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/Users/sfleming/miniconda3/envs/cellarium/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/Users/sfleming/miniconda3/envs/cellarium/lib/python3.10/site-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` 

Predicting DataLoader 0:   0%|          | 0/20 [00:00<?, ?it/s]

/Users/sfleming/miniconda3/envs/cellarium/lib/python3.10/site-packages/lightning/pytorch/loops/prediction_loop.py:307: Couldn't infer the batch indices fetched from your dataloader: `DataLoader`


Predicting DataLoader 0:  45%|â–ˆâ–ˆâ–ˆâ–ˆâ–Œ     | 9/20 [00:15<00:19,  0.58it/s]



Predicting DataLoader 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 20/20 [00:48<00:00,  0.41it/s]




In [29]:
!ls -lh /Users/sfleming/Desktop/tmp/batch*.csv.gz

-rw-r--r--  1 sfleming  staff   2.0M Oct  8 14:31 /Users/sfleming/Desktop/tmp/batch_0.csv.gz
-rw-r--r--  1 sfleming  staff   2.0M Oct  8 14:31 /Users/sfleming/Desktop/tmp/batch_1.csv.gz
-rw-r--r--  1 sfleming  staff   2.0M Oct  8 14:32 /Users/sfleming/Desktop/tmp/batch_10.csv.gz
-rw-r--r--  1 sfleming  staff   2.0M Oct  8 14:32 /Users/sfleming/Desktop/tmp/batch_11.csv.gz
-rw-r--r--  1 sfleming  staff   2.0M Oct  8 14:32 /Users/sfleming/Desktop/tmp/batch_12.csv.gz
-rw-r--r--  1 sfleming  staff   2.0M Oct  8 14:32 /Users/sfleming/Desktop/tmp/batch_13.csv.gz
-rw-r--r--  1 sfleming  staff   2.0M Oct  8 14:32 /Users/sfleming/Desktop/tmp/batch_14.csv.gz
-rw-r--r--  1 sfleming  staff   2.0M Oct  8 14:32 /Users/sfleming/Desktop/tmp/batch_15.csv.gz
-rw-r--r--  1 sfleming  staff   2.0M Oct  8 14:32 /Users/sfleming/Desktop/tmp/batch_16.csv.gz
-rw-r--r--  1 sfleming  staff   2.0M Oct  8 14:32 /Users/sfleming/Desktop/tmp/batch_17.csv.gz
-rw-r--r--  1 sfleming  staff   2.0M Oct  8 14:32 /Users/sflem

In [30]:
!cat /Users/sfleming/Desktop/tmp/batch_0.csv.gz | zcat | head -n 1

12024477,0.26802728,5.1590977,2.5461533,0.073427506,0.23911971,2.754567,0.16287391,0.0010512942,0.87522185,0.009504059,0.5389844,0.491597,1.012902,1.1374475,2.6108975,0.071048215,0.61987126,1.8056259,0.0005368311,2.0310016,2.5362048,0.4033317,0.64339423,0.24812263,0.62582624,1.1245711,0.035139628,1.6583856,0.037551142,0.07730511,2.4732704,6.187904,0.613545,0.27310076,0.3877924,0.109418795,0.104126446,0.31198278,0.0024354428,0.06651029,0.5976026,0.21146509,0.104931064,1.4340197,0.08538062,9.05576,1.5827175,0.17232628,0.26673633,4.0119247,0.47987586,0.43843925,0.07493413,0.72142416,0.44669056,3.499847,1.1892666,1.398297,0.7774532,1.4077824,0.0024590262,0.0033527168,0.0750121,0.028640687,0.030879576,1.178991,0.18245286,0.5480898,0.5090252,0.001019529,0.47942433,6.40905,0.036482256,1.0377126,0.0018258867,0.7787971,0.06776762,3.1111712,0.5945682,0.124100894,1.3088573,0.41907084,1.208893,0.9610857,0.9965388,0.1869039,0.1028203,0.0026125177,0.026380684,0.021720415,1.8363973,0.61317223,0.35121

The only catch with the `PredictionWriter` is that currently there are no gene names in the files. You have to remember which genes you asked for.

In [31]:
# they are in this order

scvi_module.model.reconstruction_var_names_g

['ENSG00000174684',
 'ENSG00000163527',
 'ENSG00000116406',
 'ENSG00000146411',
 'ENSG00000182050',
 'ENSG00000113532',
 'ENSG00000167130',
 'ENSG00000167165',
 'ENSG00000187210',
 'ENSG00000175229',
 'ENSG00000088035',
 'ENSG00000136720',
 'ENSG00000113552',
 'ENSG00000162688',
 'ENSG00000106392',
 'ENSG00000185674',
 'ENSG00000196376',
 'ENSG00000178234',
 'ENSG00000162040',
 'ENSG00000142657',
 'ENSG00000168961',
 'ENSG00000101346',
 'ENSG00000119523',
 'ENSG00000175040',
 'ENSG00000162139',
 'ENSG00000149541',
 'ENSG00000105492',
 'ENSG00000105220',
 'ENSG00000197496',
 'ENSG00000139044',
 'ENSG00000121964',
 'ENSG00000071073',
 'ENSG00000136213',
 'ENSG00000119227',
 'ENSG00000168917',
 'ENSG00000123989',
 'ENSG00000117411',
 'ENSG00000181027',
 'ENSG00000171489',
 'ENSG00000185090',
 'ENSG00000138459',
 'ENSG00000181830',
 'ENSG00000176928',
 'ENSG00000197713',
 'ENSG00000175164',
 'ENSG00000163931',
 'ENSG00000135838',
 'ENSG00000144214',
 'ENSG00000175548',
 'ENSG00000033170',
