## Imports

In [7]:
import tempfile
from typing import Literal

import scvi
import torch
from scvi import REGISTRY_KEYS
from scvi.module.base import (
    BaseModuleClass,
    LossOutput,
    auto_move_data,
)
from torch.distributions import NegativeBinomial, Normal
from torch.distributions import kl_divergence as kl

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal, Dirichlet, NegativeBinomial
from scvi.model.base import UnsupervisedTrainingMixin, BaseModelClass
from scvi.data import AnnDataManager
from scvi.data.fields import (
    LayerField,
    CategoricalObsField,
    NumericalObsField,
    CategoricalJointObsField,
    NumericalJointObsField,
)
from anndata import AnnData


  doc = func(self, args[0].__doc__, *args[1:], **kwargs)
  doc = func(self, args[0].__doc__, *args[1:], **kwargs)


In [8]:
scvi.settings.seed = 0
print("Last run with scvi-tools version:", scvi.__version__)

Seed set to 0


Last run with scvi-tools version: 1.3.0


In [8]:
torch.set_float32_matmul_precision("high")
save_dir = "cell_cycle_test_model"

%config InlineBackend.print_figure_kwargs={"facecolor": "w"}
%config InlineBackend.figure_format="retina"

## Neural Network Class

This class defines a fully connected neural network with a given number of inputs and outputs,
a single hidden layer with 128 neurons and ReLU activation, and the chosen final activation function.

In [9]:
class MyNeuralNet(torch.nn.Module):
# Initialization of network
    def __init__(
        # 1. Defining the inputs
        self,
        n_input: int,
        n_output: int,
        link_var: Literal["exp", "none", "softmax"],
    ):
        """Encodes data of ``n_input`` dimensions into a space of ``n_output`` dimensions.

        Uses a one layer fully-connected neural network with 128 hidden nodes.

        Parameters
        ----------
        n_input
            The dimensionality of the input.
        n_output
            The dimensionality of the output.
        link_var
            The final non-linearity.
        """
        # 2. Defining the architecture
        super().__init__() # Initialize the parent class torch.nn.Module
        self.neural_net = torch.nn.Sequential(
            torch.nn.Linear(n_input, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, n_output),
        )
        # 3. Defining the final activation function
        self.transformation = None
        if link_var == "softmax":
            self.transformation = torch.nn.Softmax(dim=-1) # applies softmax along the last dimension
        elif link_var == "exp":
            self.transformation = torch.exp # exp doesn’t need dim because it applies to each element independently.

# Forward pass
    def forward(self, x: torch.Tensor):
        output = self.neural_net(x)
        if self.transformation:
            output = self.transformation(output)
        return output

Exploring this class

In [10]:
my_neural_net = MyNeuralNet(10, 4, "softmax") # 100 input features, 10 output features with softmax activation
my_neural_net

MyNeuralNet(
  (neural_net): Sequential(
    (0): Linear(in_features=10, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=4, bias=True)
  )
  (transformation): Softmax(dim=-1)
)

In [11]:
# observe that the output sums to 1 and are positive!
x = torch.randn((2, 10))
my_neural_net(x)

tensor([[0.1545, 0.1928, 0.2447, 0.4081],
        [0.2039, 0.2265, 0.3178, 0.2517]], grad_fn=<SoftmaxBackward0>)

## Crafting the Module in vanilla PyTorch

In [12]:
class MyModule(BaseModuleClass):
    """Skeleton Variational auto-encoder model.

    Here we implement a basic version of scVI's underlying VAE [Lopez18]_.
    This implementation is for instructional purposes only.

    Parameters
    ----------
    n_input
        Number of input genes.
    n_latent
        Dimensionality of the latent space.
    """

    def __init__(
        self,
        n_input: int,
        n_latent: int = 10,
    ):
        super().__init__()
        # in the init, we create the parameters of our elementary stochastic computation unit.

        # First, we setup the parameters of the generative model
        self.decoder = MyNeuralNet(n_latent, n_input, "softmax")
        self.log_theta = torch.nn.Parameter(torch.randn(n_input)) # initializes a trainable parameter with random values from a normal distribution

        # Second, we setup the parameters of the variational distribution
        self.mean_encoder = MyNeuralNet(n_input, n_latent, "none") # the mean of a distribution can take any real value
        self.var_encoder = MyNeuralNet(n_input, n_latent, "exp") # applying exp ensures that variance is always positive

    def _get_inference_input(self, tensors: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        """Parse the dictionary to get appropriate args"""
        # This function takes a dictionary of tensors and extracts the required input tensor (raw counts) for the inference step, formatting it into a new dictionary.
        return {"x": tensors[REGISTRY_KEYS.X_KEY]}

    @auto_move_data
    def inference(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
        """
        High level inference method.

        Runs the inference (encoder) model.
        """
        # log the input to the variational distribution for numerical stability
        x_ = torch.log1p(x)
        # get variational parameters via the encoder networks
        qz_m = self.mean_encoder(x_)
        qz_v = self.var_encoder(x_)
        # get one sample to feed to the generative model
        # under the hood here is the Reparametrization trick (Rsample)
        z = Normal(qz_m, torch.sqrt(qz_v)).rsample()

        return {"qz_m": qz_m, "qz_v": qz_v, "z": z}

    def _get_generative_input(
        self, tensors: dict[str, torch.Tensor], inference_outputs: dict[str, torch.Tensor]
    ) -> dict[str, torch.Tensor]:
        return {
            "z": inference_outputs["z"],
            "library": torch.sum(tensors[REGISTRY_KEYS.X_KEY], dim=1, keepdim=True),
        }

    @auto_move_data
    def generative(self, z: torch.Tensor, library: torch.Tensor) -> dict[str, torch.Tensor]:
        """Runs the generative model."""
        # get the "normalized" mean of the negative binomial
        px_scale = self.decoder(z)
        # get the mean of the negative binomial
        px_rate = library * px_scale
        # get the dispersion parameter
        theta = torch.exp(self.log_theta)

        return {
            "px_scale": px_scale,
            "theta": theta,
            "px_rate": px_rate,
        }

    def loss(
        self,
        tensors: dict[str, torch.Tensor],
        inference_outputs: dict[str, torch.Tensor],
        generative_outputs: dict[str, torch.Tensor],
    ) -> LossOutput:
        # here, we would like to form the ELBO. There are two terms:
        #   1. one that pertains to the likelihood of the data
        #   2. one that pertains to the variational distribution
        # so we extract all the required information
        x = tensors[REGISTRY_KEYS.X_KEY]
        px_rate = generative_outputs["px_rate"]
        theta = generative_outputs["theta"]
        qz_m = inference_outputs["qz_m"]
        qz_v = inference_outputs["qz_v"]

        # term 1
        # the pytorch NB distribution uses a different parameterization
        # so we must apply a quick transformation (included in scvi-tools, but here we use the
        # pytorch code)
        nb_logits = (px_rate + 1e-4).log() - (theta + 1e-4).log()
        log_lik = NegativeBinomial(total_count=theta, logits=nb_logits).log_prob(x).sum(dim=-1)

        # term 2
        prior_dist = Normal(torch.zeros_like(qz_m), torch.ones_like(qz_v))
        var_post_dist = Normal(qz_m, torch.sqrt(qz_v))
        kl_divergence = kl(var_post_dist, prior_dist).sum(dim=1)

        elbo = log_lik - kl_divergence
        loss = torch.mean(-elbo)
        return LossOutput(
            loss=loss,
            reconstruction_loss=-log_lik,
            kl_local=kl_divergence,
            kl_global=0.0,
        )

In [13]:
# try creating a module and see the description:
MyModule(100, 10)

MyModule(
  (decoder): MyNeuralNet(
    (neural_net): Sequential(
      (0): Linear(in_features=10, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=100, bias=True)
    )
    (transformation): Softmax(dim=-1)
  )
  (mean_encoder): MyNeuralNet(
    (neural_net): Sequential(
      (0): Linear(in_features=100, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=10, bias=True)
    )
  )
  (var_encoder): MyNeuralNet(
    (neural_net): Sequential(
      (0): Linear(in_features=100, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=10, bias=True)
    )
  )
)

## The model

In [14]:
class SCVI(UnsupervisedTrainingMixin, BaseModelClass):
    """single-cell Variational Inference [Lopez18]_."""

    def __init__(
        self,
        adata: AnnData,
        n_latent: int = 10,
        **model_kwargs,
    ):
        super().__init__(adata)

        self.module = VAE(
            n_input=self.summary_stats["n_vars"],
            n_batch=self.summary_stats["n_batch"],
            n_latent=n_latent,
            **model_kwargs,
        )
        self._model_summary_string = (
            f"SCVI Model with the following params: \nn_latent: {n_latent}"
        )
        self.init_params_ = self._get_init_params(locals())

    @classmethod
    def setup_anndata(
        cls,
        adata: AnnData,
        batch_key: str | None = None,
        layer: str | None = None,
        **kwargs,
    ) -> AnnData | None:
        setup_method_args = cls._get_setup_method_args(**locals())
        anndata_fields = [
            LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
            CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
            # Dummy fields required for VAE class.
            CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, None),
            NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, None, required=False),
            CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, None),
            NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, None),
        ]
        adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
        adata_manager.register_fields(adata, **kwargs)
        cls.register_manager(adata_manager)

## Model with Dirichlet & NB Decoders

In [21]:

class CellCycleVAE(UnsupervisedTrainingMixin, BaseModelClass):
    """Single-cell Variational Inference model with Dirichlet and Negative Binomial decoders."""

    def __init__(self, adata: AnnData, n_latent: int = 10, **model_kwargs):
        """
        Parameters
        ----------
        adata : AnnData
            Annotated data matrix.
        n_latent : int
            Dimensionality of the latent space (default: 10).
        **model_kwargs : dict
            Additional arguments for the model.
        """
        super().__init__(adata)

        # Define the main generative module
        self.module = Cycle_VAE(
            n_input=self.summary_stats["n_vars"],
            n_batch=self.summary_stats["n_batch"],
            n_latent=n_latent,
            n_phases=4,  # 4 cell cycle phases: G0, G1, S, G2M
            **model_kwargs,
        )

        # Store model parameters
        self._model_summary_string = (
            f"CellCycleVAE Model with the following params: \nn_latent: {n_latent}"
        )
        self.init_params_ = self._get_init_params(locals())

    @classmethod
    def setup_anndata(
        cls,
        adata: AnnData,
        batch_key: str | None = None,
        layer: str | None = None,
        **kwargs,
    ) -> AnnData | None:
        """
        Prepares an AnnData object for use with CellCycleVAE.
        """
        setup_method_args = cls._get_setup_method_args(**locals())

        anndata_fields = [
            LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
            CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
            # Dummy fields required for the VAE class
            CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, None),
            NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, None, required=False),
            CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, None),
            NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, None),
        ]

        adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
        adata_manager.register_fields(adata, **kwargs)
        cls.register_manager(adata_manager)


In [26]:
class Cycle_VAE(nn.Module):
    """VAE with a Dirichlet decoder for cell cycle phase and a Negative Binomial decoder for gene expression."""

    def __init__(self, n_input, n_latent=10, n_phases=4, n_batch=0):
        """
        Parameters
        ----------
        n_input : int
            Number of genes (input features).
        n_latent : int
            Dimensionality of the latent space (default: 10).
        n_phases : int
            Number of cell cycle phases (default: 4 for G0, G1, S, G2M).
        n_batch : int
            Number of batch effect categories (default: 0, no batch correction).
        """
        super().__init__()

        # Encoder: Maps gene expression to latent space (z)
        self.encoder = nn.Sequential(
            nn.Linear(n_input, 128),
            nn.ReLU(),
            nn.Linear(128, n_latent * 2),  # Outputs mean & log variance for z
        )

        # Dirichlet Decoder: Maps latent space z to cell cycle phase distribution π
        self.dirichlet_decoder = nn.Sequential(
            nn.Linear(n_latent, 32),
            nn.ReLU(),
            nn.Linear(32, n_phases),
            nn.Softplus(),  # Ensures positive values for Dirichlet parameters
        )

        # Negative Binomial Decoder: Maps π to gene expression parameters (μ)
        self.nb_decoder = nn.Sequential(
            nn.Linear(n_phases, 128),
            nn.ReLU(),
            nn.Linear(128, n_input),  # Mean gene expression for NB distribution
        )

        # Dispersion parameter for Negative Binomial (theta)
        self.theta = nn.Parameter(torch.randn(n_input))

    def encode(self, x):
        """Encodes input gene expression x into latent space z."""
        q = self.encoder(x)
        mu, log_var = torch.chunk(q, 2, dim=-1)  # Split into mean & log variance
        std = torch.exp(0.5 * log_var)
        return mu, std

    def reparameterize(self, mu, std):
        """Reparametrization trick for sampling from q(z|x)."""
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        """Forward pass through the VAE model."""
        
        # 1. Encode x → latent space z
        mu, std = self.encode(x)
        z = self.reparameterize(mu, std)

        # 2. Decode z → cell cycle phase distribution π (Dirichlet)
        alpha = self.dirichlet_decoder(z) + 1e-6  # Add small constant for stability
        pi = Dirichlet(alpha).rsample()  # Sample π from Dirichlet(α)

        # 3. Decode π → gene expression parameters μ (NB decoder)
        mu_nb = self.nb_decoder(pi)
        mu_nb = torch.exp(mu_nb)  # Ensure positive gene expression means

        # 4. Compute Negative Binomial output
        theta = torch.exp(self.theta)  # Ensure dispersion is positive
        nb_dist = NegativeBinomial(total_count=theta, probs=mu_nb / (mu_nb + theta))

        return nb_dist, mu, std, pi

In [27]:
adata = scvi.data.synthetic_iid()
adata

AnnData object with n_obs × n_vars = 400 × 100
    obs: 'batch', 'labels'
    uns: 'protein_names'
    obsm: 'protein_expression', 'accessibility'

In [28]:
CellCycleVAE.setup_anndata(adata, batch_key="batch")
print(f"adata UUID (assigned by setup_anndata): {adata.uns['_scvi_uuid']}")
print(f"AnnDataManager: {CellCycleVAE._setup_adata_manager_store[adata.uns['_scvi_uuid']]}")
model = CellCycleVAE(adata)
model

adata UUID (assigned by setup_anndata): 9293e1c4-bcc7-4bf6-b27b-9bc6d9152318
AnnDataManager: <scvi.data._manager.AnnDataManager object at 0x2869ef380>


TypeError: Cycle_VAE.__init__() got an unexpected keyword argument 'n_batch'

In [29]:
model.train(max_epochs=20)

AttributeError: 'Cycle_VAE' object has no attribute 'loss'

## Building the Model

The model

In [None]:
class CCVI(UnsupervisedTrainingMixin, BaseModelClass):
    def __init__(
        self,
        adata: AnnData,
        n_latent: int = 10,
        n_phases=4,
        **model_kwargs,
    ):
        
        super().__init__(adata) # Initialize the parent class with the AnnData object


        self.module = CC_VAE(
                n_input=self.summary_stats["n_vars"],  # Number of input features (genes)
                n_batch=self.summary_stats["n_batch"],  # Number of unique batches in data
                n_latent=n_latent,  # Latent space dimensionality
                **model_kwargs,  # Pass additional keyword arguments to the VAE model
            )

        self._model_summary_string = (
            f"CCVI Model with the following params: \nn_latent: {n_latent}" # summary string
        )

        self.init_params_ = self._get_init_params(locals()) # Store the initialization parameters for reproducibility

    @classmethod
        def setup_anndata(
            cls,
            adata: AnnData,  # AnnData object to be processed and configured
            batch_key: str | None = None,  # Column name in `adata.obs` specifying batch information
            layer: str | None = None,  # Layer of `adata` containing raw count data
            **kwargs,  # Additional arguments for setup configuration
        ) -> AnnData | None:
            # Retrieve method arguments for setting up the AnnData object
            setup_method_args = cls._get_setup_method_args(**locals())

        # Define the fields required for proper data handling in the CCVI model
        anndata_fields = [
            LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),  # Expression data layer
            CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),  # Batch labels (if provided)
            CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, None),  # Placeholder for labels (unused)
            NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, None, required=False),  # Placeholder for size factors (optional)
            CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, None),  # Placeholder for categorical covariates
            NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, None),  # Placeholder for numerical covariates
        ]

        # Create an AnnDataManager instance to handle preprocessing and field registration
        adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)

        # Register fields with the AnnData object, applying necessary preprocessing
        adata_manager.register_fields(adata, **kwargs)

        # Register the manager for future use
        cls.register_manager(adata_manager)


In [None]:
class CC_VAE(nn.Module):
    """VAE with a Dirichlet decoder for cell cycle phase and a Negative Binomial decoder for gene expression."""

    def __init__(self, n_input, n_latent=10, n_phases=4, n_batch=0):
        """
        Parameters
        ----------
        n_input : int
            Number of genes (input features).
        n_latent : int
            Dimensionality of the latent space (default: 10).
        n_phases : int
            Number of cell cycle phases (default: 4 for G0, G1, S, G2M).
        n_batch : int
            Number of batch effect categories (default: 0, no batch correction).
        """
        super().__init__()

        # Encoder: Maps gene expression to latent space (z)
        self.encoder = nn.Sequential(
            nn.Linear(n_input, 128),
            nn.ReLU(),
            nn.Linear(128, n_latent * 2),  # Outputs mean & log variance for z
        )

        # Dirichlet Decoder: Maps latent space z to cell cycle phase distribution π
        self.dirichlet_decoder = nn.Sequential(
            nn.Linear(n_latent, 32),
            nn.ReLU(),
            nn.Linear(32, n_phases),
            nn.Softplus(),  # Ensures positive values for Dirichlet parameters
        )

        # Negative Binomial Decoder: Maps π to gene expression parameters (μ)
        self.nb_decoder = nn.Sequential(
            nn.Linear(n_phases, 128),
            nn.ReLU(),
            nn.Linear(128, n_input),  # Mean gene expression for NB distribution
        )

        # Dispersion parameter for Negative Binomial (theta)
        self.theta = nn.Parameter(torch.randn(n_input))

    def encode(self, x):
        """Encodes input gene expression x into latent space z."""
        q = self.encoder(x)
        mu, log_var = torch.chunk(q, 2, dim=-1)  # Split into mean & log variance
        std = torch.exp(0.5 * log_var)
        return mu, std

    def reparameterize(self, mu, std):
        """Reparametrization trick for sampling from q(z|x)."""
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        """Forward pass through the VAE model."""
        
        # 1. Encode x → latent space z
        mu, std = self.encode(x)
        z = self.reparameterize(mu, std)

        # 2. Decode z → cell cycle phase distribution π (Dirichlet)
        alpha = self.dirichlet_decoder(z) + 1e-6  # Add small constant for stability
        pi = Dirichlet(alpha).rsample()  # Sample π from Dirichlet(α)

        # 3. Decode π → gene expression parameters μ (NB decoder)
        mu_nb = self.nb_decoder(pi)
        mu_nb = torch.exp(mu_nb)  # Ensure positive gene expression means

        # 4. Compute Negative Binomial output
        theta = torch.exp(self.theta)  # Ensure dispersion is positive
        nb_dist = NegativeBinomial(total_count=theta, probs=mu_nb / (mu_nb + theta))

        return nb_dist, mu, std, pi

1. create a class MyNeuralNet(torch.nn.Module)
   - given a number of inputs, layers, activations,...
   - functions: init and forward
2. create a class CC_VAE (BaseModuleClass)
   - init: specify the decoders and encoders
   - get_generative_input(): selecting the registered tensors from the AnnData, as well as the latent variables (from inference) used in the model
- generative(): run the decoders
- _get_inference_input(): selecting the registered tensors from the AnnData used in the inference
- inference(): run the encoder
- loss(): the log-likelihood or its lower bound
3. create a class CC_model (BaseModuleClass,UnsupervisedTrainingMixin)
   

In [None]:
from __future__ import annotations  # Enables postponed evaluation of type annotations (so classes defined later can be referenced in type hints)
import logging                     # Standard Python logging module
import warnings                    # Used for issuing warning messages
from typing import TYPE_CHECKING  # Used to avoid circular imports at runtime by conditionally importing only for type checking

# Importing constants and settings from scvi-tools
from scvi import REGISTRY_KEYS, settings
from scvi.data import AnnDataManager  # Manages how AnnData is used inside scvi
from scvi.data._constants import ADATA_MINIFY_TYPE  # Constants related to minified AnnData types
from scvi.data._utils import _get_adata_minify_type  # Utility to get the minified data type
from scvi.data.fields import (                      # Definitions for how various fields are extracted from AnnData
    CategoricalJointObsField,
    CategoricalObsField,
    LayerField,
    NumericalJointObsField,
    NumericalObsField,
)

from scvi.model._utils import _init_library_size  # Initializes library size parameters (used for normalization)
from scvi.model.base import EmbeddingMixin, UnsupervisedTrainingMixin  # Mixins that add embedding and training behaviors
from scvi.module import VAE  # The core variational autoencoder module
from scvi.utils import setup_anndata_dsp  # Decorator to help document and handle setup_anndata logic

# Importing additional mixin classes and base model class
from .base import ArchesMixin, BaseMinifiedModeModelClass, RNASeqMixin, VAEMixin

# Import type hints only during static type checking
if TYPE_CHECKING:
    from typing import Literal
    from anndata import AnnData  # AnnData is the main data structure for single-cell data

# Create a logger for this module
logger = logging.getLogger(__name__)

## CC_model

In [23]:
class CC_model(
# 1. CLASS INHERITANCE
    EmbeddingMixin,               # Adds methods for getting latent representations
    RNASeqMixin,                  # Adds single-cell RNA-seq-specific logic
    VAEMixin,                     # Adds methods for working with a VAE model
    ArchesMixin,                  # Adds functionality for transfer learning (ARCHES)
    UnsupervisedTrainingMixin,   # Adds methods for unsupervised training
    BaseMinifiedModeModelClass,  # Adds support for working with memory-efficient minified AnnData
):
    
# 2. CLASS DOCSTRING
        """single-cell Variational Inference :cite:p:`Lopez18`.

    Parameters
    ----------
    adata
        AnnData object that has been registered via setup_anndata
    n_hidden, n_latent, etc.
        VAE architecture and model parameters
    Examples
    --------
    Shows usage with reading data, setting up, training, and getting results
    """

# 3. CLASS ATTRIBUTES
    _module_cls = CC_VAE  # Points to the VAE class that this model will use internally 
    _LATENT_QZM_KEY = "scvi_latent_qzm"  # Key for the latent mean in AnnData
    _LATENT_QZV_KEY = "scvi_latent_qzv"  # Key for the latent variance in AnnData
# 4. CONSTRUCTOR
    def __init__(
        self,
        adata: AnnData | None = None,  # Input data; can be None (if adata is not provided, the model will delay initialization until train is called).
        n_hidden: int = 128,           # Hidden units per layer
        n_latent: int = 10,            # Dimensionality of latent space
        n_layers: int = 1,             # Number of layers in encoder/decoder
        dropout_rate: float = 0.1,     # Dropout rate
        dispersion: Literal[...] = "gene",         # Type of dispersion parameter
        gene_likelihood: Literal[...] = "zinb",    # Distribution to model gene expression
        latent_distribution: Literal[...] = "normal",  # Latent distribution type
        **kwargs,                      # Any other parameters passed to the VAE
    ):
        super().__init__(adata)  # Call the constructor of the parent mixin/base classes

        # Store parameters in a dictionary
        self._module_kwargs = {
            "n_hidden": n_hidden,
            "n_latent": n_latent,
            "n_layers": n_layers,
            "dropout_rate": dropout_rate,
            "dispersion": dispersion,
            "gene_likelihood": gene_likelihood,
            "latent_distribution": latent_distribution,
            **kwargs,
        }

        # Create a readable summary string
        self._model_summary_string = (
            "CC model with the following parameters: \n"
            f"n_hidden: {n_hidden}, n_latent: {n_latent}, n_layers: {n_layers}, "
            f"dropout_rate: {dropout_rate}, dispersion: {dispersion}, "
            f"gene_likelihood: {gene_likelihood}, latent_distribution: {latent_distribution}."
        )

        # If lazy initialization is enabled (adata is not provided), postpone model creation until training
        if self._module_init_on_train:
            self.module = None
            warnings.warn(
                "Model was initialized without `adata`. The module will be initialized when "
                "calling `train`. This behavior is experimental and may change in the future.",
                UserWarning,
                stacklevel=settings.warnings_stacklevel,
            )
        else:
            # Get categorical covariate info, if available
            n_cats_per_cov = (
                self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key
                if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry
                else None
            )

            # Get number of batches
            n_batch = self.summary_stats.n_batch

            # Determine if size factor is provided in the data
            use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry

            # Initialize library size params if needed
            library_log_means, library_log_vars = None, None
            if (
                not use_size_factor_key
                and self.minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR
            ):
                library_log_means, library_log_vars = _init_library_size(
                    self.adata_manager, n_batch
                )

            # Instantiate the actual VAE model
            self.module = self._module_cls(
                n_input=self.summary_stats.n_vars,  # Number of genes
                n_batch=n_batch,
                n_labels=self.summary_stats.n_labels,
                n_continuous_cov=self.summary_stats.get("n_extra_continuous_covs", 0),
                n_cats_per_cov=n_cats_per_cov,
                n_hidden=n_hidden,
                n_latent=n_latent,
                n_layers=n_layers,
                dropout_rate=dropout_rate,
                dispersion=dispersion,
                gene_likelihood=gene_likelihood,
                latent_distribution=latent_distribution,
                use_size_factor_key=use_size_factor_key,
                library_log_means=library_log_means,
                library_log_vars=library_log_vars,
                **kwargs,
            )

            # Set minified type to the model (used for memory optimization)
            self.module.minified_data_type = self.minified_data_type

        # Save init parameters for reproducibility
        self.init_params_ = self._get_init_params(locals())

# 5. Define setup_anndata for preproccessing AnnData

    @classmethod
    @setup_anndata_dsp.dedent  # Automatically formats docstring from template
    def setup_anndata(
        cls,
        adata: AnnData,
        layer: str | None = None,  # Which layer of AnnData.X to use
        batch_key: str | None = None,  # Batch annotation column in adata.obs
        labels_key: str | None = None,  # Label annotation column
        size_factor_key: str | None = None,  # Precomputed size factor
        categorical_covariate_keys: list[str] | None = None,  # Categorical covariates
        continuous_covariate_keys: list[str] | None = None,   # Continuous covariates
        **kwargs,
    ):
        """%(summary)s.

        Parameters
        ----------
        %(param_adata)s
        %(param_layer)s
        %(param_batch_key)s
        %(param_labels_key)s
        %(param_size_factor_key)s
        %(param_cat_cov_keys)s
        %(param_cont_cov_keys)s
        """

        # Get arguments as a dictionary
        setup_method_args = cls._get_setup_method_args(**locals())

        # Define how to extract relevant fields from AnnData
        anndata_fields = [
            LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
            CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
            CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key),
            NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False),
            CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys),
            NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys),
        ]

        # If this is a "minified" AnnData, add extra required fields
        adata_minify_type = _get_adata_minify_type(adata)
        if adata_minify_type is not None:
            anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type)

        # Create a manager to track and validate all fields
        adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)

        # Register fields into the manager
        adata_manager.register_fields(adata, **kwargs)

        # Register the manager for this class (global to model)
        cls.register_manager(adata_manager)

IndentationError: unindent does not match any outer indentation level (<string>, line 26)

## CC_VAE (inspired by VAE)

z ~ q(z | x)
α = f(z)
c ~ Dirichlet(α)
e ~ NB(f(c))

In [17]:
from scvi.module.base import BaseMinifiedModeModuleClass, EmbeddingModuleMixin, LossOutput, auto_move_data
from scvi import REGISTRY_KEYS
from scvi.module._constants import MODULE_KEYS

import torch
import torch.nn as nn
from torch.distributions import Normal, Dirichlet
from scvi.distributions import NegativeBinomial


class DirichletVAE(EmbeddingModuleMixin, BaseMinifiedModeModuleClass):
    """
    Custom hierarchical VAE:
    z ~ N(0, I)
    alpha = f(z)
    c ~ Dirichlet(alpha)
    expression ~ NB(f(c), scale)
    """

    def __init__(
        self,
        n_input,
        n_latent=10,
        n_categories=4,
        n_genes=None,
        use_observed_lib_size=True,
    ):
        super().__init__()
        self.use_observed_lib_size = use_observed_lib_size

        # z encoder
        self.encoder = nn.Sequential(
            nn.Linear(n_input, 128), nn.ReLU(), nn.Linear(128, 128), nn.ReLU()
        )
        self.z_mean = nn.Linear(128, n_latent)
        self.z_var = nn.Linear(128, n_latent)

        # z -> alpha for Dirichlet
        self.alpha_decoder = nn.Sequential(
            nn.Linear(n_latent, 128), nn.ReLU(), nn.Linear(128, n_categories), nn.Softplus()
        )

        # c -> expression mean
        self.decoder = nn.Sequential(
            nn.Linear(n_categories, 128), nn.ReLU(), nn.Linear(128, n_genes), nn.Softplus()
        )

        self.px_r = nn.Parameter(torch.randn(n_genes))

    @auto_move_data
    def forward(self, tensors, compute_loss=True):
        x = tensors[REGISTRY_KEYS.X_KEY]
        library = tensors[REGISTRY_KEYS.SIZE_FACTOR_KEY].unsqueeze(-1) if self.use_observed_lib_size else 1.0

        h = self.encoder(x)
        mu = self.z_mean(h)
        var = torch.exp(self.z_var(h)) + 1e-4
        qz = Normal(mu, var.sqrt())
        z = qz.rsample()

        alpha = self.alpha_decoder(z) + 1e-4  # ensure positive alpha
        q_dir = Dirichlet(alpha)
        c = q_dir.rsample()

        px_rate = self.decoder(c) * library
        px = NegativeBinomial(mu=px_rate, theta=torch.exp(self.px_r))

        if not compute_loss:
            return {"z": z, "c": c, "dist": px}

        reconst_loss = -px.log_prob(x).sum(-1)
        pz = Normal(torch.zeros_like(mu), torch.ones_like(var))
        kl_z = torch.distributions.kl_divergence(qz, pz).sum(-1)

        loss = torch.mean(reconst_loss + kl_z)

        return LossOutput(
            loss=loss,
            reconstruction_loss=reconst_loss,
            kl_local={"kl_z": kl_z},
            extra_metrics={"z": z, "c": c},
        )

## MyNeuralNet()

In [1]:
import torch
import torch.nn as nn
from torch.distributions import Normal, Dirichlet
from scvi.distributions import NegativeBinomial

  doc = func(self, args[0].__doc__, *args[1:], **kwargs)
  doc = func(self, args[0].__doc__, *args[1:], **kwargs)


In [2]:
# Encoder
class Encoder(nn.Module):
    """Encode data of ``n_input`` dimensions into a latent space of ``n_output`` dimensions.

    Uses a fully-connected neural network of ``n_hidden`` layers.

    Parameters
    ----------
    n_input
        The dimensionality of the input (data space)
    n_output
        The dimensionality of the output (latent space)
    n_cat_list
        A list containing the number of categories
        for each category of interest. Each category will be
        included using a one-hot encoding
    n_layers
        The number of fully-connected hidden layers
    n_hidden
        The number of nodes per hidden layer
    dropout_rate
        Dropout rate to apply to each of the hidden layers
    distribution
        Distribution of z
    var_eps
        Minimum value for the variance;
        used for numerical stability
    var_activation
        Callable used to ensure positivity of the variance.
        Defaults to :meth:`torch.exp`.
    return_dist
        Return directly the distribution of z instead of its parameters.
    **kwargs
        Keyword args for :class:`~scvi.nn.FCLayers`
    """

    def __init__(
        self,
        n_input: int,
        n_output: int,
        n_cat_list: Iterable[int] = None,
        n_layers: int = 1,
        n_hidden: int = 128,
        dropout_rate: float = 0.1,
        distribution: str = "normal",
        var_eps: float = 1e-4,
        var_activation: Callable | None = None,
        return_dist: bool = False,
        **kwargs,
    ):
        super().__init__()

        self.distribution = distribution
        self.var_eps = var_eps
        self.encoder = FCLayers(
            n_in=n_input,
            n_out=n_hidden,
            n_cat_list=n_cat_list,
            n_layers=n_layers,
            n_hidden=n_hidden,
            dropout_rate=dropout_rate,
            **kwargs,
        )
        self.mean_encoder = nn.Linear(n_hidden, n_output)
        self.var_encoder = nn.Linear(n_hidden, n_output)
        self.return_dist = return_dist

        if distribution == "ln":
            self.z_transformation = nn.Softmax(dim=-1)
        else:
            self.z_transformation = _identity
        self.var_activation = torch.exp if var_activation is None else var_activation

    def forward(self, x: torch.Tensor, *cat_list: int):
        r"""The forward computation for a single sample.

         #. Encodes the data into latent space using the encoder network
         #. Generates a mean \\( q_m \\) and variance \\( q_v \\)
         #. Samples a new value from an i.i.d. multivariate normal
            \\( \\sim Ne(q_m, \\mathbf{I}q_v) \\)

        Parameters
        ----------
        x
            tensor with shape (n_input,)
        cat_list
            list of category membership(s) for this sample

        Returns
        -------
        3-tuple of :py:class:`torch.Tensor`
            tensors of shape ``(n_latent,)`` for mean and var, and sample

        """
        # Parameters for latent distribution
        q = self.encoder(x, *cat_list)
        q_m = self.mean_encoder(q)
        q_v = self.var_activation(self.var_encoder(q)) + self.var_eps
        dist = Normal(q_m, q_v.sqrt())
        latent = self.z_transformation(dist.rsample())
        if self.return_dist:
            return dist, latent
        return q_m, q_v, latent

In [3]:
class DirichletDecoder(nn.Module):
    def __init__(self, n_latent, n_categories, n_hidden=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_latent, n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_categories),
            nn.Softplus()  # ensure alpha > 0
        )

    def forward(self, z):
        alpha = self.net(z) + 1e-4
        return Dirichlet(alpha)

In [4]:
class ExpressionDecoder(nn.Module):
    def __init__(self, n_categories, n_genes, n_hidden=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_categories, n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_genes),
            nn.Softplus()  # mean must be positive
        )
        self.px_r = nn.Parameter(torch.randn(n_genes))  # gene-specific dispersion

    def forward(self, c, library_size):
        mean = self.net(c)
        scaled_mean = mean * library_size.unsqueeze(-1)
        return NegativeBinomial(mu=scaled_mean, theta=torch.exp(self.px_r))

In [10]:
class MyHierarchicalVAE(nn.Module):
    def __init__(self, n_input, n_latent, n_categories, n_genes):
        super().__init__()
        self.encoder_z = EncoderZ(n_input, n_latent)
        self.decoder_dirichlet = DirichletDecoder(n_latent, n_categories)
        self.decoder_expression = ExpressionDecoder(n_categories, n_genes)

    def forward(self, x, library_size):
        qz = self.encoder_z(x)
        z = qz.rsample()
        q_dir = self.decoder_dirichlet(z)
        c = q_dir.rsample()
        px = self.decoder_expression(c, library_size)
        return qz, q_dir, px, z, c

    def loss(self, x, library_size):
        qz, q_dir, px, z, c = self.forward(x, library_size)
        # Negative log-likelihood
        recon_loss = -px.log_prob(x).sum(-1)
        # KL[q(z|x) || p(z)]
        pz = Normal(torch.zeros_like(z), torch.ones_like(z))
        kl_z = torch.distributions.kl_divergence(qz, pz).sum(-1)
        # No KL for c because it's from p(c|z)
        total_loss = (recon_loss + kl_z).mean()
        return total_loss

In [11]:
from torch.utils.data import DataLoader, TensorDataset

# Assume `x` is (n_cells, n_genes)
x = torch.randn(1000, 1000)
lib_size = x.sum(dim=1)

dataset = TensorDataset(x, lib_size)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

model = MyHierarchicalVAE(n_input=1000, n_latent=10, n_categories=4, n_genes=1000)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [13]:
n_epochs = 400

model.train()

MyHierarchicalVAE(
  (encoder_z): EncoderZ(
    (encoder): Sequential(
      (0): Linear(in_features=1000, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): ReLU()
    )
    (mean_encoder): Linear(in_features=128, out_features=10, bias=True)
    (var_encoder): Linear(in_features=128, out_features=10, bias=True)
  )
  (decoder_dirichlet): DirichletDecoder(
    (net): Sequential(
      (0): Linear(in_features=10, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=4, bias=True)
      (3): Softplus(beta=1.0, threshold=20.0)
    )
  )
  (decoder_expression): ExpressionDecoder(
    (net): Sequential(
      (0): Linear(in_features=4, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=1000, bias=True)
      (3): Softplus(beta=1.0, threshold=20.0)
    )
  )
)

In [1]:
import torch
import torch.nn as nn
from torch.distributions import Dirichlet
from typing import Iterable, Literal


class DecoderCCVI(nn.Module):
    def __init__(
        self,
        n_input: int,           # latent dimension
        n_output: int,          # number of genes
        n_cat_list: Iterable[int] = None,
        n_layers: int = 1,
        n_hidden: int = 128,
        inject_covariates: bool = True,
        use_batch_norm: bool = False,
        use_layer_norm: bool = False,
        scale_activation: Literal["softmax", "softplus"] = "softmax",
        dirichlet_dim: int = 4,
        **kwargs,
    ):
        super().__init__()

        # z → Dirichlet parameters (α1, ..., α4)
        self.dirichlet_param_net = FCLayers(
            n_in=n_input,
            n_out=dirichlet_dim,
            n_cat_list=n_cat_list,
            n_layers=n_layers,
            n_hidden=n_hidden,
            dropout_rate=0,
            inject_covariates=inject_covariates,
            use_batch_norm=use_batch_norm,
            use_layer_norm=use_layer_norm,
            **kwargs,
        )

        self.dirichlet_activation = nn.Softplus()  # ensure positive α values

        # c → px_scale (expression proportions)
        if scale_activation == "softmax":
            px_scale_activation = nn.Softmax(dim=-1)
        elif scale_activation == "softplus":
            px_scale_activation = nn.Softplus()
        else:
            raise ValueError("Unknown activation")

        self.cell_cycle_decoder = nn.Sequential(
            nn.Linear(dirichlet_dim, n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_output),
            px_scale_activation,
        )

        self.px_r_decoder = nn.Linear(n_output, n_output)  # Optional: gene-cell dispersion
        #self.px_dropout_decoder = nn.Linear(n_output, n_output)

    def forward(
        self,
        dispersion: str,
        z: torch.Tensor,
        library: torch.Tensor,
        *cat_list: int,
    ):
        # 1. Generate Dirichlet parameters from z
        a = self.dirichlet_param_net(z, *cat_list)
        a = self.dirichlet_activation(a) + 1e-4  # ensure positivity

        # 2. Sample cell cycle phase from Dirichlet
        dirichlet_dist = Dirichlet(a)
        c = dirichlet_dist.rsample()  # reparameterized sample for backprop

        # 3. Generate px_scale from c
        px_scale = self.cell_cycle_decoder(c)

        # 4. Dropout and rate
        px_dropout = None # self.px_dropout_decoder(px_scale)
        px_rate = torch.exp(library) * px_scale
        px_r = self.px_r_decoder(px_scale) if dispersion == "gene-cell" else None

        return px_scale, px_r, px_rate, px_dropout


In [None]:
class CC_VAE(
    EmbeddingModuleMixin, BaseMinifiedModeModuleClass
):
    def __init__(
        self,
        n_input,
        n_batch=0,
        n_labels=0,
        n_hidden=128,
        n_latent=10,
        n_layers=1,
        n_continuous_cov=0,
        n_cats_per_cov=None,
        dropout_rate=0.1,
        dispersion="gene",
        log_variational=True,
        gene_likelihood="nb",
        latent_distribution="normal",
        encode_covariates=False,
        deeply_inject_covariates=True,
        batch_representation="one-hot",
        use_batch_norm="both",
        use_layer_norm="none",
        use_size_factor_key=False,
        use_observed_lib_size=True,
        library_log_means=None,
        library_log_vars=None,
        var_activation=None,
        extra_encoder_kwargs=None,
        extra_decoder_kwargs=None,
        batch_embedding_kwargs=None,
    ):
        from scvi.nn import DecoderSCVI, Encoder

        super().__init__()

        self.dispersion = dispersion
        self.n_latent = n_latent
        self.log_variational = log_variational
        self.gene_likelihood = gene_likelihood
        self.n_batch = n_batch
        self.n_labels = n_labels
        self.latent_distribution = latent_distribution
        self.encode_covariates = encode_covariates
        self.use_size_factor_key = use_size_factor_key
        self.use_observed_lib_size = use_size_factor_key or use_observed_lib_size

        if not self.use_observed_lib_size:
            if library_log_means is None or library_log_vars is None:
                raise ValueError("Must provide library_log_means and library_log_vars if not using observed_lib_size.")
            self.register_buffer("library_log_means", torch.from_numpy(library_log_means).float())
            self.register_buffer("library_log_vars", torch.from_numpy(library_log_vars).float())

        if self.dispersion == "gene":
            self.px_r = torch.nn.Parameter(torch.randn(n_input))
        elif self.dispersion == "gene-batch":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, n_batch))
        elif self.dispersion == "gene-label":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, n_labels))
        elif self.dispersion != "gene-cell":
            raise ValueError("`dispersion` must be one of 'gene', 'gene-batch', 'gene-label', 'gene-cell'.")

        self.batch_representation = batch_representation
        if self.batch_representation == "embedding":
            self.init_embedding(REGISTRY_KEYS.BATCH_KEY, n_batch, **(batch_embedding_kwargs or {}))
            batch_dim = self.get_embedding(REGISTRY_KEYS.BATCH_KEY).embedding_dim
        elif self.batch_representation != "one-hot":
            raise ValueError("`batch_representation` must be one of 'one-hot', 'embedding'.")

        use_batch_norm_encoder = use_batch_norm in ["encoder", "both"]
        use_batch_norm_decoder = use_batch_norm in ["decoder", "both"]
        use_layer_norm_encoder = use_layer_norm in ["encoder", "both"]
        use_layer_norm_decoder = use_layer_norm in ["decoder", "both"]

        n_input_encoder = n_input + n_continuous_cov * encode_covariates
        if self.batch_representation == "embedding":
            n_input_encoder += batch_dim * encode_covariates
            cat_list = list([] if n_cats_per_cov is None else n_cats_per_cov)
        else:
            cat_list = [n_batch] + list([] if n_cats_per_cov is None else n_cats_per_cov)

        encoder_cat_list = cat_list if encode_covariates else None
        _extra_encoder_kwargs = extra_encoder_kwargs or {}

        self.z_encoder = Encoder(
            n_input_encoder,
            n_latent,
            n_cat_list=encoder_cat_list,
            n_layers=n_layers,
            n_hidden=n_hidden,
            dropout_rate=dropout_rate,
            distribution=latent_distribution,
            inject_covariates=deeply_inject_covariates,
            use_batch_norm=use_batch_norm_encoder,
            use_layer_norm=use_layer_norm_encoder,
            var_activation=var_activation,
            return_dist=True,
            **_extra_encoder_kwargs,
        )

        self.l_encoder = Encoder(
            n_input_encoder,
            1,
            n_layers=1,
            n_cat_list=encoder_cat_list,
            n_hidden=n_hidden,
            dropout_rate=dropout_rate,
            inject_covariates=deeply_inject_covariates,
            use_batch_norm=use_batch_norm_encoder,
            use_layer_norm=use_layer_norm_encoder,
            var_activation=var_activation,
            return_dist=True,
            **_extra_encoder_kwargs,
        )

        n_input_decoder = n_latent + n_continuous_cov
        if self.batch_representation == "embedding":
            n_input_decoder += batch_dim

        _extra_decoder_kwargs = extra_decoder_kwargs or {}
        self.decoder = DecoderSCVI(
            n_input_decoder,
            n_input,
            n_cat_list=cat_list,
            n_layers=n_layers,
            n_hidden=n_hidden,
            inject_covariates=deeply_inject_covariates,
            use_batch_norm=use_batch_norm_decoder,
            use_layer_norm=use_layer_norm_decoder,
            scale_activation="softplus" if use_size_factor_key else "softmax",
            dirichlet_dim=4,
            **_extra_decoder_kwargs,
        )