# Utils

In [4]:
from scipy.stats import chi2_contingency
from scipy.stats import mannwhitneyu
from scipy.stats import ttest_ind
from sklearn.utils import shuffle
import matplotlib.pyplot as plt
import statsmodels.api as sm
import anndata as AnnData
import scrublet as scr
import cellrank as cr
import seaborn as sns
import scvelo as scv
import pandas as pd
import scanpy as sc
import numpy as np
import matplotlib
import networkx
import fsspec
import igraph
import scvi
import desc 
import umap
import h5py
import os



Global seed set to 0
  new_rank_zero_deprecation(
  return new_rank_zero_deprecation(*args, **kwargs)


In [1]:
import collections
import math
import numpy as np
import pandas as pd
from typing import Iterable
import torch
from torch import nn as nn
from torch.distributions import Normal
import warnings
from typing import Optional, Tuple, Union
import torch.nn.functional as F
from torch.distributions import Distribution, Gamma, Poisson, constraints
from torch.distributions.utils import (
    broadcast_all,
    lazy_property,
    logits_to_probs,
    probs_to_logits,
)


def make_labels_ccle(metapath, expar, barcodes):
    metadf = pd.read_csv(metapath, sep="\t", index_col=0)
    if "Site_Primary" in metadf.columns:
        metadf["CellType"] = metadf["Site_Primary"]
        metadf["Barcode"] = metadf.index
    classes = np.unique(list(metadf["CellType"]))
    classes = np.array(
        [each for each in classes if "nan" not in each])
    metadf = metadf[metadf["CellType"].isin(classes)]
    metadf = metadf[metadf["Barcode"].isin(barcodes)]
    new_barcodes, idx_1, idx_2 = np.intersect1d(
        barcodes, np.array(metadf["Barcode"]),
        return_indices=True)
    outar = expar[idx_1, :]
    outdf = metadf.iloc[idx_2, :]
    out_barcodes = np.array(barcodes, dtype="|U64")[idx_1]
    one_hot = pd.get_dummies(outdf["CellType"])
    one_hot_tensor = torch.from_numpy(np.array(one_hot))
    return outar, outdf, out_barcodes, one_hot_tensor


def one_hot(index, n_cat):
    onehot = torch.zeros(index.size(0), n_cat, device=index.device)
    onehot.scatter_(1, index.type(torch.long), 1)
    return onehot.type(torch.float32)


def reparameterize_gaussian(mu, var):
    return Normal(mu, var.sqrt()).rsample()


def identity(x):
    return x


class CustomConnected(nn.Module):
    def __init__(self, inputsize, hiddensize, connections):
        super().__init__()
        self.inputsize = inputsize
        self.hiddensize = hiddensize
        # Connections in TF x Gene (binary)
        self.connections = connections
        # Weights in TF x Gene dimension
        weights = torch.Tensor(self.hiddensize, self.inputsize)
        self.weights = nn.Parameter(weights)
        bias = torch.Tensor(self.hiddensize)
        self.bias = nn.Parameter(bias)
        # Initialize weights
        nn.init.kaiming_uniform_(
            self.weights, a=math.sqrt(5),
            nonlinearity='leaky_relu')
        # Initialize bias with union distribution
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weights)
        bound = 1 / math.sqrt(fan_in)
        nn.init.uniform_(self.bias, -bound, bound)  # bias init

    def forward(self, x):
        enforced_weights = torch.mul(
            self.weights, self.connections.detach())
        ew_times_x = torch.mm(x, enforced_weights.detach().t())
        return torch.add(ew_times_x, self.bias)


class FCLayersEncoder(nn.Module):
    """
    A helper class to build fully-connected layers for a neural network.

    Parameters
    ----------
    n_in
        The dimensionality of the input
    n_out
        The dimensionality of the output
    n_cat_list
        A list containing, for each category of interest,
        the number of categories. Each category will be
        included using a one-hot encoding.
    n_layers
        The number of fully-connected hidden layers
    n_hidden
        The number of nodes per hidden layer
    connections
        A boolean tensor indeicating weights
        to set to zero
    dropout_rate
        Dropout rate to apply to each of the hidden layers
    use_batch_norm
        Whether to have `BatchNorm` layers or not
    use_layer_norm
        Whether to have `LayerNorm` layers or not
    use_activation
        Whether to have layer activation or not
    bias
        Whether to learn bias in linear layers or not
    inject_covariates
        Whether to inject covariates in each layer,
        or just the first (default).
    activation_fn
        Which activation function to use
    """

    def __init__(
        self,
        n_in: int,
        n_out: int,
        n_cat_list: Iterable[int] = None,
        n_layers: int = 1,
        n_hidden: int = 128,
        connections=None,
        dropout_rate: float = 0.1,
        use_batch_norm: bool = True,
        use_layer_norm: bool = False,
        use_activation: bool = True,
        bias: bool = True,
        inject_covariates: bool = True,
        activation_fn: nn.Module = nn.ReLU,
    ):
        super().__init__()
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        connect_dim2 = n_in
        if n_cat_list is not None:
            connect_dim2 += n_cat_list[0]
        if connections is None:
            connections = torch.ones(n_hidden, connect_dim2).long().to(device)
        else:
            if connections.shape[1] < connect_dim2:
                print("Appending connections")
                temp_tensor = torch.zeros(n_hidden, connect_dim2 - n_in).long().to(device)
                connections = torch.cat((connections, temp_tensor), axis=-1)

        self.connections = connections
        self.inject_covariates = inject_covariates
        layers_dim = [n_in] + (n_layers - 1) * [n_hidden] + [n_out]

        if n_cat_list is not None:
            # n_cat = 1 will be ignored
            self.n_cat_list = [
                n_cat if n_cat > 1 else 0 for n_cat in n_cat_list]
        else:
            self.n_cat_list = []

        cat_dim = sum(self.n_cat_list)
        self.fc_layers = nn.Sequential(
            collections.OrderedDict(
                [
                    (
                        "Layer_{}".format(i),
                        nn.Sequential(
                            CustomConnected(
                                n_in + cat_dim * self.inject_into_layer(i),
                                n_out,
                                connections
                            ),
                            # non-default params come from defaults in
                            # original Tensorflow implementation
                            nn.BatchNorm1d(n_out, momentum=0.01, eps=0.001)
                            if use_batch_norm
                            else None,
                            nn.LayerNorm(n_out, elementwise_affine=False)
                            if use_layer_norm
                            else None,
                            activation_fn() if use_activation else None,
                            nn.Dropout(
                                p=dropout_rate) if dropout_rate > 0 else None,
                        ),
                    )
                    for i, (n_in, n_out) in enumerate(
                        zip(layers_dim[:-1], layers_dim[1:])
                    )
                ]
            )
        )

    def inject_into_layer(self, layer_num) -> bool:
        """Helper to determine if covariates should be injected."""
        user_cond = layer_num == 0 or (
            layer_num > 0 and self.inject_covariates)
        return user_cond

    def set_online_update_hooks(self, hook_first_layer=True):
        self.hooks = []

        def _hook_fn_weight(grad):
            categorical_dims = sum(self.n_cat_list)
            new_grad = torch.zeros_like(grad)
            if categorical_dims > 0:
                new_grad[:, -categorical_dims:] = grad[:, -categorical_dims:]
            return new_grad

        def _hook_fn_zero_out(grad):
            return grad * 0

        for i, layers in enumerate(self.fc_layers):
            # if i > 0 and not self.inject_covariates:
            #     break
            for layer in layers:
                if i == 0 and not hook_first_layer:
                    continue
                if isinstance(layer, nn.Linear):
                    if self.inject_into_layer(i):
                        w = layer.weight.register_hook(_hook_fn_weight)
                    else:
                        w = layer.weight.register_hook(_hook_fn_zero_out)
                    self.hooks.append(w)
                    b = layer.bias.register_hook(_hook_fn_zero_out)
                    self.hooks.append(b)

    def forward(self, x: torch.Tensor, batch_index):
        """
        Forward computation on ``x``.

        Parameters
        ----------
        x
            tensor of values with shape ``(n_in,)``
        batch_index
            tensor of batch membership(s)
        x: torch.Tensor

        Returns
        -------
        py:class:`torch.Tensor`
            tensor of shape ``(n_out,)``

        """
        cat_list = [batch_index]
        one_hot_cat_list = []
        # for generality in this list many indices useless.

        if len(self.n_cat_list) > len(cat_list):
            print(self.n_cat_list)
            print(cat_list)
            raise ValueError(
                "nb. categorical args provided doesn't match init. params."
            )
        for n_cat, cat in zip(self.n_cat_list, cat_list):
            if n_cat and cat is None:
                raise ValueError(
                    "cat not provided while n_cat != 0 in init. params.")
            # n_cat = 1 will be ignored - no additional information
            if n_cat > 1:
                if cat.size(1) != n_cat:
                    one_hot_cat = one_hot(cat, n_cat)
                else:
                    one_hot_cat = cat  # cat has already been one_hot encoded
                one_hot_cat_list += [one_hot_cat]
        for i, layers in enumerate(self.fc_layers):
            for layer in layers:
                if layer is not None:
                    if isinstance(layer, nn.BatchNorm1d):
                        if x.dim() == 3:
                            x = torch.cat(
                                [(layer(slice_x)).unsqueeze(0)
                                 for slice_x in x], dim=0
                            )
                        else:
                            x = layer(x)
                    else:
                        if isinstance(layer, nn.Linear) or\
                                isinstance(layer, CustomConnected) and\
                                self.inject_into_layer(i):
                            if x.dim() == 3:
                                one_hot_cat_list_layer = [
                                    o.unsqueeze(0).expand(
                                        (x.size(0), o.size(0), o.size(1))
                                    )
                                    for o in one_hot_cat_list
                                ]
                            else:
                                one_hot_cat_list_layer = one_hot_cat_list
                            x = torch.cat((x, *one_hot_cat_list_layer), dim=-1)
                        x = layer(x)
        return x


class FCLayers(nn.Module):
    """
    A helper class to build fully-connected layers for a neural network.

    Parameters
    ----------
    n_in
        The dimensionality of the input
    n_out
        The dimensionality of the output
    n_cat_list
        A list containing, for each category of interest,
        the number of categories. Each category will be
        included using a one-hot encoding.
    n_layers
        The number of fully-connected hidden layers
    n_hidden
        The number of nodes per hidden layer
    dropout_rate
        Dropout rate to apply to each of the hidden layers
    use_batch_norm
        Whether to have `BatchNorm` layers or not
    use_layer_norm
        Whether to have `LayerNorm` layers or not
    use_activation
        Whether to have layer activation or not
    bias
        Whether to learn bias in linear layers or not
    inject_covariates
        Whether to inject covariates in each
        layer, or just the first (default).
    activation_fn
        Which activation function to use
    """

    def __init__(
        self,
        n_in: int,
        n_out: int,
        n_cat_list: Iterable[int] = None,
        n_layers: int = 1,
        n_hidden: int = 128,
        dropout_rate: float = 0.1,
        use_batch_norm: bool = True,
        use_layer_norm: bool = False,
        use_activation: bool = True,
        bias: bool = True,
        inject_covariates: bool = True,
        activation_fn: nn.Module = nn.ReLU,
    ):
        super().__init__()
        self.inject_covariates = inject_covariates
        layers_dim = [n_in] + (n_layers - 1) * [n_hidden] + [n_out]

        if n_cat_list is not None:
            # n_cat = 1 will be ignored
            self.n_cat_list = [
                n_cat if n_cat > 1 else 0 for n_cat in n_cat_list]
        else:
            self.n_cat_list = []

        cat_dim = sum(self.n_cat_list)
        self.fc_layers = nn.Sequential(
            collections.OrderedDict(
                [
                    (
                        "Layer_{}".format(i),
                        nn.Sequential(
                            nn.Linear(
                                n_in + cat_dim * self.inject_into_layer(i),
                                n_out,
                                bias=bias,
                            ),
                            # non-default params come from defaults in
                            # original Tensorflow implementation
                            nn.BatchNorm1d(n_out, momentum=0.01, eps=0.001)
                            if use_batch_norm
                            else None,
                            nn.LayerNorm(n_out, elementwise_affine=False)
                            if use_layer_norm
                            else None,
                            activation_fn() if use_activation else None,
                            nn.Dropout(
                                p=dropout_rate) if dropout_rate > 0 else None,
                        ),
                    )
                    for i, (n_in, n_out) in enumerate(
                        zip(layers_dim[:-1], layers_dim[1:])
                    )
                ]
            )
        )

    def inject_into_layer(self, layer_num) -> bool:
        """Helper to determine if covariates should be injected."""
        user_cond = layer_num == 0 or\
            (layer_num > 0 and self.inject_covariates)
        return user_cond

    def set_online_update_hooks(self, hook_first_layer=True):
        self.hooks = []

        def _hook_fn_weight(grad):
            categorical_dims = sum(self.n_cat_list)
            new_grad = torch.zeros_like(grad)
            if categorical_dims > 0:
                new_grad[:, -categorical_dims:] = grad[:, -categorical_dims:]
            return new_grad

        def _hook_fn_zero_out(grad):
            return grad * 0

        for i, layers in enumerate(self.fc_layers):
            # if i > 0 and not self.inject_covariates:
            #     break
            for layer in layers:
                if i == 0 and not hook_first_layer:
                    continue
                if isinstance(layer, nn.Linear):
                    if self.inject_into_layer(i):
                        w = layer.weight.register_hook(_hook_fn_weight)
                    else:
                        w = layer.weight.register_hook(_hook_fn_zero_out)
                    self.hooks.append(w)
                    b = layer.bias.register_hook(_hook_fn_zero_out)
                    self.hooks.append(b)

    def forward(self, x: torch.Tensor, *cat_list: int):
        """
        Forward computation on ``x``.

        Parameters
        ----------
        x
            tensor of values with shape ``(n_in,)``
        cat_list
            list of category membership(s) for this sample
        x: torch.Tensor

        Returns
        -------
        py:class:`torch.Tensor`
            tensor of shape ``(n_out,)``

        """
        # for generality in this list many indices useless
        one_hot_cat_list = []

        if len(self.n_cat_list) > len(cat_list):
            print(self.n_cat_list)
            print(cat_list)
            raise ValueError(
                "nb. categorical args provided doesn't match init. params."
            )
        for n_cat, cat in zip(self.n_cat_list, cat_list):
            if n_cat and cat is None:
                raise ValueError(
                    "cat not provided while n_cat != 0 in init. params.")
            # n_cat = 1 will be ignored - no additional information
            if n_cat > 1:
                if cat.size(1) != n_cat:
                    one_hot_cat = one_hot(cat, n_cat)
                else:
                    one_hot_cat = cat  # cat has already been one_hot encoded
                one_hot_cat_list += [one_hot_cat]
        for i, layers in enumerate(self.fc_layers):
            for layer in layers:
                if layer is not None:
                    if isinstance(layer, nn.BatchNorm1d):
                        if x.dim() == 3:
                            x = torch.cat(
                                [(layer(slice_x)).unsqueeze(0)
                                 for slice_x in x], dim=0
                            )
                        else:
                            x = layer(x)
                    else:
                        if isinstance(layer, nn.Linear) and\
                                self.inject_into_layer(i):
                            if x.dim() == 3:
                                one_hot_cat_list_layer = [
                                    o.unsqueeze(0).expand(
                                        (x.size(0), o.size(0), o.size(1))
                                    )
                                    for o in one_hot_cat_list
                                ]
                            else:
                                one_hot_cat_list_layer = one_hot_cat_list
                            x = torch.cat((x, *one_hot_cat_list_layer), dim=-1)
                        x = layer(x)
        return x


class Encoder(nn.Module):
    """
    Encodes data of ``n_input`` dimensions into a
    latent space of ``n_output`` dimensions.

    Uses a fully-connected neural network of ``n_hidden`` layers.

    Parameters
    ----------
    n_input
        The dimensionality of the input (data space)
    n_output
        The dimensionality of the output (latent space)
    n_cat_list
        A list containing the number of categories
        for each category of interest. Each category will be
        included using a one-hot encoding
    n_layers
        The number of fully-connected hidden layers
    n_hidden
        The number of nodes per hidden layer
    dropout_rate
        Dropout rate to apply to each of the hidden layers
    distribution
        Distribution of z
    **kwargs
        Keyword args for :class:`~scvi.modules._base.FCLayers`
    """

    def __init__(
        self,
        n_input: int,
        n_output: int,
        connections=None,
        n_cat_list: Iterable[int] = None,
        n_layers: int = 1,
        n_hidden: int = 128,
        dropout_rate: float = 0.1,
        distribution: str = "normal",
        **kwargs,
    ):
        super().__init__()

        self.distribution = distribution
        self.encoder = FCLayersEncoder(
            n_in=n_input,
            n_out=n_hidden,
            connections=connections,
            n_cat_list=n_cat_list,
            n_layers=n_layers,
            n_hidden=n_hidden,
            dropout_rate=dropout_rate,
            **kwargs,
        )
        self.mean_encoder = nn.Linear(n_hidden, n_output)
        self.var_encoder = nn.Linear(n_hidden, n_output)

        if distribution == "ln":
            self.z_transformation = nn.Softmax(dim=-1)
        else:
            self.z_transformation = identity

    def forward(self, x: torch.Tensor, batch_index):
        r"""
        The forward computation for a single sample.

         #. Encodes the data into latent space using the encoder network
         #. Generates a mean \\( q_m \\) and variance \\( q_v \\)
         #. Samples a new value from an i.i.d. multivariate
         normal \\( \\sim Ne(q_m, \\mathbf{I}q_v) \\)

        Parameters
        ----------
        x
            tensor with shape (n_input,)
        batch_index
            tensor of batch membership

        Returns
        -------
        3-tuple of :py:class:`torch.Tensor`
            tensors of shape ``(n_latent,)`` for mean and var, and sample

        """
        # Parameters for latent distribution
        q = self.encoder(x, batch_index)
        q_m = self.mean_encoder(q)
        q_v = torch.exp(self.var_encoder(q)) + 1e-4
        latent = self.z_transformation(reparameterize_gaussian(q_m, q_v))
        return q_m, q_v, latent


def log_zinb_positive(
    x: torch.Tensor, mu: torch.Tensor,
    theta: torch.Tensor, pi: torch.Tensor, eps=1e-8
):
    """
    Log likelihood (scalar) of a minibatch according to a zinb model.

    Parameters
    ----------
    x
        Data
    mu
        mean of the negative binomial (has to be
        positive support) (shape: minibatch x vars)
    theta
        inverse dispersion parameter (has to be
        positive support) (shape: minibatch x vars)
    pi
        logit of the dropout parameter (real support)
        (shape: minibatch x vars)
    eps
        numerical stability constant

    Notes
    -----
    We parametrize the bernoulli using the logits,
    hence the softplus functions appearing.
    """
    # theta is the dispersion rate. If .ndimension() == 1,
    # it is shared for all cells (regardless of batch or labels)
    if theta.ndimension() == 1:
        theta = theta.view(
            1, theta.size(0)
        )  # In this case, we reshape theta for broadcasting

    softplus_pi = F.softplus(-pi)  #  uses log(sigmoid(x)) = -softplus(-x)
    log_theta_eps = torch.log(theta + eps)
    log_theta_mu_eps = torch.log(theta + mu + eps)
    pi_theta_log = -pi + theta * (log_theta_eps - log_theta_mu_eps)

    case_zero = F.softplus(pi_theta_log) - softplus_pi
    mul_case_zero = torch.mul((x < eps).type(torch.float32), case_zero)

    case_non_zero = (
        -softplus_pi
        + pi_theta_log
        + x * (torch.log(mu + eps) - log_theta_mu_eps)
        + torch.lgamma(x + theta)
        - torch.lgamma(theta)
        - torch.lgamma(x + 1)
    )
    mul_case_non_zero = torch.mul((x > eps).type(torch.float32), case_non_zero)

    res = mul_case_zero + mul_case_non_zero

    return res


def log_nb_positive(x: torch.Tensor, mu: torch.Tensor, theta: torch.Tensor, eps=1e-8):
    """
    Log likelihood (scalar) of a minibatch according to a nb model.

    Parameters
    ----------
    x
        data
    mu
        mean of the negative binomial (has to be positive support) (shape: minibatch x vars)
    theta
        inverse dispersion parameter (has to be positive support) (shape: minibatch x vars)
    eps
        numerical stability constant

    Notes
    -----
    We parametrize the bernoulli using the logits, hence the softplus functions appearing.

    """
    if theta.ndimension() == 1:
        theta = theta.view(
            1, theta.size(0)
        )  # In this case, we reshape theta for broadcasting

    log_theta_mu_eps = torch.log(theta + mu + eps)

    res = (
        theta * (torch.log(theta + eps) - log_theta_mu_eps)
        + x * (torch.log(mu + eps) - log_theta_mu_eps)
        + torch.lgamma(x + theta)
        - torch.lgamma(theta)
        - torch.lgamma(x + 1)
    )

    return res


def log_mixture_nb(
    x: torch.Tensor,
    mu_1: torch.Tensor,
    mu_2: torch.Tensor,
    theta_1: torch.Tensor,
    theta_2: torch.Tensor,
    pi_logits: torch.Tensor,
    eps=1e-8,
):
    """
    Log likelihood (scalar) of a minibatch according to a mixture nb model.

    pi_logits is the probability (logits) to be in the first component.
    For totalVI, the first component should be background.

    Parameters
    ----------
    x
        Observed data
    mu_1
        Mean of the first negative binomial component (has to be positive support) (shape: minibatch x features)
    mu_2
        Mean of the second negative binomial (has to be positive support) (shape: minibatch x features)
    theta_1
        First inverse dispersion parameter (has to be positive support) (shape: minibatch x features)
    theta_2
        Second inverse dispersion parameter (has to be positive support) (shape: minibatch x features)
        If None, assume one shared inverse dispersion parameter.
    pi_logits
        Probability of belonging to mixture component 1 (logits scale)
    eps
        Numerical stability constant
    """
    if theta_2 is not None:
        log_nb_1 = log_nb_positive(x, mu_1, theta_1)
        log_nb_2 = log_nb_positive(x, mu_2, theta_2)
    # this is intended to reduce repeated computations
    else:
        theta = theta_1
        if theta.ndimension() == 1:
            theta = theta.view(
                1, theta.size(0)
            )  # In this case, we reshape theta for broadcasting

        log_theta_mu_1_eps = torch.log(theta + mu_1 + eps)
        log_theta_mu_2_eps = torch.log(theta + mu_2 + eps)
        lgamma_x_theta = torch.lgamma(x + theta)
        lgamma_theta = torch.lgamma(theta)
        lgamma_x_plus_1 = torch.lgamma(x + 1)

        log_nb_1 = (
            theta * (torch.log(theta + eps) - log_theta_mu_1_eps)
            + x * (torch.log(mu_1 + eps) - log_theta_mu_1_eps)
            + lgamma_x_theta
            - lgamma_theta
            - lgamma_x_plus_1
        )
        log_nb_2 = (
            theta * (torch.log(theta + eps) - log_theta_mu_2_eps)
            + x * (torch.log(mu_2 + eps) - log_theta_mu_2_eps)
            + lgamma_x_theta
            - lgamma_theta
            - lgamma_x_plus_1
        )

    logsumexp = torch.logsumexp(torch.stack((log_nb_1, log_nb_2 - pi_logits)), dim=0)
    softplus_pi = F.softplus(-pi_logits)

    log_mixture_nb = logsumexp - softplus_pi

    return log_mixture_nb


def _convert_mean_disp_to_counts_logits(mu, theta, eps=1e-6):
    r"""
    NB parameterizations conversion.

    Parameters
    ----------
    mu
        mean of the NB distribution.
    theta
        inverse overdispersion.
    eps
        constant used for numerical log stability. (Default value = 1e-6)

    Returns
    -------
    type
        the number of failures until the experiment is stopped
        and the success probability.
    """
    if not (mu is None) == (theta is None):
        raise ValueError(
            "If using the mu/theta NB parameterization, both parameters must be specified"
        )
    logits = (mu + eps).log() - (theta + eps).log()
    total_count = theta
    return total_count, logits


def _convert_counts_logits_to_mean_disp(total_count, logits):
    """
    NB parameterizations conversion.

    Parameters
    ----------
    total_count
        Number of failures until the experiment is stopped.
    logits
        success logits.

    Returns
    -------
    type
        the mean and inverse overdispersion of the NB distribution.

    """
    theta = total_count
    mu = logits.exp() * theta
    return mu, theta


def _gamma(theta, mu):
    concentration = theta
    rate = theta / mu
    # Important remark: Gamma is parametrized by the rate = 1/scale!
    gamma_d = Gamma(concentration=concentration, rate=rate)
    return gamma_d


class NegativeBinomial(Distribution):
    r"""
    Negative binomial distribution.

    One of the following parameterizations must be provided:

    (1), (`total_count`, `probs`) where `total_count` is the number of failures until
    the experiment is stopped and `probs` the success probability. (2), (`mu`, `theta`)
    parameterization, which is the one used by scvi-tools. These parameters respectively
    control the mean and inverse dispersion of the distribution.

    In the (`mu`, `theta`) parameterization, samples from the negative binomial are generated as follows:

    1. :math:`w \sim \textrm{Gamma}(\underbrace{\theta}_{\text{shape}}, \underbrace{\theta/\mu}_{\text{rate}})`
    2. :math:`x \sim \textrm{Poisson}(w)`

    Parameters
    ----------
    total_count
        Number of failures until the experiment is stopped.
    probs
        The success probability.
    mu
        Mean of the distribution.
    theta
        Inverse dispersion.
    validate_args
        Raise ValueError if arguments do not match constraints
    """

    arg_constraints = {
        "mu": constraints.greater_than_eq(0),
        "theta": constraints.greater_than_eq(0),
    }
    support = constraints.nonnegative_integer

    def __init__(
        self,
        total_count: Optional[torch.Tensor] = None,
        probs: Optional[torch.Tensor] = None,
        logits: Optional[torch.Tensor] = None,
        mu: Optional[torch.Tensor] = None,
        theta: Optional[torch.Tensor] = None,
        validate_args: bool = False,
    ):
        self._eps = 1e-8
        if (mu is None) == (total_count is None):
            raise ValueError(
                "Please use one of the two possible parameterizations. Refer to the documentation for more information."
            )

        using_param_1 = total_count is not None and (
            logits is not None or probs is not None
        )
        if using_param_1:
            logits = logits if logits is not None else probs_to_logits(probs)
            total_count = total_count.type_as(logits)
            total_count, logits = broadcast_all(total_count, logits)
            mu, theta = _convert_counts_logits_to_mean_disp(total_count, logits)
        else:
            mu, theta = broadcast_all(mu, theta)
        self.mu = mu
        self.theta = theta
        super().__init__(validate_args=validate_args)

    @property
    def mean(self):
        return self.mu

    @property
    def variance(self):
        return self.mean + (self.mean ** 2) / self.theta

    def sample(
        self, sample_shape: Union[torch.Size, Tuple] = torch.Size()
    ) -> torch.Tensor:
        with torch.no_grad():
            gamma_d = self._gamma()
            p_means = gamma_d.sample(sample_shape)

            # Clamping as distributions objects can have buggy behaviors when
            # their parameters are too high
            l_train = torch.clamp(p_means, max=1e8)
            counts = Poisson(
                l_train
            ).sample()  # Shape : (n_samples, n_cells_batch, n_vars)
            return counts

    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
        if self._validate_args:
            try:
                self._validate_sample(value)
            except ValueError:
                warnings.warn(
                    "The value argument must be within the support of the distribution",
                    UserWarning,
                )
        return log_nb_positive(value, mu=self.mu, theta=self.theta, eps=self._eps)

    def _gamma(self):
        return _gamma(self.theta, self.mu)


class ZeroInflatedNegativeBinomial(NegativeBinomial):
    r"""
    Zero-inflated negative binomial distribution.

    One of the following parameterizations must be provided:

    (1), (`total_count`, `probs`) where `total_count` is the number of failures until
    the experiment is stopped and `probs` the success probability. (2), (`mu`, `theta`)
    parameterization, which is the one used by scvi-tools. These parameters respectively
    control the mean and inverse dispersion of the distribution.

    In the (`mu`, `theta`) parameterization, samples from the negative binomial are generated as follows:

    1. :math:`w \sim \textrm{Gamma}(\underbrace{\theta}_{\text{shape}}, \underbrace{\theta/\mu}_{\text{rate}})`
    2. :math:`x \sim \textrm{Poisson}(w)`

    Parameters
    ----------
    total_count
        Number of failures until the experiment is stopped.
    probs
        The success probability.
    mu
        Mean of the distribution.
    theta
        Inverse dispersion.
    zi_logits
        Logits scale of zero inflation probability.
    validate_args
        Raise ValueError if arguments do not match constraints
    """

    arg_constraints = {
        "mu": constraints.greater_than_eq(0),
        "theta": constraints.greater_than_eq(0),
        "zi_probs": constraints.half_open_interval(0.0, 1.0),
        "zi_logits": constraints.real,
    }
    support = constraints.nonnegative_integer

    def __init__(
        self,
        total_count: Optional[torch.Tensor] = None,
        probs: Optional[torch.Tensor] = None,
        logits: Optional[torch.Tensor] = None,
        mu: Optional[torch.Tensor] = None,
        theta: Optional[torch.Tensor] = None,
        zi_logits: Optional[torch.Tensor] = None,
        validate_args: bool = False,
    ):

        super().__init__(
            total_count=total_count,
            probs=probs,
            logits=logits,
            mu=mu,
            theta=theta,
            validate_args=validate_args,
        )
        self.zi_logits, self.mu, self.theta = broadcast_all(
            zi_logits, self.mu, self.theta
        )

    @property
    def mean(self):
        pi = self.zi_probs
        return (1 - pi) * self.mu

    @property
    def variance(self):
        raise NotImplementedError

    @lazy_property
    def zi_logits(self) -> torch.Tensor:
        return probs_to_logits(self.zi_probs, is_binary=True)

    @lazy_property
    def zi_probs(self) -> torch.Tensor:
        return logits_to_probs(self.zi_logits, is_binary=True)

    def sample(
        self, sample_shape: Union[torch.Size, Tuple] = torch.Size()
    ) -> torch.Tensor:
        with torch.no_grad():
            samp = super().sample(sample_shape=sample_shape)
            is_zero = torch.rand_like(samp) <= self.zi_probs
            samp[is_zero] = 0.0
            return samp

    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
        try:
            self._validate_sample(value)
        except ValueError:
            warnings.warn(
                "The value argument must be within the support of the distribution",
                UserWarning,
            )
        return log_zinb_positive(value, self.mu, self.theta, self.zi_logits, eps=1e-08)


class NegativeBinomialMixture(Distribution):
    """
    Negative binomial mixture distribution.

    See :class:`~scvi.distributions.NegativeBinomial` for further description
    of parameters.

    Parameters
    ----------
    mu1
        Mean of the component 1 distribution.
    mu2
        Mean of the component 2 distribution.
    theta1
        Inverse dispersion for component 1.
    mixture_logits
        Logits scale probability of belonging to component 1.
    theta2
        Inverse dispersion for component 1. If `None`, assumed to be equal to `theta1`.
    validate_args
        Raise ValueError if arguments do not match constraints
    """

    arg_constraints = {
        "mu1": constraints.greater_than_eq(0),
        "mu2": constraints.greater_than_eq(0),
        "theta1": constraints.greater_than_eq(0),
        "mixture_probs": constraints.half_open_interval(0.0, 1.0),
        "mixture_logits": constraints.real,
    }
    support = constraints.nonnegative_integer

    def __init__(
        self,
        mu1: torch.Tensor,
        mu2: torch.Tensor,
        theta1: torch.Tensor,
        mixture_logits: torch.Tensor,
        theta2: Optional[torch.Tensor] = None,
        validate_args: bool = False,
    ):

        (
            self.mu1,
            self.theta1,
            self.mu2,
            self.mixture_logits,
        ) = broadcast_all(mu1, theta1, mu2, mixture_logits)

        super().__init__(validate_args=validate_args)

        if theta2 is not None:
            self.theta2 = broadcast_all(mu1, theta2)
        else:
            self.theta2 = None

    @property
    def mean(self):
        pi = self.mixture_probs
        return pi * self.mu1 + (1 - pi) * self.mu2

    @lazy_property
    def mixture_probs(self) -> torch.Tensor:
        return logits_to_probs(self.mixture_logits, is_binary=True)

    def sample(
        self, sample_shape: Union[torch.Size, Tuple] = torch.Size()
    ) -> torch.Tensor:
        with torch.no_grad():
            pi = self.mixture_probs
            mixing_sample = torch.distributions.Bernoulli(pi).sample()
            mu = self.mu1 * mixing_sample + self.mu2 * (1 - mixing_sample)
            if self.theta2 is None:
                theta = self.theta1
            else:
                theta = self.theta1 * mixing_sample + self.theta2 * (1 - mixing_sample)
            gamma_d = _gamma(mu, theta)
            p_means = gamma_d.sample(sample_shape)

            # Clamping as distributions objects can have buggy behaviors when
            # their parameters are too high
            l_train = torch.clamp(p_means, max=1e8)
            counts = Poisson(
                l_train
            ).sample()  # Shape : (n_samples, n_cells_batch, n_features)
            return counts

    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
        try:
            self._validate_sample(value)
        except ValueError:
            warnings.warn(
                "The value argument must be within the support of the distribution",
                UserWarning,
            )
        return log_mixture_nb(
            value,
            self.mu1,
            self.mu2,
            self.theta1,
            self.theta2,
            self.mixture_logits,
            eps=1e-08,
        )


In [2]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.distributions import Normal, Poisson
from torch.distributions import kl_divergence as kl
from typing import Dict, Tuple
from typing import Iterable, Optional
#from utils import Encoder, FCLayers
#from utils import NegativeBinomial, ZeroInflatedNegativeBinomial


try:
    from typing import Literal
except ImportError:
    try:
        from typing_extensions import Literal
    except ImportError:

        class LiteralMeta(type):
            def __getitem__(cls, values):
                if not isinstance(values, tuple):
                    values = (values,)
                return type("Literal_", (Literal,), dict(__args__=values))

        class Literal(metaclass=LiteralMeta):
            pass


class DecoderSCVI(nn.Module):
    """
    Decodes data from latent space of ``n_input``
    dimensions ``n_output``dimensions.

    Uses a fully-connected neural network of ``n_hidden`` layers.

    Parameters
    ----------
    n_input
        The dimensionality of the input (latent space)
    n_output
        The dimensionality of the output (data space)
    n_cat_list
        A list containing the number of categories
        for each category of interest. Each category will be
        included using a one-hot encoding
    n_layers
        The number of fully-connected hidden layers
    n_hidden
        The number of nodes per hidden layer
    dropout_rate
        Dropout rate to apply to each of the hidden layers
    inject_covariates
        Whether to inject covariates in each layer,
        or just the first (default).
    use_batch_norm
        Whether to use batch norm in layers
    use_layer_norm
        Whether to use layer norm in layers
    """

    def __init__(
        self,
        n_input: int,
        n_output: int,
        n_cat_list: Iterable[int] = None,
        n_layers: int = 1,
        n_hidden: int = 128,
        inject_covariates: bool = True,
        use_batch_norm: bool = False,
        use_layer_norm: bool = False,
    ):
        super().__init__()
        self.px_decoder = FCLayers(
            n_in=n_input,
            n_out=n_hidden,
            n_cat_list=n_cat_list,
            n_layers=n_layers,
            n_hidden=n_hidden,
            dropout_rate=0,
            inject_covariates=inject_covariates,
            use_batch_norm=use_batch_norm,
            use_layer_norm=use_layer_norm,
        )

        self.n_cat_list = n_cat_list

        # mean gamma
        self.px_scale_decoder = nn.Sequential(
            nn.Linear(n_hidden, n_output),
            nn.Softmax(dim=-1),
        )

        # dispersion: here we only deal with gene-cell dispersion case
        self.px_r_decoder = nn.Linear(n_hidden, n_output)

        # dropout
        self.px_dropout_decoder = nn.Linear(n_hidden, n_output)

    def forward(
        self, dispersion: str, z: torch.Tensor,
        library: torch.Tensor, batch_index
    ):
        """
        The forward computation for a single sample.

         #. Decodes the data from the latent space using the decoder network
         #. Returns parameters for the ZINB distribution of expression
         #. If ``dispersion != 'gene-cell'`` then value
         for that param will be ``None``

        Parameters
        ----------
        dispersion
            One of the following

            * ``'gene'`` - dispersion parameter of NB is
            constant per gene across cells
            * ``'gene-batch'`` - dispersion can differ
            between different batches
            * ``'gene-label'`` - dispersion can differ
            between different labels
            * ``'gene-cell'`` - dispersion can differ
            for every gene in every cell
        z :
            tensor with shape ``(n_input,)``
        library
            library size
        cat_list
            list of category membership(s) for this sample

        Returns
        -------
        4-tuple of :py:class:`torch.Tensor`
            parameters for the ZINB distribution of expression

        """
        # The decoder returns values for the parameters
        # of the ZINB distribution
        px = self.px_decoder(z, batch_index)
        px_scale = self.px_scale_decoder(px)
        px_dropout = self.px_dropout_decoder(px)
        # Clamp to high value: exp(12) ~ 160000 to
        # avoid nans (computational stability)
        px_rate = torch.exp(library) * px_scale  # torch.clamp( , max=12)
        px_r = self.px_r_decoder(px) if dispersion == "gene-cell" else None
        return px_scale, px_r, px_rate, px_dropout, px


def one_hot(index, n_cat):
    onehot = torch.zeros(index.size(0), n_cat, device=index.device)
    onehot.scatter_(1, index.type(torch.long), 1)
    return onehot.type(torch.float32)


def reparameterize_gaussian(mu, var):
    return Normal(mu, var.sqrt()).rsample()


def identity(x):
    return x


def loss_function(
        qz_m, qz_v, x, px_rate, px_r, px_dropout,
        ql_m, ql_v, use_observed_lib_size,
        local_l_mean, local_l_var):
    mean = torch.zeros_like(qz_m)
    scale = torch.ones_like(qz_v)
    kl_divergence_z = kl(
        Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(
            dim=1)
    if not use_observed_lib_size:
        kl_divergence_l = kl(
            Normal(ql_m, torch.sqrt(ql_v)),
            Normal(local_l_mean, torch.sqrt(local_l_var)),
        ).sum(dim=1)
    else:
        kl_divergence_l = 0.0
    kl_divergence = kl_divergence_z
    reconst_loss = get_reconstruction_loss(
        x, px_rate, px_r, px_dropout)
    return reconst_loss + kl_divergence_l, kl_divergence


def get_reconstruction_loss(
        x, px_rate, px_r, px_dropout, gene_likelihood="zinb"):
    # Reconstruction Loss
    if gene_likelihood == "zinb":
        reconst_loss = (
            -ZeroInflatedNegativeBinomial(
                mu=px_rate, theta=px_r, zi_logits=px_dropout
            )
            .log_prob(x)
            .sum(dim=-1)
        )
    elif gene_likelihood == "nb":
        reconst_loss = (
            -NegativeBinomial(
                mu=px_rate, theta=px_r).log_prob(x).sum(dim=-1)
        )
    elif gene_likelihood == "poisson":
        reconst_loss = -Poisson(px_rate).log_prob(x).sum(dim=-1)
    return reconst_loss


class VAE(nn.Module):
    def __init__(
        self,
        n_input: int,
        connections=None,
        n_celltypes: int = 10,
        n_batch: int = 0,
        n_labels: int = 0,
        n_hidden: int = 128,
        n_latent: int = 10,
        n_layers: int = 1,
        n_continuous_cov: int = 0,
        n_cats_per_cov: Optional[Iterable[int]] = None,
        dropout_rate: float = 0.1,
        dispersion: str = "gene",
        log_variational: bool = True,
        gene_likelihood: str = "zinb",
        latent_distribution: str = "normal",
        encode_covariates: bool = False,
        deeply_inject_covariates: bool = True,
        use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "both",
        use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "none",
        use_observed_lib_size: bool = True,
    ):
        super().__init__()
        self.n_celltypes = n_celltypes
        self.connections = connections
        self.dispersion = dispersion
        if n_batch > 0:
            print(
                "Setting dispersion to gene-batch and"
                " encode_covariates to True")
            self.dispersion = "gene-batch"
            encode_covariates = True
        self.n_latent = n_latent
        self.log_variational = log_variational
        self.gene_likelihood = gene_likelihood
        # Automatically deactivate if useless
        self.n_batch = n_batch
        self.n_labels = n_labels
        self.latent_distribution = latent_distribution
        self.encode_covariates = encode_covariates
        self.use_observed_lib_size = use_observed_lib_size

        if self.dispersion == "gene":
            self.px_r = torch.nn.Parameter(torch.randn(n_input))
        elif self.dispersion == "gene-batch":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, n_batch))
        elif self.dispersion == "gene-label":
            self.px_r = torch.nn.Parameter(torch.randn(n_input, n_labels))
        elif self.dispersion == "gene-cell":
            pass
        else:
            raise ValueError(
                "dispersion must be one of ['gene', 'gene-batch',"
                " 'gene-label', 'gene-cell'], but input was "
                "{}.format(self.dispersion)"
            )

        use_batch_norm_encoder = use_batch_norm == "encoder" or\
            use_batch_norm == "both"
        use_batch_norm_decoder = use_batch_norm == "decoder" or\
            use_batch_norm == "both"
        use_layer_norm_encoder = use_layer_norm == "encoder" or\
            use_layer_norm == "both"
        use_layer_norm_decoder = use_layer_norm == "decoder" or\
            use_layer_norm == "both"

        # z encoder goes from the n_input-dimensional data to an n_latent-d
        # latent space representation
        n_input_encoder = n_input + n_continuous_cov * encode_covariates
        cat_list = [n_batch] + list(
            [] if n_cats_per_cov is None else n_cats_per_cov)
        self.cat_list = cat_list
        encoder_cat_list = cat_list if encode_covariates else None
        self.encoder_cat_list = encoder_cat_list
        self.z_encoder = Encoder(
            n_input_encoder,
            n_latent,
            self.connections,
            n_cat_list=encoder_cat_list,
            n_layers=n_layers,
            n_hidden=n_hidden,
            dropout_rate=dropout_rate,
            distribution=latent_distribution,
            inject_covariates=deeply_inject_covariates,
            use_batch_norm=use_batch_norm_encoder,
            use_layer_norm=use_layer_norm_encoder,
        )

        self.ctpred_linear = nn.Linear(n_latent, n_celltypes)
        self.ctpred_activation = nn.ReLU()

        # l encoder goes from n_input-dimensional data to 1-d library size
        self.l_encoder = Encoder(
            n_input_encoder,
            1,
            None,
            n_layers=1,
            n_cat_list=encoder_cat_list,
            n_hidden=n_hidden,
            dropout_rate=dropout_rate,
            inject_covariates=deeply_inject_covariates,
            use_batch_norm=use_batch_norm_encoder,
            use_layer_norm=use_layer_norm_encoder,
        )
        # decoder goes from n_latent-dimensional space to n_input-d data
        n_input_decoder = n_latent + n_continuous_cov
        self.decoder = DecoderSCVI(
            n_input_decoder,
            n_input,
            n_cat_list=cat_list,
            n_layers=n_layers,
            n_hidden=n_hidden,
            inject_covariates=deeply_inject_covariates,
            use_batch_norm=use_batch_norm_decoder,
            use_layer_norm=use_layer_norm_decoder,
        )

    def get_latents(self, x, y=None) -> torch.Tensor:
        """
        Returns the result of ``sample_from_posterior_z`` inside a list.
        Parameters
        ----------
        x
            tensor of values with shape ``(batch_size, inputsize)``
        y
            tensor of cell-types labels with shape
            ``(batch_size, n_labels)`` (Default value = None)
        Returns
        -------
        type
            one element list of tensor
        """
        return [self.sample_from_posterior_z(x, y)]

    def sample_from_posterior_z(
        self, x, batch_index=None, y=None, give_mean=False, n_samples=5000
    ) -> torch.Tensor:
        """
        Samples the tensor of latent values from the posterior.
        Parameters
        ----------
        x
            tensor of values with shape ``(batch_size, inputsize)``
        y
            tensor of cell-types labels with shape ``(batch_size, n_labels)``
            (Default value = None)
        give_mean
            is True when we want the mean of the posterior
            distribution rather than sampling (Default value = False)
        n_samples
            how many MC samples to average over for
            transformed mean (Default value = 5000)
        Returns
        -------
        type
            tensor of shape ``(batch_size, lvsize)``
        """
        if self.log_variational:
            x = torch.log(1 + x)
        qz_m, qz_v, z = self.z_encoder(
            x, batch_index)  # y only used in VAEC

        if give_mean:
            if self.latent_distribution == "ln":
                samples = Normal(qz_m, qz_v.sqrt()).sample([n_samples])
                z = self.z_encoder.z_transformation(samples, batch_index)
                z = z.mean(dim=0)
            else:
                z = qz_m
        return z

    def sample_from_posterior_l(
        self, x, batch_index=None, give_mean=True
    ) -> torch.Tensor:
        """
        Samples the tensor of library sizes from the posterior.
        Parameters
        ----------
        x
            tensor of values with shape ``(batch_size, inputsize)``
        y
            tensor of cell-types labels with shape ``(batch_size, n_labels)``
        give_mean
            Return mean or sample
        Returns
        -------
        type
            tensor of shape ``(batch_size, 1)``
        """
        if self.log_variational:
            x = torch.log(1 + x)
        ql_m, ql_v, library = self.l_encoder(x, batch_index)
        if give_mean is False:
            library = library
        else:
            library = torch.distributions.LogNormal(ql_m, ql_v.sqrt()).mean
        return library

    def get_sample_scale(
        self, x, batch_index=None, y=None, n_samples=1, transform_batch=None
    ) -> torch.Tensor:
        """
        Returns the tensor of predicted frequencies of expression.
        Parameters
        ----------
        x
            tensor of values with shape ``(batch_size, inputsize)``
        batch_index
            array that indicates which batch the cells
            belong to with shape ``batch_size`` (Default value = None)
        y
            tensor of cell-types labels with shape ``(batch_size,
            n_labels)`` (Default value = None)
        n_samples
            number of samples (Default value = 1)
        transform_batch
            int of batch to transform samples into (Default value = None)
        Returns
        -------
        type
            tensor of predicted frequencies of expression
            with shape ``(batch_size, inputsize)``
        """
        return self.inference(
            x,
            batch_index=batch_index,
            y=y,
            n_samples=n_samples,
            transform_batch=transform_batch,
        )["px_scale"]

    def get_sample_rate(
        self, x, batch_index=None, y=None, n_samples=1, transform_batch=None
    ) -> torch.Tensor:
        """
        Returns the tensor of means of the negative binomial distribution.
        Parameters
        ----------
        x
            tensor of values with shape ``(batch_size, inputsize)``
        y
            tensor of cell-types labels with shape
            ``(batch_size, n_labels)`` (Default value = None)
        batch_index
            array that indicates which batch the cells belong to with
            shape ``batch_size`` (Default value = None)
        n_samples
            number of samples (Default value = 1)
        transform_batch
            int of batch to transform samples into (Default value = None)
        Returns
        -------
        type
            tensor of means of the negative binomial distribution with
            shape ``(batch_size, inputsize)``
        """
        return self.inference(
            x,
            batch_index=batch_index,
            y=y,
            n_samples=n_samples,
            transform_batch=transform_batch,
        )["px_rate"]

    def inference(
        self, x, batch_index=None, y=None, n_samples=1, transform_batch=None
    ) -> Dict[str, torch.Tensor]:
        """Helper function used in forward pass."""
        x_ = x
        if self.use_observed_lib_size:
            library = torch.log(x.sum(1)).unsqueeze(1)
        if self.log_variational:
            x_ = torch.log(1 + x_)

        # Sampling
        qz_m, qz_v, z = self.z_encoder(x_, batch_index)
        ql_m, ql_v, library_encoded = self.l_encoder(x_, batch_index)
        if not self.use_observed_lib_size:
            library = library_encoded

        # Predict celltypes using z
        ctpred = self.ctpred_activation(
            self.ctpred_linear(qz_m))

        if n_samples > 1:
            qz_m = qz_m.unsqueeze(
                0).expand((n_samples, qz_m.size(0), qz_m.size(1)))
            qz_v = qz_v.unsqueeze(
                0).expand((n_samples, qz_v.size(0), qz_v.size(1)))
            # when z is normal, untran_z == z
            untran_z = Normal(qz_m, qz_v.sqrt()).sample()
            z = self.z_encoder.z_transformation(untran_z)
            ql_m = ql_m.unsqueeze(
                0).expand((n_samples, ql_m.size(0), ql_m.size(1)))
            ql_v = ql_v.unsqueeze(
                0).expand((n_samples, ql_v.size(0), ql_v.size(1)))
            if self.use_observed_lib_size:
                library = library.unsqueeze(0).expand(
                    (n_samples, library.size(0), library.size(1))
                )
            else:
                library = Normal(ql_m, ql_v.sqrt()).sample()

        if transform_batch is not None:
            dec_batch_index = transform_batch * torch.ones_like(batch_index)
        else:
            dec_batch_index = batch_index

        px_scale, px_r, px_rate, px_dropout, px = self.decoder(
            self.dispersion, z, library, batch_index
        )
        if self.dispersion == "gene-label":
            px_r = F.linear(
                one_hot(y, self.n_labels), self.px_r
            )  # px_r gets transposed - last dimension is nb genes
        elif self.dispersion == "gene-batch":
            px_r = F.linear(one_hot(dec_batch_index, self.n_batch), self.px_r)
        elif self.dispersion == "gene":
            px_r = self.px_r
        px_r = torch.exp(px_r)

        return dict(
            px_scale=px_scale,
            px_r=px_r,
            px_rate=px_rate,
            px_dropout=px_dropout,
            qz_m=qz_m,
            qz_v=qz_v,
            z=z,
            ql_m=ql_m,
            ql_v=ql_v,
            library=library,
            px=px,
            ctpred=ctpred
        )

    def forward(
        self, x, batch_index=None, y=None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Parameters for z latent distribution
        outputs = self.inference(x, batch_index, y)
        return outputs


In [3]:
# Import necessary libraries
import os
import numpy as np
import pandas as pd
import torch
import scanpy as sc
import scvelo as scv
import matplotlib.pyplot as plt
from statsmodels.stats.multitest import multipletests
import scipy.stats
from collections import OrderedDict
from datetime import datetime
from sklearn.metrics import accuracy_score
import torch.nn.functional as F
import torch.nn as nn
import scipy.sparse as sp_sparse

# Device configuration
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Define the VAE model class
class VAE(nn.Module):
    def __init__(
        self,
        n_input,
        connections=None,
        n_celltypes=0,
        n_batch=0,
        n_labels=0,
        n_hidden=128,
        n_latent=10,
        n_layers=1,
        dropout_rate=0.1,
        predict_celltype=False
    ):
        super(VAE, self).__init__()
        self.n_input = n_input
        self.n_latent = n_latent
        self.predict_celltype = predict_celltype

        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(n_input, n_hidden),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(n_hidden, n_latent * 2)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(n_latent, n_hidden),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(n_hidden, n_input)
        )

        # Optional classifier for cell type prediction
        if self.predict_celltype:
            self.classifier = nn.Linear(n_latent, n_celltypes)

    def encode(self, x):
        h = self.encoder(x)
        mu, logvar = h.chunk(2, dim=-1)
        # Clamp logvar to prevent numerical issues
        logvar = torch.clamp(logvar, min=-10, max=10)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z)
        output = {
            'recon_x': recon_x,
            'mu': mu,
            'logvar': logvar
        }
        if self.predict_celltype:
            output['state_pred'] = self.classifier(z)
        return output

# Define the loss function
def loss_function(recon_x, x, mu, logvar, state_pred=None, state_true=None, loss_scalers=[1000, 1, 1]):
    # Reconstruction loss
    recon_loss = F.mse_loss(recon_x, x, reduction='sum')

    # KL divergence
    kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    # Total loss
    loss = (recon_loss / loss_scalers[0]) + (kld / loss_scalers[1])

    # Cell type prediction loss
    if state_pred is not None and state_true is not None:
        ce_loss = F.cross_entropy(state_pred, state_true)
        loss += ce_loss / loss_scalers[2]
    else:
        ce_loss = torch.tensor(0.0).to(device)

    return loss, recon_loss, kld, ce_loss

def main(
    gmtpath,
    nparpaths,
    outdir,
    numlvs,
    metapaths,
    dont_train,
    genepath,
    existingmodelpath,
    use_connections,
    loss_scalers,
    predict_celltypes,
    num_celltypes,
    filter_var,
    num_genes,
    include_batches
):
    """
    Main function to load data, initialize the model, train it, and return the trained model along with necessary variables.
    """
    # Load the GMT matrix and gene list
    gmtmat, tfs, genes_gmtmat = make_gmtmat(gmtpath, outdir, genepath)

    # Load input data
    dict_inputs = load_inputs(
        nparpaths, gmtmat, outdir, genes_gmtmat, metapaths
    )
    expar = dict_inputs["expar"]          # Expression matrix (cells x genes)
    metadf = dict_inputs["metadf"]        # Metadata DataFrame
    barcodes = dict_inputs["barcodes"]    # Cell barcodes
    genes_expar = dict_inputs["genes"]    # Genes from expression data
    gmtmat_original = dict_inputs["gmtmat"]  # Original gmtmat (genes x tfs)

    # Align gmtmat to match the genes in expar
    gmtmat_df = pd.DataFrame(gmtmat_original, index=genes_gmtmat, columns=tfs)
    gmtmat_aligned_df = gmtmat_df.reindex(genes_expar).fillna(0)
    gmtmat_aligned = gmtmat_aligned_df.values  # Shape: (n_genes, n_tfs)

    # Prepare state labels if predicting cell types
    if predict_celltypes:
        # Assuming 'State' is a column in metadf
        states = metadf['State'].unique().tolist()
        num_states = len(states)
        metadf['State_Code'] = metadf['State'].astype('category').cat.codes
        state_labels = torch.from_numpy(metadf['State_Code'].values).long().to(device)
    else:
        states = []
        state_labels = None
        num_states = 0  # Set to 0 if not predicting states

    # Handle batch indices if necessary
    if include_batches:
        if 'Batch' in metadf.columns:
            metadf['Batch_Code'] = metadf['Batch'].astype('category').cat.codes
            batch_idxs = metadf['Batch_Code'].values
        else:
            batch_idxs = None
    else:
        batch_idxs = None

    # Convert expression data to PyTorch tensor
    expar = torch.from_numpy(expar).float().to(device)

    # Check for NaNs or Infs in expar
    if torch.isnan(expar).any() or torch.isinf(expar).any():
        print("NaN or Inf detected in expar. Replacing with zeros.")
        expar = torch.where(torch.isnan(expar), torch.zeros_like(expar), expar)
        expar = torch.where(torch.isinf(expar), torch.zeros_like(expar), expar)

    # Initialize the model
    vae = VAE(
        n_input=expar.shape[1],  # Number of genes
        connections=gmtmat_aligned if use_connections else None,
        n_celltypes=num_states,
        n_batch=0 if batch_idxs is None else len(np.unique(batch_idxs)),
        n_labels=0,  # Adjust if you have labels
        n_hidden=gmtmat_aligned.shape[1] if use_connections else 128,
        n_latent=numlvs,
        n_layers=1,
        dropout_rate=0.1,
        predict_celltype=predict_celltypes
    ).to(device)

    # Initialize the optimizer with a lower learning rate
    optimizer = torch.optim.Adam(vae.parameters(), lr=1e-4)

    # Paths for saving the model and checkpoints
    modelpath = os.path.join(outdir, "vae.pt")
    chkpath = os.path.join(outdir, "vae_chkp.pt")
    logdir = outdir

    # Load existing model if specified
    if existingmodelpath != 'NA':
        vae, optimizer = load_existing_model(existingmodelpath, chkpath, vae, optimizer)

    # Training parameters
    MINIBATCH = 64  # Adjust based on your hardware
    MAXEPOCH = 50   # Set the number of epochs for training
    loss_scalers = loss_scalers  # As provided

    # Train the model
    if not dont_train:
        vae = train_model(
            vae, optimizer, MINIBATCH, MAXEPOCH,
            expar, logdir,
            modelpath, chkpath, state_labels,
            loss_scalers, predict_celltypes,
            states, batch_idxs
        )

    # Return the trained model and other necessary variables
    return vae, gmtmat_aligned, tfs, states

def train_model(
    vae,
    optimizer,
    MINIBATCH,
    MAXEPOCH,
    expar,
    logdir,
    modelpath,
    chkpath,
    state_labels,
    loss_scalers,
    predict_celltypes,
    states=[],
    batch_idxs=None
):
    criterion_class = torch.nn.CrossEntropyLoss()
    time_str = str(datetime.now())
    time_str = time_str.replace(" ", "_")
    time_str = time_str.replace(":", "0")

    # Use os.environ.get() to avoid KeyError
    job_id = os.environ.get("SLURM_JOB_ID", "NA")

    logpath = os.path.join(
        logdir,
        "training.log.{}.{}".format(
            job_id, time_str)
    )
    accpath = logpath + "_accuracy.txt"

    # Initialize log files
    with open(logpath, "w") as loglink:
        header = [
            "Epoch",
            "Reconstruction.Loss",
            "KLD",
            "CE.Loss",
            "Accuracy",
            "MiniBatch.ID",
            "Time.Stamp"
        ]
        loglink.write("\t".join(header) + "\n")

    if predict_celltypes:
        with open(accpath, "a") as acclink:
            header_acc = ["Epoch"] + [state + ".acc" for state in states]
            acclink.write("\t".join(header_acc) + "\n")

    TOTBATCHIDX = int(np.ceil(expar.shape[0] / MINIBATCH))

    for epoch in range(MAXEPOCH):
        # Shuffle indices for each epoch
        sampled_idxs = np.random.choice(
            np.arange(expar.shape[0]), expar.shape[0], replace=False
        )

        running_loss_reconst = 0
        running_kld = 0
        running_ce = 0
        running_loss = 0
        accval = 0
        state_resps = np.zeros(expar.shape[0])
        state_preds = np.zeros(expar.shape[0])

        for idxbatch in range(TOTBATCHIDX):
            idxbatch_st = idxbatch * MINIBATCH
            idxbatch_end = min((idxbatch + 1) * MINIBATCH, expar.shape[0])
            cur_sidxs = sampled_idxs[idxbatch_st:idxbatch_end]
            train1 = expar[cur_sidxs, :]

            if batch_idxs is not None:
                batch_idxs_tensor = torch.from_numpy(batch_idxs[cur_sidxs]).long().to(device).reshape(-1, 1)
            else:
                batch_idxs_tensor = None

            optimizer.zero_grad()

            # Forward pass
            outdict = vae(train1)
            recon_x = outdict['recon_x']
            mu = outdict['mu']
            logvar = outdict['logvar']

            if predict_celltypes:
                state_pred = outdict['state_pred']
                state_true = state_labels[cur_sidxs]
            else:
                state_pred = None
                state_true = None

            loss, loss_1, loss_2, loss_3 = loss_function(
                recon_x, train1, mu, logvar,
                state_pred=state_pred, state_true=state_true,
                loss_scalers=loss_scalers
            )

            # Check for NaNs
            if torch.isnan(loss):
                print(f"Losses: Reconstruction Loss = {loss_1.item()}, KLD = {loss_2.item()}, CE Loss = {loss_3.item()}")
                print("Checking for NaNs in recon_x, x, mu, logvar, state_pred, state_true")
                print(f"NaNs in recon_x: {torch.isnan(recon_x).any()}")
                print(f"NaNs in x: {torch.isnan(train1).any()}")
                print(f"NaNs in mu: {torch.isnan(mu).any()}")
                print(f"NaNs in logvar: {torch.isnan(logvar).any()}")
                if predict_celltypes:
                    print(f"NaNs in state_pred: {torch.isnan(state_pred).any()}")
                    print(f"NaNs in state_true: {torch.isnan(state_true).any()}")
                raise ValueError("NaN occurred in loss computation.")

            # Backward pass
            loss.backward()
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=1.0)
            optimizer.step()

            running_loss_reconst += loss_1.item() / loss_scalers[0]
            running_kld += loss_2.item() / loss_scalers[1]
            running_ce += loss_3.item() / loss_scalers[2]
            running_loss += loss.item()

            if predict_celltypes:
                _, predicted = torch.max(state_pred.data, 1)
                total = state_true.size(0)
                correct = (predicted == state_true).sum().item()
                accval += correct / total
                state_resps[cur_sidxs] = state_true.cpu().numpy()
                state_preds[cur_sidxs] = predicted.cpu().numpy()

            del train1, outdict
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        cur_loss = running_loss / TOTBATCHIDX
        cur_loss_reconst = running_loss_reconst / TOTBATCHIDX
        cur_kld = running_kld / TOTBATCHIDX
        cur_ce = running_ce / TOTBATCHIDX
        accval = accval / TOTBATCHIDX

        adlist_cts = [str(epoch)]
        for k in range(len(states)):
            pred_state = state_preds == k
            resp_state = state_resps == k
            cur_acc = accuracy_score(resp_state, pred_state)
            adlist_cts.append(str(round(cur_acc, 3)))

        if predict_celltypes:
            with open(accpath, "a") as acclink:
                acclink.write("\t".join(adlist_cts) + "\n")

        print(f"Epoch {epoch}, Loss {cur_loss} at {datetime.now()}")

        with open(logpath, "a") as loglink:
            adlist = [
                str(epoch),
                str(cur_loss_reconst),
                str(cur_kld),
                str(cur_ce),
                str(round(accval, 3)),
                str(idxbatch),
                str(datetime.now())
            ]
            loglink.write("\t".join(adlist) + "\n")

        if epoch % 10 == 0:
            checkpoint = {
                'model': vae.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            for eachpath in [modelpath, chkpath]:
                torch.save(checkpoint, eachpath)

    return vae

def load_inputs(nparpaths, gmtmat, outdir, genes, metapaths):
    """
    Loads the input data from AnnData HDF5 files and prepares it for training.
    """
    expar_list = []
    barcodes_list = []
    metadf_list = []

    for i in range(len(nparpaths)):
        print(f"Loading {nparpaths[i]}")
        # Load the AnnData object
        adata = sc.read_h5ad(nparpaths[i])

        # Extract expression matrix
        expar = adata.X
        if sp_sparse.issparse(expar):
            expar = expar.toarray()  # Convert to dense if necessary

        # Extract gene names
        if 'Gene' in adata.var.columns:
            genes = adata.var['Gene'].values
        else:
            genes = adata.var.index.values

        # Extract cell barcodes
        if 'CellID' in adata.obs.columns:
            barcodes = adata.obs['CellID'].values
        else:
            barcodes = adata.obs.index.values

        # Extract cell metadata
        metadf = adata.obs.copy()

        # Append to lists
        expar_list.append(expar)
        barcodes_list.append(barcodes)
        metadf_list.append(metadf)

    # Concatenate data from all datasets
    expar = np.concatenate(expar_list, axis=0)
    barcodes = np.concatenate(barcodes_list, axis=0)
    metadf = pd.concat(metadf_list, axis=0, ignore_index=True)

    # Return the data in a dictionary
    return {
        'expar': expar,
        'metadf': metadf,
        'barcodes': barcodes,
        'genes': genes,
        'gmtmat': gmtmat
    }

def load_existing_model(modelpath, chkpath, vae, optimizer):
    """
    Loads an existing model and optimizer state.
    """
    for eachpath in [modelpath, chkpath]:
        if os.path.exists(eachpath):
            try:
                checkpoint = torch.load(eachpath)
                state_dict = checkpoint['model']
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    k = k.replace('module.', '')
                    new_state_dict[k] = v
                vae.load_state_dict(new_state_dict)
                optimizer.load_state_dict(checkpoint['optimizer'])
                print(f"Loaded from {eachpath}")
                return vae, optimizer
            except Exception:
                pass
    print("Didn't load from any")
    return vae, optimizer

def make_gmtmat(gmtpath, outdir, genepath):
    """
    Loads the GMT matrix and gene list.

    Parameters
    ----------
    gmtpath : str
        Path to the GMT file containing gene set definitions.
    outdir : str
        Output directory to save intermediate files.
    genepath : str
        Path to the file containing gene names.

    Returns
    -------
    gmtmat : np.ndarray
        Gene-TF connection matrix (genes x TFs).
    tfs : list
        List of transcription factor names.
    genes_gmtmat : list
        List of gene names used in gmtmat.
    """
    # Load the gene list from 'genepath'
    with open(genepath, 'r') as f:
        genes_gmtmat = [line.strip() for line in f]

    # Load the GMT file and parse TFs
    tfs = []
    tf_gene_dict = {}
    with open(gmtpath, 'r') as gmt_file:
        for line in gmt_file:
            tokens = line.strip().split('\t')
            tf_name = tokens[0]
            target_genes = tokens[2:]  # Skip description in tokens[1]
            tfs.append(tf_name)
            tf_gene_dict[tf_name] = set(target_genes)

    n_tfs = len(tfs)
    n_genes = len(genes_gmtmat)

    # Initialize gmtmat with zeros
    gmtmat = np.zeros((n_genes, n_tfs), dtype=np.float32)

    gene_to_index = {gene: idx for idx, gene in enumerate(genes_gmtmat)}

    for tf_idx, tf_name in enumerate(tfs):
        target_genes = tf_gene_dict[tf_name]
        for gene in target_genes:
            if gene in gene_to_index:
                gene_idx = gene_to_index[gene]
                gmtmat[gene_idx, tf_idx] = 1  # Binary connections

    return gmtmat, tfs, genes_gmtmat


Global seed set to 0
  new_rank_zero_deprecation(
  return new_rank_zero_deprecation(*args, **kwargs)


In [4]:
# Import necessary modules
import os
import torch
import numpy as np
#from model import VAE, loss_function  # Make sure model.py is in the same directory
#from train import main  # Make sure train.py is in the same directory
import h5py
import scvelo as scv

Global seed set to 0
  new_rank_zero_deprecation(
  return new_rank_zero_deprecation(*args, **kwargs)


# Running MAVE on OS152

In [None]:

# Main code

# List of datasets
datasets = ['OS152', 'OS384', 'OS742']

# Number of bootstrap samples
n_bootstrap = 5

# Loop over each dataset
for dataset in datasets:
    print(f"Processing dataset: {dataset}")

    # Define paths and parameters specific to the dataset
    BASE_DIR = '/Users/brendamelano/Desktop/Reprogramming_Osteosarcoma/plain_scRNAseq_analysis'
    H5PATH = f'{BASE_DIR}/{dataset}/{dataset}_adata_subtype_PCA.h5ad'
    METAPATH = f'{BASE_DIR}/{dataset}/{dataset}_10Xbarcode_subtype.tsv'
    GENEPATH = '/Users/brendamelano/Desktop/ciberatac/mave_data/Genes_passing_40p.txt'
    GMTPATH = '/Users/brendamelano/Desktop/ciberatac/mave_data/c3.tft.v7.2.symbols.gmt'
    OUTDIR = f'{BASE_DIR}/{dataset}/output'

    # Create output directory if it doesn't exist
    os.makedirs(OUTDIR, exist_ok=True)

    # Set parameters
    gmtpath = GMTPATH
    outdir = OUTDIR
    nparpaths = [H5PATH]
    numlvs = 10
    genepath = GENEPATH
    metapaths = [METAPATH]
    num_celltypes = 3
    predict_celltypes = True
    use_connections = True
    loss_scalers = [1000, 1, 1]
    dont_train = False
    filter_var = False
    num_genes = 2000
    include_batches = False

    # Initialize arrays to store bootstrap weights
    state_tf_weights_bootstrap = []

    for b in range(n_bootstrap):
        print(f"Bootstrap iteration: {b+1}/{n_bootstrap}")

        # Run the main function and capture the returned values
        vae, gmtmat_aligned, tfs, states = main(
            gmtpath=gmtpath,
            nparpaths=nparpaths,
            outdir=outdir,
            numlvs=numlvs,
            metapaths=metapaths,
            dont_train=dont_train,
            genepath=genepath,
            existingmodelpath='NA',
            use_connections=use_connections,
            loss_scalers=loss_scalers,
            predict_celltypes=predict_celltypes,
            num_celltypes=num_celltypes,
            filter_var=filter_var,
            num_genes=num_genes,
            include_batches=include_batches
        )

        # Now vae is defined, and you can proceed to extract the model weights
        # Extract model weights
        model_state_dict = vae.state_dict()

        # Extract classifier weights
        classifier_weights = model_state_dict['classifier.weight'].cpu().detach().numpy()  # Shape: (n_states, n_latent)

        # Extract decoder weights
        decoder_weight_0 = model_state_dict['decoder.0.weight'].cpu().detach().numpy()  # Shape: (n_hidden, n_latent)
        decoder_weight_3 = model_state_dict['decoder.3.weight'].cpu().detach().numpy()  # Shape: (n_input, n_hidden)

        # Compute the effective decoder weight matrix
        effective_decoder_weights = np.dot(decoder_weight_3, decoder_weight_0)  # Shape: (n_input, n_latent)

        # 'gmtmat_aligned' is already aligned with 'expar' genes
        connections = gmtmat_aligned  # Shape: (n_input, n_tfs)

        # Compute TF weights
        tf_weights = np.dot(effective_decoder_weights.T, connections)  # Shape: (n_latent, n_tfs)

        # Compute influence of TFs on States
        state_tf_weights = np.dot(classifier_weights, tf_weights)  # Shape: (n_states, n_tfs)

        # Store the weights
        state_tf_weights_bootstrap.append(state_tf_weights)

    # Convert list to numpy array
    state_tf_weights_bootstrap = np.array(state_tf_weights_bootstrap)  # Shape: (n_bootstrap, n_states, n_tfs)

    # Compute mean and std deviations
    state_tf_weights_mean = np.mean(state_tf_weights_bootstrap, axis=0)  # Shape: (n_states, n_tfs)
    state_tf_weights_std = np.std(state_tf_weights_bootstrap, axis=0)  # Shape: (n_states, n_tfs)

    # Proceed with p-value computation and plotting
    n_states = len(states)
    ncols = 3  # Adjust as needed
    nrows = int(np.ceil(n_states / ncols))

    subplot_width = 3  # inches
    subplot_height = 3  # inches
    fig, axes = plt.subplots(nrows, ncols, figsize=(subplot_width * ncols, subplot_height * nrows))
    axes = axes.flatten()

    top_n = 10  # Number of top TFs to highlight per State
    effect_size_cutoff = 0.1  # Adjust as needed
    p_value_cutoff = 0.05  # Adjust as needed

    # Calculate the -log10 of the p-value cutoff for plotting the hor


Processing dataset: OS152
Bootstrap iteration: 1/5
Loading /Users/brendamelano/Desktop/Reprogramming_Osteosarcoma/plain_scRNAseq_analysis/OS152/OS152_adata_subtype_PCA.h5ad
Epoch 0, Loss 209.85550914091223 at 2024-10-08 11:31:38.107602
Epoch 1, Loss 197.8373923956179 at 2024-10-08 11:31:43.131246
Epoch 2, Loss 195.7526246613147 at 2024-10-08 11:31:47.899802
Epoch 3, Loss 194.6672815061083 at 2024-10-08 11:31:53.338234
Epoch 4, Loss 193.8763753853592 at 2024-10-08 11:31:59.912385
Epoch 5, Loss 193.14154411764707 at 2024-10-08 11:32:05.053773
Epoch 6, Loss 192.5203742980957 at 2024-10-08 11:32:10.371475
Epoch 7, Loss 192.0291124231675 at 2024-10-08 11:32:15.223442
Epoch 8, Loss 191.5847600301107 at 2024-10-08 11:32:20.055241
Epoch 9, Loss 191.17907318414427 at 2024-10-08 11:32:25.071320
Epoch 10, Loss 190.8751890145096 at 2024-10-08 11:32:30.010517
Epoch 11, Loss 190.64690159816368 at 2024-10-08 11:32:35.401424
Epoch 12, Loss 190.43299237419578 at 2024-10-08 11:32:40.381388
Epoch 13, Los

Epoch 22, Loss 189.53717549641928 at 2024-10-08 11:42:12.594240
Epoch 23, Loss 189.50076383702896 at 2024-10-08 11:42:17.325624
Epoch 24, Loss 189.4604226654651 at 2024-10-08 11:42:22.082865
Epoch 25, Loss 189.42577175065583 at 2024-10-08 11:42:27.069373
Epoch 26, Loss 189.39593333824007 at 2024-10-08 11:42:32.312110
Epoch 27, Loss 189.35771785062903 at 2024-10-08 11:42:38.187603
Epoch 28, Loss 189.3230701521331 at 2024-10-08 11:42:43.398162
Epoch 29, Loss 189.30475130268172 at 2024-10-08 11:42:48.546531
Epoch 30, Loss 189.27764324113434 at 2024-10-08 11:42:54.270519
Epoch 31, Loss 189.25508716059667 at 2024-10-08 11:42:59.593366
Epoch 32, Loss 189.21683988384171 at 2024-10-08 11:43:04.931194
Epoch 33, Loss 189.2094035429113 at 2024-10-08 11:43:10.243494
Epoch 34, Loss 189.1825013254203 at 2024-10-08 11:43:15.385942
Epoch 35, Loss 189.16685373642866 at 2024-10-08 11:43:20.551246
Epoch 36, Loss 189.1439037696988 at 2024-10-08 11:43:25.581154
Epoch 37, Loss 189.1276984869265 at 2024-10-0

Epoch 47, Loss 188.95772492651847 at 2024-10-08 11:53:31.742216
Epoch 48, Loss 188.93910179886163 at 2024-10-08 11:53:39.002504
Epoch 49, Loss 188.92387292899338 at 2024-10-08 11:53:45.378344
Processing dataset: OS384
Bootstrap iteration: 1/5
Loading /Users/brendamelano/Desktop/Reprogramming_Osteosarcoma/plain_scRNAseq_analysis/OS384/OS384_adata_subtype_PCA.h5ad
Epoch 0, Loss 216.42497343175552 at 2024-10-08 11:53:49.542117
Epoch 1, Loss 201.6567720899395 at 2024-10-08 11:53:52.755610
Epoch 2, Loss 198.8819663851869 at 2024-10-08 11:53:55.707450
Epoch 3, Loss 198.13379594391466 at 2024-10-08 11:53:58.604124
Epoch 4, Loss 197.54298131606157 at 2024-10-08 11:54:01.420384
Epoch 5, Loss 196.95059174182367 at 2024-10-08 11:54:04.219739
Epoch 6, Loss 196.508633183498 at 2024-10-08 11:54:07.186960
Epoch 7, Loss 196.0842829685585 at 2024-10-08 11:54:10.214764
Epoch 8, Loss 195.67930662865732 at 2024-10-08 11:54:13.157655
Epoch 9, Loss 195.36234956629136 at 2024-10-08 11:54:17.287162
Epoch 10, 

Epoch 19, Loss 192.88334835276885 at 2024-10-08 11:59:22.890420
Epoch 20, Loss 192.7198851342295 at 2024-10-08 11:59:25.358319
Epoch 21, Loss 192.5856460870481 at 2024-10-08 11:59:28.016064
Epoch 22, Loss 192.45180421717026 at 2024-10-08 11:59:30.513541
Epoch 23, Loss 192.31359803442862 at 2024-10-08 11:59:33.268190
Epoch 24, Loss 192.21567550359987 at 2024-10-08 11:59:36.404464
Epoch 25, Loss 192.13661762312347 at 2024-10-08 11:59:39.255031
Epoch 26, Loss 192.0325087005017 at 2024-10-08 11:59:42.316969
Epoch 27, Loss 191.96672147863052 at 2024-10-08 11:59:45.042822
Epoch 28, Loss 191.8806032666973 at 2024-10-08 11:59:47.565990
Epoch 29, Loss 191.83342668121935 at 2024-10-08 11:59:50.194684
Epoch 30, Loss 191.76276652018228 at 2024-10-08 11:59:53.048725
Epoch 31, Loss 191.72049623377183 at 2024-10-08 11:59:55.866095
Epoch 32, Loss 191.66573528214997 at 2024-10-08 11:59:58.579632
Epoch 33, Loss 191.64467067344515 at 2024-10-08 12:00:01.363417
Epoch 34, Loss 191.56983379289215 at 2024-10

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import scanpy as sc
import scvelo as scv
import matplotlib.pyplot as plt
from statsmodels.stats.multitest import multipletests
import scipy.stats

# List of datasets
datasets = ['OS152', 'OS384', 'OS742']

# Loop over each dataset
for dataset in datasets:
    print(f"Processing dataset: {dataset}")
    
    # Define paths and parameters specific to the dataset
    BASE_DIR = '/Users/brendamelano/Desktop/Reprogramming_Osteosarcoma/plain_scRNAseq_analysis'
    H5PATH = f'{BASE_DIR}/{dataset}/{dataset}_adata_subtype_PCA.h5ad'
    METAPATH = f'{BASE_DIR}/{dataset}/{dataset}_10Xbarcode_subtype.tsv'
    GENEPATH = '/Users/brendamelano/Desktop/ciberatac/mave_data/Genes_passing_40p.txt'
    GMTPATH = '/Users/brendamelano/Desktop/ciberatac/mave_data/c3.tft.v7.2.symbols.gmt'
    OUTDIR = f'{BASE_DIR}/{dataset}/output'
    
    # Load your data
    adata_path = f"{BASE_DIR}/{dataset}/{dataset}_adata_subtype_PCA.h5ad"
    adata = scv.read(adata_path)
    
    # Create output directory if it doesn't exist
    os.makedirs(OUTDIR, exist_ok=True)
    
    # Set parameters
    gmtpath = GMTPATH
    outdir = OUTDIR
    nparpaths = [H5PATH]
    numlvs = 10
    genepath = GENEPATH
    metapaths = [METAPATH]
    num_celltypes = 8
    predict_celltypes = True
    use_connections = True
    loss_scalers = [1000, 1, 1]
    dont_train = False
    filter_var = False
    num_genes = 2000
    include_batches = False
    
    # Run the main function and capture the returned values
    vae, gmtmat_aligned, tfs, states = main(
        gmtpath=gmtpath,
        nparpaths=nparpaths,
        outdir=outdir,
        numlvs=numlvs,
        metapaths=metapaths,
        dont_train=dont_train,
        genepath=genepath,
        existingmodelpath='NA',
        use_connections=use_connections,
        loss_scalers=loss_scalers,
        predict_celltypes=predict_celltypes,
        num_celltypes=num_celltypes,
        filter_var=filter_var,
        num_genes=num_genes,
        include_batches=include_batches
    )
    
    # Now vae is defined, and you can proceed to extract the model weights
    # Extract model weights
    model_state_dict = vae.state_dict()
    
    # Extract classifier weights
    classifier_weights = model_state_dict['classifier.weight'].cpu().detach().numpy()  # Shape: (n_states, n_latent)
    
    # Extract decoder weights
    decoder_weight_0 = model_state_dict['decoder.0.weight'].cpu().detach().numpy()  # Shape: (n_hidden, n_latent)
    decoder_weight_3 = model_state_dict['decoder.3.weight'].cpu().detach().numpy()  # Shape: (n_input, n_hidden)
    
    # Compute the effective decoder weight matrix
    effective_decoder_weights = np.dot(decoder_weight_3, decoder_weight_0)  # Shape: (n_input, n_latent)
    
    # 'gmtmat_aligned' is already aligned with 'expar' genes
    connections = gmtmat_aligned  # Shape: (n_input, n_tfs)
    
    # Compute TF weights
    tf_weights = np.dot(effective_decoder_weights.T, connections)  # Shape: (n_latent, n_tfs)
    
    # Compute influence of TFs on States
    state_tf_weights = np.dot(classifier_weights, tf_weights)  # Shape: (n_states, n_tfs)
    
    # Normalize the weights
    state_tf_weights_norm = state_tf_weights / np.linalg.norm(state_tf_weights, axis=1, keepdims=True)
    
    # Get TF names
    tf_names = tfs  # List of TF names
    
    # Plotting with FDR correction
    n_states = len(states)
    ncols = 3  # Adjust as needed
    nrows = int(np.ceil(n_states / ncols))
    
    subplot_width = 3  # inches
    subplot_height = 3  # inches
    fig, axes = plt.subplots(nrows, ncols, figsize=(subplot_width * ncols, subplot_height * nrows))
    axes = axes.flatten()
    
    top_n = 10  # Number of top TFs to highlight per State
    effect_size_cutoff = 0.1  # Adjust as needed
    p_value_cutoff = 0.05  # Adjust as needed
    
    # Calculate the -log10 of the p-value cutoff for plotting the horizontal line
    neg_log_p_value_cutoff = -np.log10(p_value_cutoff)
    
    for i, state in enumerate(states):
        state_weights = state_tf_weights_norm[i]
        tf_weights = state_weights  # All TF weights for the state
        tf_names_state = tf_names  # All TF names
        
        # Compute p-values
        std_dev = np.std(state_weights) if np.std(state_weights) > 0 else 1
        df = len(state_weights) - 1
        t_stats = tf_weights / (std_dev / np.sqrt(len(state_weights)))
        p_values = 2 * (1 - scipy.stats.t.cdf(np.abs(t_stats), df))
        
        # Apply FDR correction
        corrected_p_values = multipletests(p_values, method='fdr_bh')[1]
        
        # Create a volcano plot
        ax = axes[i]
        effect_sizes = tf_weights  # Effect sizes on x-axis
        log_p_values = -np.log10(corrected_p_values)  # -log10(p-values) on y-axis
        
        # Scatter plot
        ax.scatter(effect_sizes, log_p_values, color='grey', alpha=0.7)
        
        # Determine which TFs to highlight
        significant = (np.abs(effect_sizes) >= effect_size_cutoff) & (corrected_p_values <= p_value_cutoff)
        top_tf_indices = np.argsort(-np.abs(effect_sizes))[:top_n]
        highlight_indices = np.where(significant)[0]
        # Combine top N and significant indices
        highlight_indices = np.unique(np.concatenate([top_tf_indices, highlight_indices]))
        
        # Highlight and annotate the selected TFs
        ax.scatter(effect_sizes[highlight_indices], log_p_values[highlight_indices], color='red', alpha=0.8)
        for idx in highlight_indices:
            tf_name = tf_names_state[idx]
            x = effect_sizes[idx]
            y = log_p_values[idx]
            ax.annotate(tf_name, (x, y), textcoords="offset points", xytext=(0,5), ha='center', fontsize=8)
        
        # Add the horizontal dotted red line at the significance cutoff
        ax.axhline(y=neg_log_p_value_cutoff, color='red', linestyle='--', linewidth=1)
        ax.text(ax.get_xlim()[1], neg_log_p_value_cutoff, f'p = {p_value_cutoff}', color='red', va='bottom', ha='right', fontsize=8)
        
        # Set titles and labels
        ax.set_title(f"State {state} in Dataset {dataset}")
        ax.set_xlabel("Effect Size (Normalized Weight)")
        ax.set_ylabel("-log10(p-value)")
    
    # Remove any empty subplots
    for j in range(i + 1, len(axes)):
        fig.delaxes(axes[j])
    
    plt.tight_layout()
    
    # Save the figure as SVG to desktop
    desktop_path = '/Users/brendamelano/Desktop/'
    output_filename = f'{desktop_path}{dataset}_state_tf_volcano_plot.svg'
    plt.savefig(output_filename, format='svg')
    print(f"Plot saved to {output_filename}")
    
    plt.close(fig)  # Close the figure to prevent overlap in the next iteration
