In [2]:
!pip install --quiet scvi-colab
from scvi_colab import install
install()

In [3]:
import numpy as np
import pandas as pd
import scanpy as sc
import scvelo as scv
import torch
from velovi import preprocess_data

import matplotlib.pyplot as plt
import seaborn as sns


In [4]:
from scvi.train import LoudEarlyStopping
class MyEarlyStopping(LoudEarlyStopping):
    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)

    def _evaluate_stopping_criteria(self, current):
        should_stop, reason  = super()._evaluate_stopping_criteria(current)

        if not should_stop:
            new_lr = self.optimizer.param_groups[0]['lr']
            if self.watch_lr is not None and self.watch_lr != new_lr:
                self.watch_lr = new_lr
                self.update_prox_ops()

        return should_stop, reason



In [31]:
# -*- coding: utf-8 -*-
"""Main module."""
from typing import Callable, Iterable, Optional

import numpy as np
import torch
import torch.nn.functional as F
from scvi._compat import Literal
from scvi.module.base import BaseModuleClass, LossRecorder, auto_move_data
from scvi.nn import Encoder, FCLayers
from torch import nn as nn
from torch.distributions import Categorical, Dirichlet, MixtureSameFamily, Normal
from torch.distributions import kl_divergence as kl
from scvi.distributions import NegativeBinomial

import logging
import warnings
from functools import partial
from typing import Iterable, List, Optional, Sequence, Tuple, Union

from anndata import AnnData
from joblib import Parallel, delayed
from scipy.stats import ttest_ind
from scvi._compat import Literal
from scvi._utils import _doc_params
from scvi.data import AnnDataManager
from scvi.data.fields import LayerField
from scvi.dataloaders import AnnDataLoader, DataSplitter
from scvi.model._utils import scrna_raw_counts_properties
from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin, VAEMixin
from scvi.model.base._utils import _de_core
from scvi.train import TrainRunner
from scvi.utils._docstrings import doc_differential_expression, setup_anndata_dsp
from sklearn.metrics.pairwise import cosine_similarity

from velovi import REGISTRY_KEYS

logger = logging.getLogger(__name__)

torch.backends.cudnn.benchmark = True

def one_hot_encoder(idx, n_cls):
    assert torch.max(idx).item() < n_cls
    if idx.dim() == 1:
        idx = idx.unsqueeze(1)
    onehot = torch.zeros(idx.size(0), n_cls)
    onehot = onehot.to(idx.device)
    onehot.scatter_(1, idx.long(), 1)
    return onehot

class MaskedLinear(nn.Linear):
    def __init__(self, n_in,  n_out, mask, bias=True):
        # mask should have the same dimensions as the transposed linear weight
        # n_input x n_output_nodes
        if n_in != mask.shape[0] or n_out != mask.shape[1]:
            raise ValueError('Incorrect shape of the mask.')

        super().__init__(n_in, n_out, bias)

        self.register_buffer('mask', mask.t())

        # zero out the weights for group lasso
        # gradient descent won't change these zero weights
        self.weight.data*=self.mask

    def forward(self, input):
        return nn.functional.linear(input, self.weight*self.mask, self.bias)

class MaskedCondLayers(nn.Module):
    def __init__(
        self,
        n_in: int,
        n_out: int,
        n_cond: int,
        bias: bool,
        n_ext: int = 0,
        n_ext_m: int = 0,
        mask: Optional[torch.Tensor] = None,
        ext_mask: Optional[torch.Tensor] = None
    ):
        super().__init__()
        self.n_cond = n_cond
        self.n_ext = n_ext
        self.n_ext_m = n_ext_m

        self.expr_L = nn.Linear(n_in, n_out, bias=bias)

        # if mask is None:
        #     self.expr_L = nn.Linear(n_in, n_out, bias=bias)
        # else:
        #     self.expr_L = MaskedLinear(n_in, n_out, mask, bias=bias)

        # if self.n_cond != 0:
        #     self.cond_L = nn.Linear(self.n_cond, n_out, bias=False)

        # if self.n_ext != 0:
        #     self.ext_L = nn.Linear(self.n_ext, n_out, bias=False)

        # if self.n_ext_m != 0:
        #     if ext_mask is not None:
        #         self.ext_L_m = MaskedLinear(self.n_ext_m, n_out, ext_mask, bias=False)
        #     else:
        #         self.ext_L_m = nn.Linear(self.n_ext_m, n_out, bias=False)

    def forward(self, x: torch.Tensor):
        # if self.n_cond == 0:
        #     expr, cond = x, None
        # else:
        #     expr, cond = torch.split(x, [x.shape[1] - self.n_cond, self.n_cond], dim=1)

        # if self.n_ext == 0:
        #     ext = None
        # else:
        #     expr, ext = torch.split(expr, [expr.shape[1] - self.n_ext, self.n_ext], dim=1)

        # if self.n_ext_m == 0:
        #     ext_m = None
        # else:
        #     expr, ext_m = torch.split(expr, [expr.shape[1] - self.n_ext_m, self.n_ext_m], dim=1)

        expr=x

        out = self.expr_L(expr)
        # if ext is not None:
        #     out = out + self.ext_L(ext)
        # if ext_m is not None:
        #     out = out + self.ext_L_m(ext_m)
        # if cond is not None:
        #     out = out + self.cond_L(cond)
        return out


# class MaskedLinearDecoder(nn.Module):
#     def __init__(self, in_dim, out_dim, n_cond, mask, ext_mask, recon_loss,
#                  last_layer=None, n_ext=0, n_ext_m=0):
#         super().__init__()

#         if recon_loss == "mse":
#             if last_layer == "softmax":
#                 raise ValueError("Can't specify softmax last layer with mse loss.")
#             last_layer = "identity" if last_layer is None else last_layer
#         elif recon_loss == "nb":
#             last_layer = "softmax" if last_layer is None else last_layer
#         else:
#             raise ValueError("Unrecognized loss.")

#         print("GP Decoder Architecture:")
#         print("\tMasked linear layer in, ext_m, ext, cond, out: ", in_dim, n_ext_m, n_ext, n_cond, out_dim)
#         if mask is not None:
#             print('\twith hard mask.')
#         else:
#             print('\twith soft mask.')

#         self.n_ext = n_ext
#         self.n_ext_m = n_ext_m

#         self.n_cond = 0
#         if n_cond is not None:
#             self.n_cond = n_cond

#         self.L0 = MaskedCondLayers(in_dim, out_dim, n_cond, bias=False, n_ext=n_ext, n_ext_m=n_ext_m,
#                                    mask=mask, ext_mask=ext_mask)

#         if last_layer == "softmax":
#             self.mean_decoder = nn.Softmax(dim=-1)
#         elif last_layer == "softplus":
#             self.mean_decoder = nn.Softplus()
#         elif last_layer == "exp":
#             self.mean_decoder = torch.exp
#         elif last_layer == "relu":
#             self.mean_decoder = nn.ReLU()
#         elif last_layer == "identity":
#             self.mean_decoder = lambda a: a
#         else:
#             raise ValueError("Unrecognized last layer.")

#         print("Last Decoder layer:", last_layer)

#     def forward(self, z, batch=None):
#         # if batch is not None:
#         #     batch = one_hot_encoder(batch, n_cls=self.n_cond)
#         #     z_cat = torch.cat((z, batch), dim=-1)
#         #     dec_latent = self.L0(z_cat)
#         # else:
#         #     dec_latent = self.L0(z)

#         dec_latent = self.L0(z)
#         recon_x = self.mean_decoder(dec_latent)


#         return recon_x, dec_latent

class DecoderVELOVI(nn.Module):
    """
    Decodes data from latent space of ``n_input`` dimensions ``n_output``dimensions.

    Uses a fully-connected neural network of ``n_hidden`` layers.

    Parameters
    ----------
    n_input
        The dimensionality of the input (latent space)
    n_output
        The dimensionality of the output (data space)
    n_cat_list
        A list containing the number of categories
        for each category of interest. Each category will be
        included using a one-hot encoding
    n_layers
        The number of fully-connected hidden layers
    n_hidden
        The number of nodes per hidden layer
    dropout_rate
        Dropout rate to apply to each of the hidden layers
    inject_covariates
        Whether to inject covariates in each layer, or just the first (default).
    use_batch_norm
        Whether to use batch norm in layers
    use_layer_norm
        Whether to use layer norm in layers
    linear_decoder
        Whether to use linear decoder for time
    """

    def __init__(
        self,
        n_input: int,
        n_output: int,
        n_ext: int = 0,
        n_ext_m: int = 0,
        n_cond: int = 0,
        last_layer: str =None,
        ext_mask: torch.Tensor = None,
        mask: torch.Tensor = None,
        recon_loss: str = 'nb',
        n_cat_list: Iterable[int] = None,
        n_layers: int = 1,
        n_hidden: int = 128,
        inject_covariates: bool = True,
        use_batch_norm: bool = True,
        use_layer_norm: bool = False,
        dropout_rate: float = 0.0,
        linear_decoder: bool = False,
        **kwargs,
    ):
        super().__init__()
        self.n_ouput = n_output
        self.linear_decoder = linear_decoder

        ### GP decoder ###

        if recon_loss == "mse":
            if last_layer == "softmax":
                raise ValueError("Can't specify softmax last layer with mse loss.")
            last_layer = "identity" if last_layer is None else last_layer
        elif recon_loss == "nb":
            last_layer = "softmax" if last_layer is None else last_layer
        else:
            raise ValueError("Unrecognized loss.")

        #print("GP Decoder Architecture:")
        #print("\tMasked linear layer in, ext_m, ext, cond, out: ", in_dim, n_ext_m, n_ext, n_cond, out_dim)
        if mask is not None:
            print('\twith hard mask.')
        else:
            print('\twith soft mask.')

        self.n_ext = n_ext
        self.n_ext_m = n_ext_m

        self.n_cond = 0
        if n_cond is not None:
            self.n_cond = n_cond

        self.L0 = MaskedCondLayers(n_input, n_output, n_cond, bias=False, n_ext=n_ext, n_ext_m=n_ext_m,
                                   mask=mask, ext_mask=ext_mask)

        if last_layer == "softmax":
            self.mean_decoder = nn.Softmax(dim=-1)
        elif last_layer == "softplus":
            self.mean_decoder = nn.Softplus()
        elif last_layer == "exp":
            self.mean_decoder = torch.exp
        elif last_layer == "relu":
            self.mean_decoder = nn.ReLU()
        elif last_layer == "identity":
            self.mean_decoder = lambda a: a
        else:
            raise ValueError("Unrecognized last layer.")

        print("Last Decoder layer:", last_layer)

        self.rho_first_decoder = FCLayers(
            n_in=n_input,
            n_out=n_hidden if not linear_decoder else n_output,
            n_cat_list=n_cat_list,
            n_layers=n_layers if not linear_decoder else 1,
            n_hidden=n_hidden,
            dropout_rate=dropout_rate,
            inject_covariates=inject_covariates,
            use_batch_norm=use_batch_norm,
            use_layer_norm=use_layer_norm if not linear_decoder else False,
            use_activation=not linear_decoder,
            bias=not linear_decoder,
            **kwargs,
        )

        self.pi_first_decoder = FCLayers(
            n_in=n_input,
            n_out=n_hidden,
            n_cat_list=n_cat_list,
            n_layers=n_layers,
            n_hidden=n_hidden,
            dropout_rate=dropout_rate,
            inject_covariates=inject_covariates,
            use_batch_norm=use_batch_norm,
            use_layer_norm=use_layer_norm,
            **kwargs,
        )

        self.px_pi_decoder = nn.Linear(n_hidden, 4 * n_output)
        

        # rho for induction
        self.px_rho_decoder = nn.Sequential(nn.Linear(n_hidden, n_output), nn.Sigmoid())

        # tau for repression
        self.px_tau_decoder = nn.Sequential(nn.Linear(n_hidden, n_output), nn.Sigmoid())

        self.linear_scaling_tau = nn.Parameter(torch.zeros(n_output))
        self.linear_scaling_tau_intercept = nn.Parameter(torch.zeros(n_output))

    def forward(self, z: torch.Tensor, latent_dim: int = None):
        """
        The forward computation for a single sample.

         #. Decodes the data from the latent space using the decoder network
         #. Returns parameters for the ZINB distribution of expression
         #. If ``dispersion != 'gene-cell'`` then value for that param will be ``None``

        Parameters
        ----------
        z :
            tensor with shape ``(n_input,)``
        cat_list
            list of category membership(s) for this sample

        Returns
        -------
        4-tuple of :py:class:`torch.Tensor`
            parameters for the ZINB distribution of expression

        """

        z_in = z
        if latent_dim is not None:
            mask = torch.zeros_like(z)
            mask[..., latent_dim] = 1
            z_in = z * mask
        # The decoder returns values for the parameters of the ZINB distribution
        rho_first = self.rho_first_decoder(z_in)

        dec_latent = self.L0(z)
        recon_x = self.mean_decoder(dec_latent)

        if not self.linear_decoder:
            px_rho = self.px_rho_decoder(rho_first)
            px_tau = self.px_tau_decoder(rho_first)
        else:
            px_rho = nn.Sigmoid()(rho_first)
            px_tau = 1 - nn.Sigmoid()(
                rho_first * self.linear_scaling_tau.exp()
                + self.linear_scaling_tau_intercept
            )

        # cells by genes by 4
        pi_first = self.pi_first_decoder(z)
        px_pi = nn.Softplus()(
            torch.reshape(self.px_pi_decoder(pi_first), (z.shape[0], self.n_ouput, 4))
        )

        return px_pi, px_rho, px_tau, recon_x, dec_latent


# VAE model
class VELOVAE(BaseModuleClass):
    """
    Variational auto-encoder model.

    This is an implementation of the scVI model descibed in [Lopez18]_

    Parameters
    ----------
    n_input
        Number of input genes
    n_hidden
        Number of nodes per hidden layer
    n_latent
        Dimensionality of the latent space
    n_layers
        Number of hidden layers used for encoder and decoder NNs
    dropout_rate
        Dropout rate for neural networks
    log_variational
        Log(data+1) prior to encoding for numerical stability. Not normalization.
    latent_distribution
        One of

        * ``'normal'`` - Isotropic normal
        * ``'ln'`` - Logistic normal with normal params N(0, 1)
    use_layer_norm
        Whether to use layer norm in layers
    use_observed_lib_size
        Use observed library size for RNA as scaling factor in mean of conditional distribution
    var_activation
        Callable used to ensure positivity of the variational distributions' variance.
        When `None`, defaults to `torch.exp`.
    """

    def __init__(
        self,
        n_input: int,
        true_time_switch: Optional[np.ndarray] = None,
        n_hidden: int = 128,
        n_latent: int = 10,
        n_layers: int = 1,
        dropout_rate: float = 0.1,
        log_variational: bool = False,
        latent_distribution: str = "normal",
        use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "both",
        use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "both",
        use_observed_lib_size: bool = True,
        var_activation: Optional[Callable] = torch.nn.Softplus(),
        model_steady_states: bool = True,
        gamma_unconstr_init: Optional[np.ndarray] = None,
        alpha_unconstr_init: Optional[np.ndarray] = None,
        alpha_1_unconstr_init: Optional[np.ndarray] = None,
        lambda_alpha_unconstr_init: Optional[np.ndarray] = None,
        switch_spliced: Optional[np.ndarray] = None,
        switch_unspliced: Optional[np.ndarray] = None,
        t_max: float = 20,
        penalty_scale: float = 0.2,
        dirichlet_concentration: float = 0.25,
        linear_decoder: bool = False,
        time_dep_transcription_rate: bool = False,
        #Parameters for masked linear decoder
        mask: torch.Tensor = None,
        recon_loss: str = 'nb',
        conditions: list = [],
        use_l_encoder: bool = False,
        dr_rate: float = 0.05,
        use_bn: bool = False,
        use_ln: bool = True,
        decoder_last_layer: Optional[str] = None,
        soft_mask: bool = False,
        n_ext: int = 0,
        n_ext_m: int = 0,
        use_hsic: bool = False,
        hsic_one_vs_all: bool = False,
        ext_mask: Optional[torch.Tensor] = None,
        soft_ext_mask: bool = False
    ):
        super().__init__()
        self.n_latent = n_latent
        self.log_variational = log_variational
        self.latent_distribution = latent_distribution
        self.use_observed_lib_size = use_observed_lib_size
        self.n_input = n_input
        self.model_steady_states = model_steady_states
        self.t_max = t_max
        self.penalty_scale = penalty_scale
        self.dirichlet_concentration = dirichlet_concentration
        self.time_dep_transcription_rate = time_dep_transcription_rate

        if switch_spliced is not None:
            self.register_buffer("switch_spliced", torch.from_numpy(switch_spliced))
        else:
            self.switch_spliced = None
        if switch_unspliced is not None:
            self.register_buffer("switch_unspliced", torch.from_numpy(switch_unspliced))
        else:
            self.switch_unspliced = None

        n_genes = n_input * 2

        # switching time
        self.switch_time_unconstr = torch.nn.Parameter(7 + 0.5 * torch.randn(n_input))
        if true_time_switch is not None:
            self.register_buffer("true_time_switch", torch.from_numpy(true_time_switch))
        else:
            self.true_time_switch = None

        # degradation
        if gamma_unconstr_init is None:
            self.gamma_mean_unconstr = torch.nn.Parameter(-1 * torch.ones(n_input))
        else:
            self.gamma_mean_unconstr = torch.nn.Parameter(
                torch.from_numpy(gamma_unconstr_init)
            )

        # splicing
        # first samples around 1
        self.beta_mean_unconstr = torch.nn.Parameter(0.5 * torch.ones(n_input))

        # transcription
        if alpha_unconstr_init is None:
            self.alpha_unconstr = torch.nn.Parameter(0 * torch.ones(n_input))
        else:
            self.alpha_unconstr = torch.nn.Parameter(
                torch.from_numpy(alpha_unconstr_init)
            )

        # TODO: Add `require_grad`
        if alpha_1_unconstr_init is None:
            self.alpha_1_unconstr = torch.nn.Parameter(0 * torch.ones(n_input))
        else:
            self.alpha_1_unconstr = torch.nn.Parameter(
                torch.from_numpy(alpha_1_unconstr_init)
            )
        self.alpha_1_unconstr.requires_grad = time_dep_transcription_rate

        if lambda_alpha_unconstr_init is None:
            self.lambda_alpha_unconstr = torch.nn.Parameter(0 * torch.ones(n_input))
        else:
            self.lambda_alpha_unconstr = torch.nn.Parameter(
                torch.from_numpy(lambda_alpha_unconstr_init)
            )
        self.lambda_alpha_unconstr.requires_grad = time_dep_transcription_rate

        # likelihood dispersion
        # for now, with normal dist, this is just the variance
        self.scale_unconstr = torch.nn.Parameter(-1 * torch.ones(n_genes, 4))

        use_batch_norm_encoder = use_batch_norm == "encoder" or use_batch_norm == "both"
        use_batch_norm_decoder = use_batch_norm == "decoder" or use_batch_norm == "both"
        use_layer_norm_encoder = use_layer_norm == "encoder" or use_layer_norm == "both"
        use_layer_norm_decoder = use_layer_norm == "decoder" or use_layer_norm == "both"
        self.use_batch_norm_decoder = use_batch_norm_decoder

        # z encoder goes from the n_input-dimensional data to an n_latent-d
        # latent space representation
        n_input_encoder = n_genes
        self.z_encoder = Encoder(
            n_input_encoder,
            n_latent,
            n_layers=n_layers,
            n_hidden=n_hidden,
            dropout_rate=dropout_rate,
            distribution=latent_distribution,
            use_batch_norm=use_batch_norm_encoder,
            use_layer_norm=use_layer_norm_encoder,
            var_activation=var_activation,
            activation_fn=torch.nn.ReLU,
        )

         ### Attributes for masked linear decoder
        self.n_conditions = len(conditions)
        self.conditions = conditions
        self.n_conditions=0
        self.recon_loss = recon_loss
        self.freeze = False
        self.use_bn = use_bn
        self.use_ln = use_ln

        self.use_mmd = False

        self.n_ext_encoder = n_ext + n_ext_m
        self.n_ext_decoder = n_ext
        self.n_ext_m_decoder = n_ext_m

        self.use_hsic = use_hsic and self.n_ext_decoder > 0
        self.hsic_one_vs_all = hsic_one_vs_all

        self.soft_mask = soft_mask and mask is not None
        self.soft_ext_mask = soft_ext_mask and ext_mask is not None

        if decoder_last_layer is None:
            if recon_loss == 'nb':
                self.decoder_last_layer = 'softmax'
            else:
                self.decoder_last_layer = 'identity'
        else:
            self.decoder_last_layer = decoder_last_layer

        self.use_l_encoder = use_l_encoder

        self.dr_rate = dr_rate
        if self.dr_rate > 0:
            self.use_dr = True
        else:
            self.use_dr = False

        if recon_loss == "nb":
            if self.n_conditions != 0:
                self.theta = torch.nn.Parameter(torch.randn(self.n_input, self.n_conditions))
            else:
                self.theta = torch.nn.Parameter(torch.randn(1, self.n_input))
        else:
            self.theta = None

        if self.soft_mask:
            self.n_inact_genes = (1-mask).sum().item()
            soft_shape = mask.shape
            if soft_shape[0] != n_latent or soft_shape[1] != n_input:
                raise ValueError('Incorrect shape of the soft mask.')
            self.mask = mask.t()
            mask = None
        else:
            self.mask = None

        if self.soft_ext_mask:
            self.n_inact_ext_genes = (1-ext_mask).sum().item()
            ext_shape = ext_mask.shape
            if ext_shape[0] != self.n_ext_m_decoder:
                raise ValueError('Dim 0 of ext_mask should be the same as n_ext_m_decoder.')
            if ext_shape[1] != self.n_input:
                raise ValueError('Dim 1 of ext_mask should be the same as n_input.')
            self.ext_mask = ext_mask.t()
            ext_mask = None
        else:
            self.ext_mask = None
            
        # decoder goes from n_latent-dimensional space to n_input-d data
        n_input_decoder = n_latent
        self.decoder = DecoderVELOVI(
            n_input_decoder,
            n_input,
            n_ext = 0,
            n_ext_m= 0,
            n_cond= 0,
            last_layer=None,
            ext_mask = None,
            mask = None,
            recon_loss = 'nb',
            n_cat_list= None,
            n_layers=n_layers,
            n_hidden=n_hidden,
            use_batch_norm=use_batch_norm_decoder,
            use_layer_norm=use_layer_norm_decoder,
            activation_fn=torch.nn.ReLU,
            linear_decoder=linear_decoder,
            )

       


    def _get_inference_input(self, tensors):
        spliced = tensors[REGISTRY_KEYS.X_KEY]
        unspliced = tensors[REGISTRY_KEYS.U_KEY]

        input_dict = dict(
            spliced=spliced,
            unspliced=unspliced,
        )
        return input_dict

    def _get_generative_input(self, tensors, inference_outputs):
        z = inference_outputs["z"]
        gamma = inference_outputs["gamma"]
        beta = inference_outputs["beta"]
        alpha = inference_outputs["alpha"]
        alpha_1 = inference_outputs["alpha_1"]
        lambda_alpha = inference_outputs["lambda_alpha"]

        input_dict = {
            "z": z,
            "gamma": gamma,
            "beta": beta,
            "alpha": alpha,
            "alpha_1": alpha_1,
            "lambda_alpha": lambda_alpha,
        }
        return input_dict

    @auto_move_data
    def inference(
        self,
        spliced,
        unspliced,
        n_samples=1,
    ):
        """
        High level inference method.

        Runs the inference (encoder) model.
        """
        spliced_ = spliced
        unspliced_ = unspliced
        if self.log_variational:
            spliced_ = torch.log(0.01 + spliced)
            unspliced_ = torch.log(0.01 + unspliced)

        encoder_input = torch.cat((spliced_, unspliced_), dim=-1)

        qz_m, qz_v, z = self.z_encoder(encoder_input)

        if n_samples > 1:
            qz_m = qz_m.unsqueeze(0).expand((n_samples, qz_m.size(0), qz_m.size(1)))
            qz_v = qz_v.unsqueeze(0).expand((n_samples, qz_v.size(0), qz_v.size(1)))
            # when z is normal, untran_z == z
            untran_z = Normal(qz_m, qz_v.sqrt()).sample()
            z = self.z_encoder.z_transformation(untran_z)

        gamma, beta, alpha, alpha_1, lambda_alpha = self._get_rates()

        outputs = dict(
            z=z,
            qz_m=qz_m,
            qz_v=qz_v,
            gamma=gamma,
            beta=beta,
            alpha=alpha,
            alpha_1=alpha_1,
            lambda_alpha=lambda_alpha,
        )
        return outputs

    def _get_rates(self):
        # globals
        # degradation
        gamma = torch.clamp(F.softplus(self.gamma_mean_unconstr), 0, 50)
        # splicing
        beta = torch.clamp(F.softplus(self.beta_mean_unconstr), 0, 50)
        # transcription
        alpha = torch.clamp(F.softplus(self.alpha_unconstr), 0, 50)
        if self.time_dep_transcription_rate:
            alpha_1 = torch.clamp(F.softplus(self.alpha_1_unconstr), 0, 50)
            lambda_alpha = torch.clamp(F.softplus(self.lambda_alpha_unconstr), 0, 50)
        else:
            alpha_1 = self.alpha_1_unconstr
            lambda_alpha = self.lambda_alpha_unconstr

        return gamma, beta, alpha, alpha_1, lambda_alpha

    @auto_move_data
    def generative(self, z, gamma, beta, alpha, alpha_1, lambda_alpha, latent_dim=None):
        """Runs the generative model."""
        decoder_input = z
        px_pi_alpha, px_rho, px_tau, dec_mean, dec_latent = self.decoder(decoder_input, latent_dim=latent_dim)

        px_pi = Dirichlet(px_pi_alpha).rsample()

        #dec_mean, dec_latent = self.GP_linear_decoder(decoder_input, batch=None)

        scale_unconstr = self.scale_unconstr
        scale = F.softplus(scale_unconstr)

        mixture_dist_s, mixture_dist_u, end_penalty = self.get_px(
            px_pi,
            px_rho,
            px_tau,
            scale,
            gamma,
            beta,
            alpha,
            alpha_1,
            lambda_alpha,
        )

        return dict(
            px_pi=px_pi,
            px_rho=px_rho,
            px_tau=px_tau,
            scale=scale,
            px_pi_alpha=px_pi_alpha,
            mixture_dist_u=mixture_dist_u,
            mixture_dist_s=mixture_dist_s,
            end_penalty=end_penalty,
            gene_recon = dec_mean,
            dec_latent = dec_latent
        )

    def loss(
        self,
        tensors,
        inference_outputs,
        generative_outputs,
        cond_batch=None,
        kl_weight: float = 1.0,
        n_obs: float = 1.0,
    ):
        spliced = tensors[REGISTRY_KEYS.X_KEY]
        unspliced = tensors[REGISTRY_KEYS.U_KEY]

        #gene reconstruction loss
        ground_truth_counts = spliced + unspliced
        

        if cond_batch is not None:
            dispersion = F.linear(one_hot_encoder(cond_batch, self.n_conditions), self.theta) #batch is the
        else:
            dispersion = self.theta   
        dispersion = torch.exp(dispersion)

        dec_mean = generative_outputs["gene_recon"]
        negbin = NegativeBinomial(mu=dec_mean, theta=dispersion)
        
        gene_recon_loss = -negbin.log_prob(ground_truth_counts).sum(dim=-1)
        

        qz_m = inference_outputs["qz_m"]
        qz_v = inference_outputs["qz_v"]

        px_pi = generative_outputs["px_pi"]
        px_pi_alpha = generative_outputs["px_pi_alpha"]

        end_penalty = generative_outputs["end_penalty"]
        mixture_dist_s = generative_outputs["mixture_dist_s"]
        mixture_dist_u = generative_outputs["mixture_dist_u"]

        kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(0, 1)).sum(dim=1)

        reconst_loss_s = -mixture_dist_s.log_prob(spliced)
        reconst_loss_u = -mixture_dist_u.log_prob(unspliced)
        reconst_loss = reconst_loss_u.sum(dim=-1) + reconst_loss_s.sum(dim=-1) 

        kl_pi = kl(
            Dirichlet(px_pi_alpha),
            Dirichlet(self.dirichlet_concentration * torch.ones_like(px_pi)),
        ).sum(dim=-1)

        # local loss
        kl_local = kl_divergence_z + kl_pi
        weighted_kl_local = kl_weight * (kl_divergence_z) + kl_pi

        local_loss = torch.mean(reconst_loss + gene_recon_loss + weighted_kl_local)

        # combine local and global
        global_loss = 0
        loss = (
            local_loss
            + self.penalty_scale * (1 - kl_weight) * end_penalty
            + (1 / n_obs) * kl_weight * (global_loss)
        )

        loss_recorder = LossRecorder(
            loss, reconst_loss, kl_local, torch.tensor(global_loss)
        )

        return loss_recorder


    @auto_move_data
    def get_px(
        self,
        px_pi,
        px_rho,
        px_tau,
        scale,
        gamma,
        beta,
        alpha,
        alpha_1,
        lambda_alpha,
    ) -> torch.Tensor:

        t_s = torch.clamp(F.softplus(self.switch_time_unconstr), 0, self.t_max)

        n_cells = px_pi.shape[0]

        # component dist
        comp_dist = Categorical(probs=px_pi)

        # induction
        mean_u_ind, mean_s_ind = self._get_induction_unspliced_spliced(
            alpha, alpha_1, lambda_alpha, beta, gamma, t_s * px_rho
        )

        if self.time_dep_transcription_rate:
            mean_u_ind_steady = (alpha_1 / beta).expand(n_cells, self.n_input)
            mean_s_ind_steady = (alpha_1 / gamma).expand(n_cells, self.n_input)
        else:
            mean_u_ind_steady = (alpha / beta).expand(n_cells, self.n_input)
            mean_s_ind_steady = (alpha / gamma).expand(n_cells, self.n_input)
        scale_u = scale[: self.n_input, :].expand(n_cells, self.n_input, 4).sqrt()

        # repression
        u_0, s_0 = self._get_induction_unspliced_spliced(
            alpha, alpha_1, lambda_alpha, beta, gamma, t_s
        )

        tau = px_tau
        mean_u_rep, mean_s_rep = self._get_repression_unspliced_spliced(
            u_0,
            s_0,
            beta,
            gamma,
            (self.t_max - t_s) * tau,
        )
        mean_u_rep_steady = torch.zeros_like(mean_u_ind)
        mean_s_rep_steady = torch.zeros_like(mean_u_ind)
        scale_s = scale[self.n_input :, :].expand(n_cells, self.n_input, 4).sqrt()

        end_penalty = ((u_0 - self.switch_unspliced).pow(2)).sum() + (
            (s_0 - self.switch_spliced).pow(2)
        ).sum()

        # unspliced
        mean_u = torch.stack(
            (
                mean_u_ind,
                mean_u_ind_steady,
                mean_u_rep,
                mean_u_rep_steady,
            ),
            dim=2,
        )
        scale_u = torch.stack(
            (
                scale_u[..., 0],
                scale_u[..., 0],
                scale_u[..., 0],
                0.1 * scale_u[..., 0],
            ),
            dim=2,
        )
        dist_u = Normal(mean_u, scale_u)
        mixture_dist_u = MixtureSameFamily(comp_dist, dist_u)

        # spliced
        mean_s = torch.stack(
            (mean_s_ind, mean_s_ind_steady, mean_s_rep, mean_s_rep_steady),
            dim=2,
        )
        scale_s = torch.stack(
            (
                scale_s[..., 0],
                scale_s[..., 0],
                scale_s[..., 0],
                0.1 * scale_s[..., 0],
            ),
            dim=2,
        )
        dist_s = Normal(mean_s, scale_s)
        mixture_dist_s = MixtureSameFamily(comp_dist, dist_s)

        return mixture_dist_s, mixture_dist_u, end_penalty

    def _get_induction_unspliced_spliced(
        self, alpha, alpha_1, lambda_alpha, beta, gamma, t, eps=1e-6
    ):
        if self.time_dep_transcription_rate:
            unspliced = alpha_1 / beta * (1 - torch.exp(-beta * t)) - (
                alpha_1 - alpha
            ) / (beta - lambda_alpha) * (
                torch.exp(-lambda_alpha * t) - torch.exp(-beta * t)
            )

            spliced = (
                alpha_1 / gamma * (1 - torch.exp(-gamma * t))
                + alpha_1
                / (gamma - beta + eps)
                * (torch.exp(-gamma * t) - torch.exp(-beta * t))
                - beta
                * (alpha_1 - alpha)
                / (beta - lambda_alpha + eps)
                / (gamma - lambda_alpha + eps)
                * (torch.exp(-lambda_alpha * t) - torch.exp(-gamma * t))
                + beta
                * (alpha_1 - alpha)
                / (beta - lambda_alpha + eps)
                / (gamma - beta + eps)
                * (torch.exp(-beta * t) - torch.exp(-gamma * t))
            )
        else:
            unspliced = (alpha / beta) * (1 - torch.exp(-beta * t))
            spliced = (alpha / gamma) * (1 - torch.exp(-gamma * t)) + (
                alpha / ((gamma - beta) + eps)
            ) * (torch.exp(-gamma * t) - torch.exp(-beta * t))

        return unspliced, spliced

    def _get_repression_unspliced_spliced(self, u_0, s_0, beta, gamma, t, eps=1e-6):
        unspliced = torch.exp(-beta * t) * u_0
        spliced = s_0 * torch.exp(-gamma * t) - (
            beta * u_0 / ((gamma - beta) + eps)
        ) * (torch.exp(-gamma * t) - torch.exp(-beta * t))
        return unspliced, spliced

    def sample(
        self,
    ) -> np.ndarray:
        """Not implemented."""
        raise NotImplementedError

    @torch.no_grad()
    def get_loadings(self) -> np.ndarray:
        """Extract per-gene weights (for each Z, shape is genes by dim(Z)) in the linear decoder."""
        # This is BW, where B is diag(b) batch norm, W is weight matrix
        if self.decoder.linear_decoder is False:
            raise ValueError("Model not trained with linear decoder")
        w = self.decoder.rho_first_decoder.fc_layers[0][0].weight
        if self.use_batch_norm_decoder:
            bn = self.decoder.rho_first_decoder.fc_layers[0][1]
            sigma = torch.sqrt(bn.running_var + bn.eps)
            gamma = bn.weight
            b = gamma / sigma
            b_identity = torch.diag(b)
            loadings = torch.matmul(b_identity, w)
        else:
            loadings = w
        loadings = loadings.detach().cpu().numpy()

        return loadings




def _softplus_inverse(x: np.ndarray) -> np.ndarray:
    x = torch.from_numpy(x)
    x_inv = torch.where(x > 20, x, x.expm1().log()).numpy()
    return x_inv




class VELOVI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass):
    """
    Velocity Variational Inference

    Parameters
    ----------
    adata
        AnnData object that has been registered via :func:`~velovi.VELOVI.setup_anndata`.
    n_hidden
        Number of nodes per hidden layer.
    n_latent
        Dimensionality of the latent space.
    n_layers
        Number of hidden layers used for encoder and decoder NNs.
    dropout_rate
        Dropout rate for neural networks.
    gamma_init_data
        Initialize gamma using the data-driven technique.
    linear_decoder
        Use a linear decoder from latent space to time.
    **model_kwargs
        Keyword args for :class:`~velovi.VELOVAE`
    """

    def __init__(
        self,
        adata: AnnData,
        n_hidden: int = 256,
        n_latent: int = 10,
        n_layers: int = 1,
        dropout_rate: float = 0.1,
        gamma_init_data: bool = False,
        linear_decoder: bool = False,
        mask: Optional[Union[np.ndarray, list]] = None,
        mask_key: str = 'I',
        soft_mask: bool = False,
        **model_kwargs,
    ):
        super().__init__(adata)
        self.n_latent = n_latent

        if mask is None and mask_key not in self.adata.varm:
            raise ValueError('Please provide mask.')
        
        if mask is None:
            mask = adata.varm[mask_key].T

        self.mask_ = mask if isinstance(mask, list) else mask.tolist()
        mask = torch.tensor(mask).float()

        self.soft_mask_ = soft_mask

        spliced = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY)
        unspliced = self.adata_manager.get_from_registry(REGISTRY_KEYS.U_KEY)

        sorted_unspliced = np.argsort(unspliced, axis=0)
        ind = int(adata.n_obs * 0.99)
        us_upper_ind = sorted_unspliced[ind:, :]

        us_upper = []
        ms_upper = []
        for i in range(len(us_upper_ind)):
            row = us_upper_ind[i]
            us_upper += [unspliced[row, np.arange(adata.n_vars)][np.newaxis, :]]
            ms_upper += [spliced[row, np.arange(adata.n_vars)][np.newaxis, :]]
        us_upper = np.median(np.concatenate(us_upper, axis=0), axis=0)
        ms_upper = np.median(np.concatenate(ms_upper, axis=0), axis=0)

        alpha_unconstr = _softplus_inverse(us_upper)
        alpha_unconstr = np.asarray(alpha_unconstr).ravel()

        alpha_1_unconstr = np.zeros(us_upper.shape).ravel()
        lambda_alpha_unconstr = np.zeros(us_upper.shape).ravel()

        if gamma_init_data:
            gamma_unconstr = np.clip(_softplus_inverse(us_upper / ms_upper), None, 10)
        else:
            gamma_unconstr = None

        self.module = VELOVAE(
            n_input=self.summary_stats["n_vars"],
            n_hidden=n_hidden,
            n_latent=n_latent,
            n_layers=n_layers,
            dropout_rate=dropout_rate,
            gamma_unconstr_init=gamma_unconstr,
            alpha_unconstr_init=alpha_unconstr,
            alpha_1_unconstr_init=alpha_1_unconstr,
            lambda_alpha_unconstr_init=lambda_alpha_unconstr,
            switch_spliced=ms_upper,
            switch_unspliced=us_upper,
            linear_decoder=linear_decoder,
            mask=mask,
            soft_mask=self.soft_mask_,
            **model_kwargs,
        )
        self._model_summary_string = (
            "VELOVI Model with the following params: \nn_hidden: {}, n_latent: {}, n_layers: {}, dropout_rate: "
            "{}"
        ).format(
            n_hidden,
            n_latent,
            n_layers,
            dropout_rate,
        )
        self.init_params_ = self._get_init_params(locals())

    def train(
        self,
        max_epochs: Optional[int] = 500,
        lr: float = 1e-2,
        weight_decay: float = 1e-2,
        use_gpu: Optional[Union[str, int, bool]] = None,
        train_size: float = 0.9,
        validation_size: Optional[float] = None,
        batch_size: int = 256,
        early_stopping: bool = True,
        gradient_clip_val: float = 10,
        alpha=0.7,
        plan_kwargs: Optional[dict] = None,
        **trainer_kwargs,
    ):
        """
        Train the model.

        Parameters
        ----------
        max_epochs
            Number of passes through the dataset. If `None`, defaults to
            `np.min([round((20000 / n_cells) * 400), 400])`
        lr
            Learning rate for optimization
        weight_decay
            Weight decay for optimization
        use_gpu
            Use default GPU if available (if None or True), or index of GPU to use (if int),
            or name of GPU (if str, e.g., `'cuda:0'`), or use CPU (if False).
        train_size
            Size of training set in the range [0.0, 1.0].
        validation_size
            Size of the test set. If `None`, defaults to 1 - `train_size`. If
            `train_size + validation_size < 1`, the remaining cells belong to a test set.
        batch_size
            Minibatch size to use during training.
        early_stopping
            Perform early stopping. Additional arguments can be passed in `**kwargs`.
            See :class:`~scvi.train.Trainer` for further options.
        gradient_clip_val
            Val for gradient clipping
        plan_kwargs
            Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to
            `train()` will overwrite values present in `plan_kwargs`, when appropriate.
        **trainer_kwargs
            Other keyword args for :class:`~scvi.train.Trainer`.
        """
        user_plan_kwargs = (
            plan_kwargs.copy() if isinstance(plan_kwargs, dict) else dict()
        )
        plan_kwargs = dict(lr=lr, weight_decay=weight_decay, optimizer="AdamW")
        plan_kwargs.update(user_plan_kwargs)

        user_train_kwargs = trainer_kwargs.copy()
        trainer_kwargs = dict(gradient_clip_val=gradient_clip_val)
        trainer_kwargs.update(user_train_kwargs)

        data_splitter = DataSplitter(
            self.adata_manager,
            train_size=train_size,
            validation_size=validation_size,
            batch_size=batch_size,
            use_gpu=use_gpu,
        )
        training_plan = CustomTrainingPlan(self.module, alpha=alpha, **plan_kwargs)

        es = "early_stopping"
        trainer_kwargs[es] = (
            early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es]
        )

        #trainer_kwargs["callbacks"] = MyEarlyStopping()

        runner = TrainRunner(
            self,
            training_plan=training_plan,
            data_splitter=data_splitter,
            max_epochs=max_epochs,
            use_gpu=use_gpu,
            **trainer_kwargs,
        )
        return runner()

    
    #optim = torch.optim.Adam(linear_reg_model.parameters(), lr=0.05)

    def get_loss(self, adata):
        # run the model forward on the data

        adata = self._validate_anndata(adata)
        scdl = self._make_data_loader(
            adata=adata, indices=None, batch_size=256
        )

        for tensors in scdl:    
            inference_outputs, generative_outputs, loss = self.module.forward(
                tensors=tensors,
                compute_loss=True,
                )
        
        return inference_outputs, generative_outputs, loss
        # # calculate the mse loss
        
        # # initialize gradients to zero
        # optim.zero_grad()
        # # backpropagate
        # loss.backward()
        # # take a gradient step
        # optim.step()
        # return loss

    @torch.no_grad()
    def get_state_assignment(
        self,
        adata: Optional[AnnData] = None,
        indices: Optional[Sequence[int]] = None,
        gene_list: Optional[Sequence[str]] = None,
        hard_assignment: bool = False,
        n_samples: int = 20,
        batch_size: Optional[int] = None,
        return_mean: bool = True,
        return_numpy: Optional[bool] = None,
    ) -> Tuple[Union[np.ndarray, pd.DataFrame], List[str]]:
        """
        Returns cells by genes by states probabilities.

        Parameters
        ----------
        adata
            AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
            AnnData object used to initialize the model.
        indices
            Indices of cells in adata to use. If `None`, all cells are used.
        gene_list
            Return frequencies of expression for a subset of genes.
            This can save memory when working with large datasets and few genes are
            of interest.
        hard_assignment
            Return a hard state assignment
        n_samples
            Number of posterior samples to use for estimation.
        batch_size
            Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
        return_mean
            Whether to return the mean of the samples.
        return_numpy
            Return a :class:`~numpy.ndarray` instead of a :class:`~pandas.DataFrame`. DataFrame includes
            gene names as columns. If either `n_samples=1` or `return_mean=True`, defaults to `False`.
            Otherwise, it defaults to `True`.

        Returns
        -------
        If `n_samples` > 1 and `return_mean` is False, then the shape is `(samples, cells, genes)`.
        Otherwise, shape is `(cells, genes)`. In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True.
        """
        adata = self._validate_anndata(adata)
        scdl = self._make_data_loader(
            adata=adata, indices=indices, batch_size=batch_size
        )

        if gene_list is None:
            gene_mask = slice(None)
        else:
            all_genes = adata.var_names
            gene_mask = [True if gene in gene_list else False for gene in all_genes]

        if n_samples > 1 and return_mean is False:
            if return_numpy is False:
                warnings.warn(
                    "return_numpy must be True if n_samples > 1 and return_mean is False, returning np.ndarray"
                )
            return_numpy = True
        if indices is None:
            indices = np.arange(adata.n_obs)

        states = []
        for tensors in scdl:
            minibatch_samples = []
            for _ in range(n_samples):
                _, generative_outputs = self.module.forward(
                    tensors=tensors,
                    compute_loss=True,
                )
                output = generative_outputs["px_pi"]
                output = output[..., gene_mask, :]
                output = output.cpu().numpy()
                minibatch_samples.append(output)
            # samples by cells by genes by four
            states.append(np.stack(minibatch_samples, axis=0))
            if return_mean:
                states[-1] = np.mean(states[-1], axis=0)

        states = np.concatenate(states, axis=0)
        state_cats = [
            "induction",
            "induction_steady",
            "repression",
            "repression_steady",
        ]
        if hard_assignment and return_mean:
            hard_assign = states.argmax(-1)

            hard_assign = pd.DataFrame(
                data=hard_assign, index=adata.obs_names, columns=adata.var_names
            )
            for i, s in enumerate(state_cats):
                hard_assign = hard_assign.replace(i, s)

            states = hard_assign

        return states, state_cats

    @torch.no_grad()
    def get_latent_time(
        self,
        adata: Optional[AnnData] = None,
        indices: Optional[Sequence[int]] = None,
        gene_list: Optional[Sequence[str]] = None,
        time_statistic: Literal["mean", "max"] = "mean",
        n_samples: int = 1,
        n_samples_overall: Optional[int] = None,
        batch_size: Optional[int] = None,
        return_mean: bool = True,
        return_numpy: Optional[bool] = None,
    ) -> Union[np.ndarray, pd.DataFrame]:
        """
        Returns the cells by genes latent time.

        Parameters
        ----------
        adata
            AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
            AnnData object used to initialize the model.
        indices
            Indices of cells in adata to use. If `None`, all cells are used.
        gene_list
            Return frequencies of expression for a subset of genes.
            This can save memory when working with large datasets and few genes are
            of interest.
        time_statistic
            Whether to compute expected time over states, or maximum a posteriori time over maximal
            probability state.
        n_samples
            Number of posterior samples to use for estimation.
        n_samples_overall
            Number of overall samples to return. Setting this forces n_samples=1.
        batch_size
            Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
        return_mean
            Whether to return the mean of the samples.
        return_numpy
            Return a :class:`~numpy.ndarray` instead of a :class:`~pandas.DataFrame`. DataFrame includes
            gene names as columns. If either `n_samples=1` or `return_mean=True`, defaults to `False`.
            Otherwise, it defaults to `True`.

        Returns
        -------
        If `n_samples` > 1 and `return_mean` is False, then the shape is `(samples, cells, genes)`.
        Otherwise, shape is `(cells, genes)`. In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True.
        """
        adata = self._validate_anndata(adata)
        if indices is None:
            indices = np.arange(adata.n_obs)
        if n_samples_overall is not None:
            indices = np.random.choice(indices, n_samples_overall)
        scdl = self._make_data_loader(
            adata=adata, indices=indices, batch_size=batch_size
        )

        if gene_list is None:
            gene_mask = slice(None)
        else:
            all_genes = adata.var_names
            gene_mask = [True if gene in gene_list else False for gene in all_genes]

        if n_samples > 1 and return_mean is False:
            if return_numpy is False:
                warnings.warn(
                    "return_numpy must be True if n_samples > 1 and return_mean is False, returning np.ndarray"
                )
            return_numpy = True
        if indices is None:
            indices = np.arange(adata.n_obs)

        times = []
        for tensors in scdl:
            minibatch_samples = []
            for _ in range(n_samples):
                _, generative_outputs = self.module.forward(
                    tensors=tensors,
                    compute_loss=False,
                )
                pi = generative_outputs["px_pi"]
                ind_prob = pi[..., 0]
                steady_prob = pi[..., 1]
                rep_prob = pi[..., 2]
                # rep_steady_prob = pi[..., 3]
                switch_time = F.softplus(self.module.switch_time_unconstr)

                ind_time = generative_outputs["px_rho"] * switch_time
                rep_time = switch_time + (
                    generative_outputs["px_tau"] * (self.module.t_max - switch_time)
                )

                if time_statistic == "mean":
                    output = (
                        ind_prob * ind_time
                        + rep_prob * rep_time
                        + steady_prob * switch_time
                        # + rep_steady_prob * self.module.t_max
                    )
                else:
                    t = torch.stack(
                        [
                            ind_time,
                            switch_time.expand(ind_time.shape),
                            rep_time,
                            torch.zeros_like(ind_time),
                        ],
                        dim=2,
                    )
                    max_prob = torch.amax(pi, dim=-1)
                    max_prob = torch.stack([max_prob] * 4, dim=2)
                    max_prob_mask = pi.ge(max_prob)
                    output = (t * max_prob_mask).sum(dim=-1)

                output = output[..., gene_mask]
                output = output.cpu().numpy()
                minibatch_samples.append(output)
            # samples by cells by genes by four
            times.append(np.stack(minibatch_samples, axis=0))
            if return_mean:
                times[-1] = np.mean(times[-1], axis=0)

        if n_samples > 1:
            # The -2 axis correspond to cells.
            times = np.concatenate(times, axis=-2)
        else:
            times = np.concatenate(times, axis=0)

        if return_numpy is None or return_numpy is False:
            return pd.DataFrame(
                times,
                columns=adata.var_names[gene_mask],
                index=adata.obs_names[indices],
            )
        else:
            return times

    @torch.no_grad()
    def get_velocity(
        self,
        adata: Optional[AnnData] = None,
        indices: Optional[Sequence[int]] = None,
        gene_list: Optional[Sequence[str]] = None,
        n_samples: int = 1,
        n_samples_overall: Optional[int] = None,
        batch_size: Optional[int] = None,
        return_mean: bool = True,
        return_numpy: Optional[bool] = None,
        velo_statistic: str = "mean",
        velo_mode: Literal["spliced", "unspliced"] = "spliced",
        clip: bool = True,
    ) -> Union[np.ndarray, pd.DataFrame]:
        """
        Returns cells by genes velocity estimates.

        Parameters
        ----------
        adata
            AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
            AnnData object used to initialize the model.
        indices
            Indices of cells in adata to use. If `None`, all cells are used.
        gene_list
            Return velocities for a subset of genes.
            This can save memory when working with large datasets and few genes are
            of interest.
        n_samples
            Number of posterior samples to use for estimation for each cell.
        n_samples_overall
            Number of overall samples to return. Setting this forces n_samples=1.
        batch_size
            Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
        return_mean
            Whether to return the mean of the samples.
        return_numpy
            Return a :class:`~numpy.ndarray` instead of a :class:`~pandas.DataFrame`. DataFrame includes
            gene names as columns. If either `n_samples=1` or `return_mean=True`, defaults to `False`.
            Otherwise, it defaults to `True`.
        velo_statistic
            Whether to compute expected velocity over states, or maximum a posteriori velocity over maximal
            probability state.
        velo_mode
            Compute ds/dt or du/dt.
        clip
            Clip to minus spliced value

        Returns
        -------
        If `n_samples` > 1 and `return_mean` is False, then the shape is `(samples, cells, genes)`.
        Otherwise, shape is `(cells, genes)`. In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True.
        """
        adata = self._validate_anndata(adata)
        if indices is None:
            indices = np.arange(adata.n_obs)
        if n_samples_overall is not None:
            indices = np.random.choice(indices, n_samples_overall)
            n_samples = 1
        scdl = self._make_data_loader(
            adata=adata, indices=indices, batch_size=batch_size
        )

        if gene_list is None:
            gene_mask = slice(None)
        else:
            all_genes = adata.var_names
            gene_mask = [True if gene in gene_list else False for gene in all_genes]

        if n_samples > 1 and return_mean is False:
            if return_numpy is False:
                warnings.warn(
                    "return_numpy must be True if n_samples > 1 and return_mean is False, returning np.ndarray"
                )
            return_numpy = True
        if indices is None:
            indices = np.arange(adata.n_obs)

        velos = []
        for tensors in scdl:
            minibatch_samples = []
            for _ in range(n_samples):
                inference_outputs, generative_outputs = self.module.forward(
                    tensors=tensors,
                    compute_loss=False,
                )
                pi = generative_outputs["px_pi"]
                alpha = inference_outputs["alpha"]
                alpha_1 = inference_outputs["alpha_1"]
                lambda_alpha = inference_outputs["lambda_alpha"]
                beta = inference_outputs["beta"]
                gamma = inference_outputs["gamma"]
                tau = generative_outputs["px_tau"]
                rho = generative_outputs["px_rho"]

                ind_prob = pi[..., 0]
                steady_prob = pi[..., 1]
                rep_prob = pi[..., 2]
                switch_time = F.softplus(self.module.switch_time_unconstr)

                ind_time = switch_time * rho
                u_0, s_0 = self.module._get_induction_unspliced_spliced(
                    alpha, alpha_1, lambda_alpha, beta, gamma, switch_time
                )
                rep_time = (self.module.t_max - switch_time) * tau
                mean_u_rep, mean_s_rep = self.module._get_repression_unspliced_spliced(
                    u_0,
                    s_0,
                    beta,
                    gamma,
                    rep_time,
                )
                if velo_mode == "spliced":
                    velo_rep = beta * mean_u_rep - gamma * mean_s_rep
                else:
                    velo_rep = -beta * mean_u_rep
                mean_u_ind, mean_s_ind = self.module._get_induction_unspliced_spliced(
                    alpha, alpha_1, lambda_alpha, beta, gamma, ind_time
                )
                if velo_mode == "spliced":
                    velo_ind = beta * mean_u_ind - gamma * mean_s_ind
                else:
                    transcription_rate = alpha_1 - (alpha_1 - alpha) * torch.exp(
                        -lambda_alpha * ind_time
                    )
                    velo_ind = transcription_rate - beta * mean_u_ind

                if velo_mode == "spliced":
                    # velo_steady = beta * u_0 - gamma * s_0
                    velo_steady = torch.zeros_like(velo_ind)
                else:
                    # velo_steady = alpha - beta * u_0
                    velo_steady = torch.zeros_like(velo_ind)

                # expectation
                if velo_statistic == "mean":
                    output = (
                        ind_prob * velo_ind
                        + rep_prob * velo_rep
                        + steady_prob * velo_steady
                    )
                # maximum
                else:
                    v = torch.stack(
                        [
                            velo_ind,
                            velo_steady.expand(velo_ind.shape),
                            velo_rep,
                            torch.zeros_like(velo_rep),
                        ],
                        dim=2,
                    )
                    max_prob = torch.amax(pi, dim=-1)
                    max_prob = torch.stack([max_prob] * 4, dim=2)
                    max_prob_mask = pi.ge(max_prob)
                    output = (v * max_prob_mask).sum(dim=-1)

                output = output[..., gene_mask]
                output = output.cpu().numpy()
                minibatch_samples.append(output)
            # samples by cells by genes
            velos.append(np.stack(minibatch_samples, axis=0))
            if return_mean:
                # mean over samples axis
                velos[-1] = np.mean(velos[-1], axis=0)

        if n_samples > 1:
            # The -2 axis correspond to cells.
            velos = np.concatenate(velos, axis=-2)
        else:
            velos = np.concatenate(velos, axis=0)

        spliced = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY)

        if clip:
            velos = np.clip(velos, -spliced[indices], None)

        if return_numpy is None or return_numpy is False:
            return pd.DataFrame(
                velos,
                columns=adata.var_names[gene_mask],
                index=adata.obs_names[indices],
            )
        else:
            return velos

    @torch.no_grad()
    def get_velocity_from_latent(
        self,
        latent_representation: np.ndarray,
        return_numpy: Optional[bool] = None,
        velo_statistic: str = "mean",
        velo_mode: Literal["spliced", "unspliced"] = "spliced",
        clip: bool = True,
    ) -> Union[np.ndarray, pd.DataFrame]:
        r"""
        Returns the normalized (decoded) gene expression.

        This is denoted as :math:`\rho_n` in the scVI paper.

        Parameters
        ----------
        adata
            AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
            AnnData object used to initialize the model.
        return_numpy
            Return a :class:`~numpy.ndarray` instead of a :class:`~pandas.DataFrame`. DataFrame includes
            gene names as columns. If either `n_samples=1` or `return_mean=True`, defaults to `False`.
            Otherwise, it defaults to `True`.
        clip
            Clip to minus spliced value

        Returns
        -------
        If `n_samples` > 1 and `return_mean` is False, then the shape is `(samples, cells, genes)`.
        Otherwise, shape is `(cells, genes)`. In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True.
        """
        adata = AnnData(latent_representation)
        data_key = "Z"
        manager = AnnDataManager(
            [LayerField(data_key, layer=None, is_count_data=False)]
        )
        manager.register_fields(adata)
        scdl = AnnDataLoader(manager)

        gamma, beta, alpha, alpha_1, lambda_alpha = self.module._get_rates()

        velos = []
        for tensors in scdl:
            z = tensors[data_key]
            generative_outputs = self.module.generative(
                z=z,
                gamma=gamma,
                beta=beta,
                alpha=alpha,
                alpha_1=alpha_1,
                lambda_alpha=lambda_alpha,
            )
            pi = generative_outputs["px_pi"]
            tau = generative_outputs["px_tau"]
            rho = generative_outputs["px_rho"]

            ind_prob = pi[..., 0]
            steady_prob = pi[..., 1]
            rep_prob = pi[..., 2]
            switch_time = F.softplus(self.module.switch_time_unconstr)

            ind_time = switch_time * rho
            u_0, s_0 = self.module._get_induction_unspliced_spliced(
                alpha, alpha_1, lambda_alpha, beta, gamma, switch_time
            )
            rep_time = (self.module.t_max - switch_time) * tau
            mean_u_rep, mean_s_rep = self.module._get_repression_unspliced_spliced(
                u_0,
                s_0,
                beta,
                gamma,
                rep_time,
            )
            if velo_mode == "spliced":
                velo_rep = beta * mean_u_rep - gamma * mean_s_rep
            else:
                velo_rep = -beta * mean_u_rep
            mean_u_ind, mean_s_ind = self.module._get_induction_unspliced_spliced(
                alpha, alpha_1, lambda_alpha, beta, gamma, ind_time
            )
            if velo_mode == "spliced":
                velo_ind = beta * mean_u_ind - gamma * mean_s_ind
            else:
                transcription_rate = alpha_1 - (alpha_1 - alpha) * torch.exp(
                    -lambda_alpha * ind_time
                )
                velo_ind = transcription_rate - beta * mean_u_ind

            if velo_mode == "spliced":
                # velo_steady = beta * u_0 - gamma * s_0
                velo_steady = torch.zeros_like(velo_ind)
            else:
                # velo_steady = alpha - beta * u_0
                velo_steady = torch.zeros_like(velo_ind)

            # expectation
            if velo_statistic == "mean":
                output = (
                    ind_prob * velo_ind
                    + rep_prob * velo_rep
                    + steady_prob * velo_steady
                )
            # maximum
            else:
                v = torch.stack(
                    [
                        velo_ind,
                        velo_steady.expand(velo_ind.shape),
                        velo_rep,
                        torch.zeros_like(velo_rep),
                    ],
                    dim=2,
                )
                max_prob = torch.amax(pi, dim=-1)
                max_prob = torch.stack([max_prob] * 4, dim=2)
                max_prob_mask = pi.ge(max_prob)
                output = (v * max_prob_mask).sum(dim=-1)

            # samples by cells by genes
            velos.append(output.cpu().numpy())

        velos = np.concatenate(velos, axis=0)

        if return_numpy is None or return_numpy is False:
            return pd.DataFrame(
                velos,
                columns=self.adata.var_names,
            )
        else:
            return velos

    @torch.no_grad()
    def get_expression_fit(
        self,
        adata: Optional[AnnData] = None,
        indices: Optional[Sequence[int]] = None,
        gene_list: Optional[Sequence[str]] = None,
        n_samples: int = 1,
        batch_size: Optional[int] = None,
        return_mean: bool = True,
        return_numpy: Optional[bool] = None,
        restrict_to_latent_dim: Optional[int] = None,
    ) -> Union[np.ndarray, pd.DataFrame]:
        r"""
        Returns the fitted spliced and unspliced abundance (s(t) and u(t)).

        Parameters
        ----------
        adata
            AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
            AnnData object used to initialize the model.
        indices
            Indices of cells in adata to use. If `None`, all cells are used.
        gene_list
            Return frequencies of expression for a subset of genes.
            This can save memory when working with large datasets and few genes are
            of interest.
        n_samples
            Number of posterior samples to use for estimation.
        batch_size
            Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
        return_mean
            Whether to return the mean of the samples.
        return_numpy
            Return a :class:`~numpy.ndarray` instead of a :class:`~pandas.DataFrame`. DataFrame includes
            gene names as columns. If either `n_samples=1` or `return_mean=True`, defaults to `False`.
            Otherwise, it defaults to `True`.

        Returns
        -------
        If `n_samples` > 1 and `return_mean` is False, then the shape is `(samples, cells, genes)`.
        Otherwise, shape is `(cells, genes)`. In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True.
        """
        adata = self._validate_anndata(adata)

        scdl = self._make_data_loader(
            adata=adata, indices=indices, batch_size=batch_size
        )

        if gene_list is None:
            gene_mask = slice(None)
        else:
            all_genes = adata.var_names
            gene_mask = [True if gene in gene_list else False for gene in all_genes]

        if n_samples > 1 and return_mean is False:
            if return_numpy is False:
                warnings.warn(
                    "return_numpy must be True if n_samples > 1 and return_mean is False, returning np.ndarray"
                )
            return_numpy = True
        if indices is None:
            indices = np.arange(adata.n_obs)

        fits_s = []
        fits_u = []
        for tensors in scdl:
            minibatch_samples_s = []
            minibatch_samples_u = []
            for _ in range(n_samples):
                inference_outputs, generative_outputs = self.module.forward(
                    tensors=tensors,
                    compute_loss=False,
                    generative_kwargs=dict(latent_dim=restrict_to_latent_dim),
                )

                gamma = inference_outputs["gamma"]
                beta = inference_outputs["beta"]
                alpha = inference_outputs["alpha"]
                alpha_1 = inference_outputs["alpha_1"]
                lambda_alpha = inference_outputs["lambda_alpha"]
                px_pi = generative_outputs["px_pi"]
                scale = generative_outputs["scale"]
                px_rho = generative_outputs["px_rho"]
                px_tau = generative_outputs["px_tau"]

                (mixture_dist_s, mixture_dist_u, _,) = self.module.get_px(
                    px_pi,
                    px_rho,
                    px_tau,
                    scale,
                    gamma,
                    beta,
                    alpha,
                    alpha_1,
                    lambda_alpha,
                )
                fit_s = mixture_dist_s.mean
                fit_u = mixture_dist_u.mean

                fit_s = fit_s[..., gene_mask]
                fit_s = fit_s.cpu().numpy()
                fit_u = fit_u[..., gene_mask]
                fit_u = fit_u.cpu().numpy()

                minibatch_samples_s.append(fit_s)
                minibatch_samples_u.append(fit_u)

            # samples by cells by genes
            fits_s.append(np.stack(minibatch_samples_s, axis=0))
            if return_mean:
                # mean over samples axis
                fits_s[-1] = np.mean(fits_s[-1], axis=0)
            # samples by cells by genes
            fits_u.append(np.stack(minibatch_samples_u, axis=0))
            if return_mean:
                # mean over samples axis
                fits_u[-1] = np.mean(fits_u[-1], axis=0)

        if n_samples > 1:
            # The -2 axis correspond to cells.
            fits_s = np.concatenate(fits_s, axis=-2)
            fits_u = np.concatenate(fits_u, axis=-2)
        else:
            fits_s = np.concatenate(fits_s, axis=0)
            fits_u = np.concatenate(fits_u, axis=0)

        if return_numpy is None or return_numpy is False:
            df_s = pd.DataFrame(
                fits_s,
                columns=adata.var_names[gene_mask],
                index=adata.obs_names[indices],
            )
            df_u = pd.DataFrame(
                fits_u,
                columns=adata.var_names[gene_mask],
                index=adata.obs_names[indices],
            )
            return df_s, df_u
        else:
            return fits_s, fits_u

    @torch.no_grad()
    def get_gene_likelihood(
        self,
        adata: Optional[AnnData] = None,
        indices: Optional[Sequence[int]] = None,
        gene_list: Optional[Sequence[str]] = None,
        n_samples: int = 1,
        batch_size: Optional[int] = None,
        return_mean: bool = True,
        return_numpy: Optional[bool] = None,
    ) -> Union[np.ndarray, pd.DataFrame]:
        r"""
        Returns the likelihood per gene. Higher is better.

        This is denoted as :math:`\rho_n` in the scVI paper.

        Parameters
        ----------
        adata
            AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
            AnnData object used to initialize the model.
        indices
            Indices of cells in adata to use. If `None`, all cells are used.
        transform_batch
            Batch to condition on.
            If transform_batch is:

            - None, then real observed batch is used.
            - int, then batch transform_batch is used.
        gene_list
            Return frequencies of expression for a subset of genes.
            This can save memory when working with large datasets and few genes are
            of interest.
        library_size
            Scale the expression frequencies to a common library size.
            This allows gene expression levels to be interpreted on a common scale of relevant
            magnitude. If set to `"latent"`, use the latent libary size.
        n_samples
            Number of posterior samples to use for estimation.
        batch_size
            Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
        return_mean
            Whether to return the mean of the samples.
        return_numpy
            Return a :class:`~numpy.ndarray` instead of a :class:`~pandas.DataFrame`. DataFrame includes
            gene names as columns. If either `n_samples=1` or `return_mean=True`, defaults to `False`.
            Otherwise, it defaults to `True`.

        Returns
        -------
        If `n_samples` > 1 and `return_mean` is False, then the shape is `(samples, cells, genes)`.
        Otherwise, shape is `(cells, genes)`. In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True.
        """
        adata = self._validate_anndata(adata)
        scdl = self._make_data_loader(
            adata=adata, indices=indices, batch_size=batch_size
        )

        if gene_list is None:
            gene_mask = slice(None)
        else:
            all_genes = adata.var_names
            gene_mask = [True if gene in gene_list else False for gene in all_genes]

        if n_samples > 1 and return_mean is False:
            if return_numpy is False:
                warnings.warn(
                    "return_numpy must be True if n_samples > 1 and return_mean is False, returning np.ndarray"
                )
            return_numpy = True
        if indices is None:
            indices = np.arange(adata.n_obs)

        rls = []
        for tensors in scdl:
            minibatch_samples = []
            for _ in range(n_samples):
                inference_outputs, generative_outputs = self.module.forward(
                    tensors=tensors,
                    compute_loss=False,
                )
                spliced = tensors[REGISTRY_KEYS.X_KEY]
                unspliced = tensors[REGISTRY_KEYS.U_KEY]

                gamma = inference_outputs["gamma"]
                beta = inference_outputs["beta"]
                alpha = inference_outputs["alpha"]
                alpha_1 = inference_outputs["alpha_1"]
                lambda_alpha = inference_outputs["lambda_alpha"]
                px_pi = generative_outputs["px_pi"]
                scale = generative_outputs["scale"]
                px_rho = generative_outputs["px_rho"]
                px_tau = generative_outputs["px_tau"]
                dec_mean = generative_outputs["gene_recon"]
                dec_lat = generative_outputs["dec_latent"]

                (mixture_dist_s, mixture_dist_u, _,) = self.module.get_px(
                    px_pi,
                    px_rho,
                    px_tau,
                    scale,
                    gamma,
                    beta,
                    alpha,
                    alpha_1,
                    lambda_alpha,
                )
                reconst_loss_s = -mixture_dist_s.log_prob(spliced)
                reconst_loss_u = -mixture_dist_u.log_prob(unspliced)
                output = -(reconst_loss_s + reconst_loss_u)
                output = output[..., gene_mask]
                output = output.cpu().numpy()
                minibatch_samples.append(output)
            # samples by cells by genes by four
            rls.append(np.stack(minibatch_samples, axis=0))
            if return_mean:
                rls[-1] = np.mean(rls[-1], axis=0)

        rls = np.concatenate(rls, axis=0)
        return rls.shape, dec_mean.shape, dec_lat.shape

    @torch.no_grad()
    def get_rates(self, mean: bool = True):

        gamma, beta, alpha, alpha_1, lambda_alpha = self.module._get_rates()

        return {
            "beta": beta.cpu().numpy(),
            "gamma": gamma.cpu().numpy(),
            "alpha": alpha.cpu().numpy(),
            "alpha_1": alpha_1.cpu().numpy(),
            "lambda_alpha": lambda_alpha.cpu().numpy(),
        }

    @classmethod
    @setup_anndata_dsp.dedent
    def setup_anndata(
        cls,
        adata: AnnData,
        spliced_layer: str,
        unspliced_layer: str,
        **kwargs,
    ) -> Optional[AnnData]:
        """
        %(summary)s.
        Parameters
        ----------
        %(param_adata)s
        spliced_layer
            Layer in adata with spliced normalized expression
        unspliced_layer
            Layer in adata with unspliced normalized expression

        Returns
        -------
        %(returns)s
        """
        setup_method_args = cls._get_setup_method_args(**locals())
        anndata_fields = [
            LayerField(REGISTRY_KEYS.X_KEY, spliced_layer, is_count_data=False),
            LayerField(REGISTRY_KEYS.U_KEY, unspliced_layer, is_count_data=False),
        ]
        adata_manager = AnnDataManager(
            fields=anndata_fields, setup_method_args=setup_method_args
        )
        adata_manager.register_fields(adata, **kwargs)
        cls.register_manager(adata_manager)

    @torch.no_grad()
    @_doc_params(
        doc_differential_expression=doc_differential_expression,
    )
    def differential_velocity(
        self,
        adata: Optional[AnnData] = None,
        groupby: Optional[str] = None,
        group1: Optional[Iterable[str]] = None,
        group2: Optional[str] = None,
        idx1: Optional[Union[Sequence[int], Sequence[bool], str]] = None,
        idx2: Optional[Union[Sequence[int], Sequence[bool], str]] = None,
        mode: Literal["vanilla", "change"] = "vanilla",
        delta: float = 0.25,
        batch_size: Optional[int] = None,
        all_stats: bool = True,
        batch_correction: bool = False,
        batchid1: Optional[Iterable[str]] = None,
        batchid2: Optional[Iterable[str]] = None,
        fdr_target: float = 0.05,
        silent: bool = False,
        **kwargs,
    ) -> pd.DataFrame:
        r"""
        A unified method for differential velocity analysis.

        Implements `"vanilla"` DE [Lopez18]_ and `"change"` mode DE [Boyeau19]_.

        Parameters
        ----------
        {doc_differential_expression}
        **kwargs
            Keyword args for :meth:`scvi.model.base.DifferentialComputation.get_bayes_factors`

        Returns
        -------
        Differential expression DataFrame.
        """
        adata = self._validate_anndata(adata)

        def model_fn(adata, **kwargs):
            if "transform_batch" in kwargs.keys():
                kwargs.pop("transform_batch")
            return partial(
                self.get_velocity,
                batch_size=batch_size,
                n_samples=1,
                return_numpy=True,
                clip=False,
            )(adata, **kwargs)

        col_names = adata.var_names

        result = _de_core(
            self.get_anndata_manager(adata, required=True),
            model_fn,
            groupby,
            group1,
            group2,
            idx1,
            idx2,
            all_stats,
            scrna_raw_counts_properties,
            col_names,
            mode,
            batchid1,
            batchid2,
            delta,
            batch_correction,
            fdr_target,
            silent,
            **kwargs,
        )

        return result

    @torch.no_grad()
    def differential_transition(
        self,
        groupby: str,
        group1: str,
        group2: str,
        adata: Optional[AnnData] = None,
        batch_size: Optional[int] = None,
        n_samples: Optional[int] = 5000,
    ) -> pd.DataFrame:
        adata = self._validate_anndata(adata)
        adata_manager = self.get_anndata_manager(adata, required=True)

        if not isinstance(group1, str):
            raise ValueError("Group 1 must be a string")

        cell_idx1 = (adata.obs[groupby] == group1).to_numpy().ravel()
        if group2 is None:
            cell_idx2 = ~cell_idx1
        else:
            cell_idx2 = (adata.obs[groupby] == group2).to_numpy().ravel()

        indices1 = np.random.choice(
            np.asarray(np.where(cell_idx1)[0].ravel()), n_samples
        )
        indices2 = np.random.choice(
            np.asarray(np.where(cell_idx2)[0].ravel()), n_samples
        )

        velo1 = self.get_velocity(
            adata,
            return_numpy=True,
            indices=indices1,
            n_samples=1,
            batch_size=batch_size,
        )
        velo1 = velo1 - velo1.mean(1)[:, np.newaxis]
        velo2 = self.get_velocity(
            adata,
            return_numpy=True,
            indices=indices2,
            n_samples=1,
            batch_size=batch_size,
        )
        velo2 = velo2 - velo2.mean(1)[:, np.newaxis]

        spliced = adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY)
        delta12 = spliced[indices2] - spliced[indices1]
        delta12 = delta12 - delta12.mean(1)[:, np.newaxis]

        delta21 = spliced[indices1] - spliced[indices2]
        delta21 = delta21 - delta21.mean(1)[:, np.newaxis]

        # TODO: Make more efficient
        correlation12 = np.diagonal(cosine_similarity(velo1, delta12))
        correlation21 = np.diagonal(cosine_similarity(velo2, delta21))

        return correlation12, correlation21

    def get_loadings(self) -> pd.DataFrame:
        """
        Extract per-gene weights in the linear decoder.

        Shape is genes by `n_latent`.
        """
        cols = ["Z_{}".format(i) for i in range(self.n_latent)]
        var_names = self.adata.var_names
        loadings = pd.DataFrame(
            self.module.get_loadings(), index=var_names, columns=cols
        )

        return loadings

    def get_variance_explained(
        self,
        adata: Optional[AnnData] = None,
        labels_key: Optional[str] = None,
        n_samples: int = 10,
    ) -> pd.DataFrame:

        if self.module.decoder.linear_decoder is False:
            raise ValueError("Model not trained with linear decoder")
        adata = self._validate_anndata(adata)
        adata_manager = self.get_anndata_manager(adata)
        n_latent = self.module.n_latent

        if labels_key is not None:
            groups = np.unique(adata.obs[labels_key])
        else:
            groups = [None]

        spliced = adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY)
        unspliced = adata_manager.get_from_registry(REGISTRY_KEYS.U_KEY)

        centered_s = spliced - spliced.mean(0)
        centered_u = unspliced - unspliced.mean(0)

        def r_squared(true_s, pred_s, true_u, pred_u, centered_s, centered_u):
            rss_s = np.sum((true_s - pred_s) ** 2)
            tss_s = np.sum(centered_s**2)
            rss_u = np.sum((true_u - pred_u) ** 2)
            tss_u = np.sum(centered_u**2)

            return (1 - (rss_s + rss_u) / (tss_s + tss_u)) * 100

        df_out = pd.DataFrame(
            data=np.zeros((n_latent, len(groups))),
            index=[f"Z_{i}" for i in range(n_latent)],
            columns=groups if groups[0] is not None else ["0"],
        )

        for i in range(n_latent):
            fitted_s, fitted_u = self.get_expression_fit(
                adata, restrict_to_latent_dim=i, n_samples=n_samples, return_numpy=True
            )
            for j, g in enumerate(groups):
                if g is None:
                    subset = slice(None)
                else:
                    subset = adata.obs[labels_key] == g
                r_2 = r_squared(
                    spliced[subset],
                    fitted_s[subset],
                    unspliced[subset],
                    fitted_u[subset],
                    centered_s[subset],
                    centered_u[subset],
                )
                df_out.iloc[i, j] = r_2

        return df_out

    def get_directional_uncertainty(
        self,
        adata: Optional[AnnData] = None,
        n_samples: int = 50,
        gene_list: Iterable[str] = None,
        n_jobs: int = -1,
    ):

        adata = self._validate_anndata(adata)

        logger.info("Sampling from model...")
        velocities_all = self.get_velocity(
            n_samples=n_samples, return_mean=False, gene_list=gene_list
        )  # (n_samples, n_cells, n_genes)

        df, cosine_sims = _compute_directional_statistics_tensor(
            tensor=velocities_all, n_jobs=n_jobs, n_cells=adata.n_obs
        )
        df.index = adata.obs_names

        return df, cosine_sims

    def get_permutation_scores(
        self, labels_key: str, adata: Optional[AnnData] = None
    ) -> Tuple[pd.DataFrame, AnnData]:
        """
        Compute permutation scores.

        Parameters
        ----------
        labels_key
            Key in adata.obs encoding cell types
        adata
            AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
            AnnData object used to initialize the model.

        Returns
        -------
        Tuple of DataFrame and AnnData. DataFrame is genes by cell types with score per cell type.
        AnnData is the permutated version of the original AnnData.
        """
        adata = self._validate_anndata(adata)
        adata_manager = self.get_anndata_manager(adata)
        if labels_key not in adata.obs:
            raise ValueError(f"{labels_key} not found in adata.obs")

        # shuffle spliced then unspliced
        bdata = self._shuffle_layer_celltype(
            adata_manager, labels_key, REGISTRY_KEYS.X_KEY
        )
        bdata_manager = self.get_anndata_manager(bdata)
        bdata = self._shuffle_layer_celltype(
            bdata_manager, labels_key, REGISTRY_KEYS.U_KEY
        )
        bdata_manager = self.get_anndata_manager(bdata)

        ms_ = adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY)
        mu_ = adata_manager.get_from_registry(REGISTRY_KEYS.U_KEY)

        ms_p = bdata_manager.get_from_registry(REGISTRY_KEYS.X_KEY)
        mu_p = bdata_manager.get_from_registry(REGISTRY_KEYS.U_KEY)

        spliced_, unspliced_ = self.get_expression_fit(adata, n_samples=10)
        root_squared_error = np.abs(spliced_ - ms_)
        root_squared_error += np.abs(unspliced_ - mu_)

        spliced_p, unspliced_p = self.get_expression_fit(bdata, n_samples=10)
        root_squared_error_p = np.abs(spliced_p - ms_p)
        root_squared_error_p += np.abs(unspliced_p - mu_p)

        celltypes = np.unique(adata.obs[labels_key])

        dynamical_df = pd.DataFrame(
            index=adata.var_names,
            columns=celltypes,
            data=np.zeros((adata.shape[1], len(celltypes))),
        )
        N = 200
        for ct in celltypes:
            for g in adata.var_names.tolist():
                x = root_squared_error_p[g][adata.obs[labels_key] == ct]
                y = root_squared_error[g][adata.obs[labels_key] == ct]
                ratio = ttest_ind(x[:N], y[:N])[0]
                dynamical_df.loc[g, ct] = ratio

        return dynamical_df, bdata

    def _shuffle_layer_celltype(
        self, adata_manager: AnnDataManager, labels_key: str, registry_key: str
    ) -> AnnData:
        """Shuffle cells within cell types for each gene."""
        from scvi.data._constants import _SCVI_UUID_KEY

        bdata = adata_manager.adata.copy()
        labels = bdata.obs[labels_key]
        del bdata.uns[_SCVI_UUID_KEY]
        self._validate_anndata(bdata)
        bdata_manager = self.get_anndata_manager(bdata)

        # get registry info to later set data back in bdata
        # in a way that doesn't require actual knowledge of location
        unspliced = bdata_manager.get_from_registry(registry_key)
        u_registry = bdata_manager.data_registry[registry_key]
        attr_name = u_registry.attr_name
        attr_key = u_registry.attr_key

        for lab in np.unique(labels):
            mask = np.asarray(labels == lab)
            unspliced_ct = unspliced[mask].copy()
            unspliced_ct = np.apply_along_axis(
                np.random.permutation, axis=0, arr=unspliced_ct
            )
            unspliced[mask] = unspliced_ct
        # e.g., if using adata.X
        if attr_key is None:
            setattr(bdata, attr_name, unspliced)
        # e.g., if using a layer
        elif attr_key is not None:
            attribute = getattr(bdata, attr_name)
            attribute[attr_key] = unspliced
            setattr(bdata, attr_name, attribute)

        return bdata


def _compute_directional_statistics_tensor(
    tensor: np.ndarray, n_jobs: int, n_cells: int
) -> pd.DataFrame:
    df = pd.DataFrame(index=np.arange(n_cells))
    df["directional_variance"] = np.nan
    df["directional_difference"] = np.nan
    df["directional_cosine_sim_variance"] = np.nan
    df["directional_cosine_sim_difference"] = np.nan
    df["directional_cosine_sim_mean"] = np.nan
    logger.info("Computing the uncertainties...")
    results = Parallel(n_jobs=n_jobs, verbose=3)(
        delayed(_directional_statistics_per_cell)(tensor[:, cell_index, :])
        for cell_index in range(n_cells)
    )
    # cells by samples
    cosine_sims = np.stack([results[i][0] for i in range(n_cells)])
    df.loc[:, "directional_cosine_sim_variance"] = [
        results[i][1] for i in range(n_cells)
    ]
    df.loc[:, "directional_cosine_sim_difference"] = [
        results[i][2] for i in range(n_cells)
    ]
    df.loc[:, "directional_variance"] = [results[i][3] for i in range(n_cells)]
    df.loc[:, "directional_difference"] = [results[i][4] for i in range(n_cells)]
    df.loc[:, "directional_cosine_sim_mean"] = [results[i][5] for i in range(n_cells)]

    return df, cosine_sims


def _directional_statistics_per_cell(
    tensor: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Internal function for parallelization.

    Parameters
    ----------
    tensor
        Shape of samples by genes for a given cell.
    """
    n_samples = tensor.shape[0]
    # over samples axis
    mean_velocity_of_cell = tensor.mean(0)
    cosine_sims = [
        _cosine_sim(tensor[i, :], mean_velocity_of_cell) for i in range(n_samples)
    ]
    angle_samples = [np.arccos(el) for el in cosine_sims]
    return (
        cosine_sims,
        np.var(cosine_sims),
        np.percentile(cosine_sims, 95) - np.percentile(cosine_sims, 5),
        np.var(angle_samples),
        np.percentile(angle_samples, 95) - np.percentile(angle_samples, 5),
        np.mean(cosine_sims),
    )


def _centered_unit_vector(vector: np.ndarray) -> np.ndarray:
    """Returns the centered unit vector of the vector."""
    vector = vector - np.mean(vector)
    return vector / np.linalg.norm(vector)


def _cosine_sim(v1: np.ndarray, v2: np.ndarray) -> np.ndarray:
    """Returns cosine similarity of the vectors."""
    v1_u = _centered_unit_vector(v1)
    v2_u = _centered_unit_vector(v2)
    return np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)



In [30]:
from scvi import train

class ProxGroupLasso:
    def __init__(self, alpha, omega=None, inplace=True):
    # omega - vector of coefficients with size
    # equal to the number of groups
        if omega is None:
            self._group_coeff = alpha
        else:
            self._group_coeff = (omega*alpha).view(-1)

        # to check for update
        self._alpha = alpha

        self._inplace = inplace

    def __call__(self, W):
        if not self._inplace:
            W = W.clone()

        norm_vect = W.norm(p=2, dim=0)
        norm_g_gr_vect = norm_vect>self._group_coeff

        scaled_norm_vector = norm_vect/self._group_coeff
        scaled_norm_vector+=(~(scaled_norm_vector>0)).float()

        W-=W/scaled_norm_vector
        W*=norm_g_gr_vect.float()

        return W


class ProxL1:
    def __init__(self, alpha, I=None, inplace=True):
        self._I = ~I.bool() if I is not None else None
        self._alpha=alpha
        self._inplace=inplace

    def __call__(self, W):
        if not self._inplace:
            W = W.clone()

        W_geq_alpha = W>=self._alpha
        W_leq_neg_alpha = W<=-self._alpha
        W_cond_joint = ~W_geq_alpha&~W_leq_neg_alpha

        if self._I is not None:
            W_geq_alpha &= self._I
            W_leq_neg_alpha &= self._I
            W_cond_joint &= self._I

        W -= W_geq_alpha.float()*self._alpha
        W += W_leq_neg_alpha.float()*self._alpha
        W -= W_cond_joint.float()*W

        return W

class CustomTrainingPlan(train.TrainingPlan):
    def __init__(self, 
            model,
            alpha,
            omega=None,
            alpha_l1=None,
            alpha_l1_epoch_anneal=None,
            alpha_l1_anneal_each=5,
            gamma_ext=None,
            gamma_epoch_anneal=None,
            gamma_anneal_each=5,
            beta=1.,
            print_stats=False,
            **kwargs):
        super().__init__(model, **kwargs)

        self.model=model
        self.print_stats = print_stats

        self.alpha = alpha
        self.omega = omega

        if self.omega is not None:
            self.omega = self.omega.to(self.device)

        # self.gamma_ext = gamma_ext
        # self.gamma_epoch_anneal = gamma_epoch_anneal
        # self.gamma_anneal_each = gamma_anneal_each

        self.alpha_l1 = alpha_l1
        self.alpha_l1_epoch_anneal = alpha_l1_epoch_anneal
        self.alpha_l1_anneal_each = alpha_l1_anneal_each

        # if self.model.use_hsic:
        #     self.beta = beta
        # else:
        #     self.beta = None

        self.watch_lr = None

        self.use_prox_ops = self.check_prox_ops()
        self.prox_ops = {}

        self.corr_coeffs = self.init_anneal()

        print(f"init corr coeffs: {self.corr_coeffs}")

    def check_prox_ops(self):
        use_prox_ops = {}

        use_main = self.model.decoder.L0.expr_L.weight.requires_grad

        use_prox_ops['main_group_lasso'] = use_main and self.alpha is not None

        use_mask = use_main and self.model.mask is not None
        use_prox_ops['main_soft_mask'] = use_mask and self.alpha_l1 is not None

        # use_ext_m = self.model.n_ext_m_decoder > 0 and self.alpha_l1 is not None
        # use_ext_m = use_ext_m and self.model.decoder.L0.ext_L_m.weight.requires_grad
        # use_prox_ops['ext_soft_mask'] = use_ext_m and self.model.ext_mask is not None

        return use_prox_ops

    def init_anneal(self):
        corr_coeffs = {}

        use_soft_mask = self.use_prox_ops['main_soft_mask'] #or self.use_prox_ops['ext_soft_mask']
        if use_soft_mask and self.alpha_l1_epoch_anneal is not None:
            corr_coeffs['alpha_l1'] = 1. / self.alpha_l1_epoch_anneal
        else:
            corr_coeffs['alpha_l1'] = 1.

        # if self.use_prox_ops['ext_unannot_l1'] and self.gamma_epoch_anneal is not None:
        #     corr_coeffs['gamma_ext'] = 1. / self.gamma_epoch_anneal
        # else:
        #     corr_coeffs['gamma_ext'] = 1.

        return corr_coeffs

    def anneal(self):
        any_change = False

        # if self.corr_coeffs['gamma_ext'] < 1.:
        #     any_change = True
        #     time_to_anneal = self.epoch > 0 and self.epoch % self.gamma_anneal_each == 0
        #     if time_to_anneal:
        #         self.corr_coeffs['gamma_ext'] = min(self.epoch / self.gamma_epoch_anneal, 1.)
        #         if self.print_stats:
        #             print('New gamma_ext anneal coefficient:', self.corr_coeffs['gamma_ext'])

        if self.corr_coeffs['alpha_l1'] < 1.:
            any_change = True
            time_to_anneal = self.epoch > 0 and self.epoch % self.self.alpha_l1_anneal_each == 0
            if time_to_anneal:
                self.corr_coeffs['alpha_l1'] = min(self.epoch / self.alpha_l1_epoch_anneal, 1.)
                if self.print_stats:
                    print('New alpha_l1 anneal coefficient:', self.corr_coeffs['alpha_l1'])

        return any_change

    def init_prox_ops(self):
        if any(self.use_prox_ops.values()) and self.watch_lr is None:
            self.watch_lr = self.optimizer.param_groups[0]['lr']

        if 'main_group_lasso' not in self.prox_ops and self.use_prox_ops['main_group_lasso']:
            print('Init the group lasso proximal operator for the main terms.')
            alpha_corr = self.alpha * self.watch_lr
            self.prox_ops['main_group_lasso'] = ProxGroupLasso(alpha_corr, self.omega)

        if 'main_soft_mask' not in self.prox_ops and self.use_prox_ops['main_soft_mask']:
            print('Init the soft mask proximal operator for the main terms.')
            main_mask = self.model.mask.to(self.device)
            alpha_l1_corr = self.alpha_l1 * self.watch_lr * self.corr_coeffs['alpha_l1']
            self.prox_ops['main_soft_mask'] = ProxL1(alpha_l1_corr, main_mask)

        # if 'ext_unannot_l1' not in self.prox_ops and self.use_prox_ops['ext_unannot_l1']:
        #     print('Init the L1 proximal operator for the unannotated extension.')
        #     gamma_ext_corr = self.gamma_ext * self.watch_lr * self.corr_coeffs['gamma_ext']
        #     self.prox_ops['ext_unannot_l1'] = ProxL1(gamma_ext_corr)

        # if 'ext_soft_mask' not in self.prox_ops and self.use_prox_ops['ext_soft_mask']:
        #     print('Init the soft mask proximal operator for the annotated extension.')
        #     ext_mask = self.model.ext_mask.to(self.device)
        #     alpha_l1_corr = self.alpha_l1 * self.watch_lr * self.corr_coeffs['alpha_l1']
        #     self.prox_ops['ext_soft_mask'] = ProxL1(alpha_l1_corr, ext_mask)

    def update_prox_ops(self):
        if 'main_group_lasso' in self.prox_ops:
            alpha_corr = self.alpha * self.watch_lr
            if self.prox_ops['main_group_lasso']._alpha != alpha_corr:
                self.prox_ops['main_group_lasso'] = ProxGroupLasso(alpha_corr, self.omega)

        # if 'ext_unannot_l1' in self.prox_ops:
        #     gamma_ext_corr = self.gamma_ext * self.watch_lr * self.corr_coeffs['gamma_ext']
        #     if self.prox_ops['ext_unannot_l1']._alpha != gamma_ext_corr:
        #         self.prox_ops['ext_unannot_l1']._alpha = gamma_ext_corr

        for mask_key in ('main_soft_mask'):#, 'ext_soft_mask'):
            if mask_key in self.prox_ops:
                alpha_l1_corr = self.alpha_l1 * self.watch_lr * self.corr_coeffs['alpha_l1']
                if self.prox_ops[mask_key]._alpha != alpha_l1_corr:
                    self.prox_ops[mask_key]._alpha = alpha_l1_corr

    def apply_prox_ops(self):
        if 'main_soft_mask' in self.prox_ops:
            self.prox_ops['main_soft_mask'](self.model.decoder.L0.expr_L.weight.data)
        if 'main_group_lasso' in self.prox_ops:
            self.prox_ops['main_group_lasso'](self.model.decoder.L0.expr_L.weight.data)
        # if 'ext_unannot_l1' in self.prox_ops:
        #     self.prox_ops['ext_unannot_l1'](self.model.decoder.L0.ext_L.weight.data)
        # if 'ext_soft_mask' in self.prox_ops:
        #     self.prox_ops['ext_soft_mask'](self.model.decoder.L0.ext_L_m.weight.data)

        def training_step(self, batch, batch_idx, optimizer_idx=0):
            self.init_prox_ops()
            super().training_step(batch, batch_idx, optimizer_idx=0)
            self.apply_prox_ops()
            print("applied prox ops")



Get AnnData

In [10]:
adata = scv.datasets.pancreas()

In [11]:
scv.pp.filter_and_normalize(adata, min_shared_counts=30, n_top_genes=2000)
scv.pp.moments(adata, n_pcs=30, n_neighbors=30)

Filtered out 21611 genes that are detected 30 counts (shared).
Normalized count data: X, spliced, unspliced.
Extracted 2000 highly variable genes.
Logarithmized X.
computing neighbors
    finished (0:00:19) --> added 
    'distances' and 'connectivities', weighted adjacency matrices (adata.obsp)
computing moments based on connectivities
    finished (0:00:01) --> added 
    'Ms' and 'Mu', moments of un/spliced abundances (adata.layers)


In [12]:
adata = preprocess_data(adata)

computing velocities
    finished (0:00:00) --> added 
    'velocity', velocity vectors for each individual cell (adata.layers)


Get annotations

In [13]:
import gdown
url = 'https://drive.google.com/uc?id=1136LntaVr92G1MphGeMVcmpE0AqcqM6c'
output = 'reactome.gmt'
gdown.download(url, output, quiet=False)

Downloading...
From: https://drive.google.com/uc?id=1136LntaVr92G1MphGeMVcmpE0AqcqM6c
To: /home/chels/thesis/repos/velo_interpret/velovi/reactome.gmt
100%|██████████| 331k/331k [00:00<00:00, 1.84MB/s]


'reactome.gmt'

In [14]:
import scarches as sca
sca.utils.add_annotations(adata, 'reactome.gmt', min_genes=12, clean=True)

In [15]:
#Remove all genes not present in annotations

adata._inplace_subset_var(adata.varm['I'].sum(1)>0)

In [16]:
#Filter out terms with less than 12 genes 

select_terms = adata.varm['I'].sum(0)>12
adata.uns['terms'] = np.array(adata.uns['terms'])[select_terms].tolist()
adata.varm['I'] = adata.varm['I'][:, select_terms]

In [None]:
#TODO: Use only highly variable genes

In [32]:
VELOVI.setup_anndata(adata, spliced_layer="Ms", unspliced_layer="Mu")
vae = VELOVI(adata)

	with soft mask.
Last Decoder layer: softmax


In [33]:
vae.train()

init corr coeffs: {'alpha_l1': 1.0}


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Epoch 3/500:   0%|          | 2/500 [03:23<14:35:19, 105.46s/it, loss=655, v_num=1]