In [18]:
import numpy as np
import scanpy as sc
import torch
import torch.nn
from scvi.distributions import NegativeBinomial
from typing import Optional
from anndata import AnnData
import scvi
from scvi.data import AnnDataManager
from scvi.data.fields import(
    LayerField, 
    CategoricalObsField,
    NumericalObsField,
    CategoricalJointObsField,
    NumericalJointObsField,
)
from scvi import REGISTRY_KEYS
from scvi.module.base import (
    BaseModuleClass,
    LossRecorder,
    auto_move_data,
)
from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin

sc.set_figure_params(figsize=(4, 4))

I will be randomly generating data for now - this data follows a ZINB distribution

## Data Loading

In [21]:
adata = scvi.data.synthetic_iid()

anndata_fields = [
    LayerField(registry_key="x", layer=None, is_count_data=True),
    CategoricalObsField(registry_key="batch", obs_key="batch"),
]
adata_manager = AnnDataManager(fields=anndata_fields)
adata_manager.register_fields(adata)
print(adata_manager.registry.keys()) # There is additionally a _scvi_uuid key which is used to uniquely identify AnnData objects for subsequent retrieval.

dict_keys(['scvi_version', 'model_name', 'setup_args', 'field_registries', '_scvi_uuid'])


## Probabilistic Model

For now, I am exclusively interested in an inferential task. I'll build up a generative model later by making assumptions about the distribution of these parameters. But for now, I will assume no variation in cell state. The gene expression of a gene $g$ in a cell $n$, $x_{ng}$, is then:
$$ x_{ng} \sim \textrm{NegativeBinomial} (l_n \mu_g, \theta_g)$$

Where $\mu_g$ and $\theta_g$ are positive parameters to be learned (the mean and dispersion parameters of the negative binomial distribution).

### Inference mechanism

We can use maximum likelihood estimation to perform inference over the parameters $\Theta=\{ \mu_g, \theta_g \}$.

In [10]:
class NB_Module(BaseModuleClass):
    """
    Basic negative binomial model.

    Parameters
    ----------
    n_input
        Number of input genes
    """

    def __init__(
        self,
        n_input: int
    ):
        super().__init__()
        # in the init, we create the parameters of our elementary stochastic computation unit.

        # First, we setup the parameters of the generative model
        self.log_mu = torch.nn.Parameter(torch.randn(n_input))
        self.log_theta = torch.nn.Parameter(torch.randn(n_input))

    def _get_generative_input(self, tensors, inference_outputs):
        x = tensors[_CONSTANTS.X_KEY]
        # here we extract the number of UMIs per cell as a known quantity
        library = torch.sum(x, dim=1, keepdim=True)

        input_dict = {
            "library": library,
        }
        return input_dict

    @auto_move_data
    def generative(self, library):
        """Runs the generative model."""

        # get the mean parameter of the negative binomial
        mu = library * torch.exp(self.log_mu)
        # get the dispersion parameter
        theta = torch.exp(self.log_theta)

        return dict(
            mu=mu, theta=theta
        )

    def loss(
        self,
        tensors,
        generative_outputs,
    ):

        # here, we would like to form the log likelihood
        # so we extract all the required information
        x = tensors[REGISTRY_KEYS.X_KEY]
        mu = generative_outputs["mu"]
        theta = generative_outputs["theta"]

        # log likelihood
        # note that I'm using the scVI NB which offers this mu/ theta parametrization, compatible with the Poisson-Gamma 
        log_lik = NegativeBinomial(mu=mu, theta=theta).log_prob(x).sum(dim=-1)

        nll = torch.mean(-log_lik)
        return LossRecorder(loss=nll, reconstruction_loss=nll)

## Training

In [34]:
class NB_Model(UnsupervisedTrainingMixin, BaseModelClass):
    def __init__(
        self,
        adata: AnnData,
        **model_kwargs,
    ):
        super(NB_Model, self).__init__(adata)

        self.module = NB_Module(
            n_input=self.summary_stats["n_vars"],
            **model_kwargs,
        )
        self._model_summary_string = ("NB Model has been created")
        self.init_params_ = self._get_init_params(locals())

    @classmethod
    def setup_anndata(
        cls,
        adata: AnnData,
        batch_key: Optional[str] = None,
        layer: Optional[str] = None,
        **kwargs,
    ) -> Optional[AnnData]:
        setup_method_args = cls._get_setup_method_args(**locals())
        anndata_fields = [
            LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
            CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
            # Dummy fields required for VAE class.
            CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, None),
            NumericalObsField(
                REGISTRY_KEYS.SIZE_FACTOR_KEY, None, required=False
            ),
            CategoricalJointObsField(
                REGISTRY_KEYS.CAT_COVS_KEY, None
            ),
            NumericalJointObsField(
                REGISTRY_KEYS.CONT_COVS_KEY, None
            ),
        ]
        adata_manager = AnnDataManager(
            fields=anndata_fields, setup_method_args=setup_method_args
        )
        adata_manager.register_fields(adata, **kwargs)
        cls.register_manager(adata_manager)

In [35]:
NB_Model.setup_anndata(adata, batch_key="batch")
my_model = NB_Model(adata)

In [36]:
my_model.train(max_epochs=20)

TypeError: __init__() got an unexpected keyword argument 'enable_checkpointing'