# SCVI Annotated

The SCVI class is a variational autoencoder-based model for single-cell RNA-seq data. It inherits from multiple mixins (e.g., EmbeddingMixin, RNASeqMixin, UnsupervisedTrainingMixin) to provide modular functionality for training, data handling, and inference. The class accepts an AnnData object and stores architecture parameters like n_hidden, n_latent, and distribution settings such as dispersion and gene_likelihood. It either initializes its underlying VAE module immediately using dataset-specific metadata or delays it until training, and manages data access through an AnnDataManager configured via the setup_anndata class method.

## Imports

In [None]:
from __future__ import annotations  # Enables postponed evaluation of type annotations (so classes defined later can be referenced in type hints)
import logging                     # Standard Python logging module
import warnings                    # Used for issuing warning messages
from typing import TYPE_CHECKING  # Used to avoid circular imports at runtime by conditionally importing only for type checking

# Importing constants and settings from scvi-tools
from scvi import REGISTRY_KEYS, settings
from scvi.data import AnnDataManager  # Manages how AnnData is used inside scvi
from scvi.data._constants import ADATA_MINIFY_TYPE  # Constants related to minified AnnData types
from scvi.data._utils import _get_adata_minify_type  # Utility to get the minified data type
from scvi.data.fields import (                      # Definitions for how various fields are extracted from AnnData
    CategoricalJointObsField,
    CategoricalObsField,
    LayerField,
    NumericalJointObsField,
    NumericalObsField,
)

from scvi.model._utils import _init_library_size  # Initializes library size parameters (used for normalization)
from scvi.model.base import EmbeddingMixin, UnsupervisedTrainingMixin  # Mixins that add embedding and training behaviors
from scvi.module import VAE  # The core variational autoencoder module
from scvi.utils import setup_anndata_dsp  # Decorator to help document and handle setup_anndata logic

# Importing additional mixin classes and base model class
from .base import ArchesMixin, BaseMinifiedModeModelClass, RNASeqMixin, VAEMixin

# Import type hints only during static type checking
if TYPE_CHECKING:
    from typing import Literal
    from anndata import AnnData  # AnnData is the main data structure for single-cell data

# Create a logger for this module
logger = logging.getLogger(__name__)

## Create SCVI Class

In [None]:
class SCVI(
# 1. CLASS INHERITANCE
    EmbeddingMixin,               # Adds methods for getting latent representations
    RNASeqMixin,                  # Adds single-cell RNA-seq-specific logic
    VAEMixin,                     # Adds methods for working with a VAE model
    ArchesMixin,                  # Adds functionality for transfer learning (ARCHES)
    UnsupervisedTrainingMixin,   # Adds methods for unsupervised training
    BaseMinifiedModeModelClass,  # Adds support for working with memory-efficient minified AnnData
):
# 2. CLASS DOCSTRING
        """single-cell Variational Inference :cite:p:`Lopez18`.

    Parameters
    ----------
    adata
        AnnData object that has been registered via setup_anndata
    n_hidden, n_latent, etc.
        VAE architecture and model parameters
    Examples
    --------
    Shows usage with reading data, setting up, training, and getting results
    """

# 3. CLASS ATTRIBUTES
    _module_cls = VAE  # Points to the VAE class that this model will use internally
    _LATENT_QZM_KEY = "scvi_latent_qzm"  # Key for the latent mean in AnnData
    _LATENT_QZV_KEY = "scvi_latent_qzv"  # Key for the latent variance in AnnData


# 4. CONSTRUCTOR
    def __init__(
        self,
        adata: AnnData | None = None,  # Input data; can be None (if adata is not provided, the model will delay initialization until train is called).
        n_hidden: int = 128,           # Hidden units per layer
        n_latent: int = 10,            # Dimensionality of latent space
        n_layers: int = 1,             # Number of layers in encoder/decoder
        dropout_rate: float = 0.1,     # Dropout rate
        dispersion: Literal[...] = "gene",         # Type of dispersion parameter
        gene_likelihood: Literal[...] = "zinb",    # Distribution to model gene expression
        latent_distribution: Literal[...] = "normal",  # Latent distribution type
        **kwargs,                      # Any other parameters passed to the VAE
    ):
        super().__init__(adata)  # Call the constructor of the parent mixin/base classes

        # Store parameters in a dictionary
        self._module_kwargs = {
            "n_hidden": n_hidden,
            "n_latent": n_latent,
            "n_layers": n_layers,
            "dropout_rate": dropout_rate,
            "dispersion": dispersion,
            "gene_likelihood": gene_likelihood,
            "latent_distribution": latent_distribution,
            **kwargs,
        }

        # Create a readable summary string
        self._model_summary_string = (
            "SCVI model with the following parameters: \n"
            f"n_hidden: {n_hidden}, n_latent: {n_latent}, n_layers: {n_layers}, "
            f"dropout_rate: {dropout_rate}, dispersion: {dispersion}, "
            f"gene_likelihood: {gene_likelihood}, latent_distribution: {latent_distribution}."
        )

        # If lazy initialization is enabled (adata is not provided), postpone model creation until training
        if self._module_init_on_train:
            self.module = None
            warnings.warn(
                "Model was initialized without `adata`. The module will be initialized when "
                "calling `train`. This behavior is experimental and may change in the future.",
                UserWarning,
                stacklevel=settings.warnings_stacklevel,
            )
        else:
            # Get categorical covariate info, if available
            n_cats_per_cov = (
                self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key
                if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry
                else None
            )

            # Get number of batches
            n_batch = self.summary_stats.n_batch

            # Determine if size factor is provided in the data
            use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry

            # Initialize library size params if needed
            library_log_means, library_log_vars = None, None
            if (
                not use_size_factor_key
                and self.minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR
            ):
                library_log_means, library_log_vars = _init_library_size(
                    self.adata_manager, n_batch
                )

            # Instantiate the actual VAE model
            self.module = self._module_cls(
                n_input=self.summary_stats.n_vars,  # Number of genes
                n_batch=n_batch,
                n_labels=self.summary_stats.n_labels,
                n_continuous_cov=self.summary_stats.get("n_extra_continuous_covs", 0),
                n_cats_per_cov=n_cats_per_cov,
                n_hidden=n_hidden,
                n_latent=n_latent,
                n_layers=n_layers,
                dropout_rate=dropout_rate,
                dispersion=dispersion,
                gene_likelihood=gene_likelihood,
                latent_distribution=latent_distribution,
                use_size_factor_key=use_size_factor_key,
                library_log_means=library_log_means,
                library_log_vars=library_log_vars,
                **kwargs,
            )

            # Set minified type to the model (used for memory optimization)
            self.module.minified_data_type = self.minified_data_type

        # Save init parameters for reproducibility
        self.init_params_ = self._get_init_params(locals())

# 5. Define setup_anndata for preproccessing AnnData

    @classmethod
    @setup_anndata_dsp.dedent  # Automatically formats docstring from template
    def setup_anndata(
        cls,
        adata: AnnData,
        layer: str | None = None,  # Which layer of AnnData.X to use
        batch_key: str | None = None,  # Batch annotation column in adata.obs
        labels_key: str | None = None,  # Label annotation column
        size_factor_key: str | None = None,  # Precomputed size factor
        categorical_covariate_keys: list[str] | None = None,  # Categorical covariates
        continuous_covariate_keys: list[str] | None = None,   # Continuous covariates
        **kwargs,
    ):
        """%(summary)s.

        Parameters
        ----------
        %(param_adata)s
        %(param_layer)s
        %(param_batch_key)s
        %(param_labels_key)s
        %(param_size_factor_key)s
        %(param_cat_cov_keys)s
        %(param_cont_cov_keys)s
        """

        # Get arguments as a dictionary
        setup_method_args = cls._get_setup_method_args(**locals())

        # Define how to extract relevant fields from AnnData
        anndata_fields = [
            LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
            CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
            CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key),
            NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False),
            CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys),
            NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys),
        ]

        # If this is a "minified" AnnData, add extra required fields
        adata_minify_type = _get_adata_minify_type(adata)
        if adata_minify_type is not None:
            anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type)

        # Create a manager to track and validate all fields
        adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)

        # Register fields into the manager
        adata_manager.register_fields(adata, **kwargs)

        # Register the manager for this class (global to model)
        cls.register_manager(adata_manager)


## Usage

In [None]:
adata = anndata.read_h5ad("data.h5ad")
scvi.model.SCVI.setup_anndata(adata, batch_key="batch")
model = scvi.model.SCVI(adata)
model.train()
adata.obsm["X_scVI"] = model.get_latent_representation()