In [2]:
import numpy as np
import pandas as pd
import scipy.sparse as sp_sparse
from anndata import AnnData
from typing import Optional, Union
import scipy
from scipy import sparse
from dynamo.configuration import DynamoAdataKeyManager, DKM
from dynamo import main_info
from scipy.special import xlogy

def check_is_count_data(X: Union[np.ndarray, sparse.spmatrix]):
    """Checks values of X to ensure it is count data"""
    from numbers import Integral

    data = X if isinstance(X, np.ndarray) else X.data
    # Check no negatives
    if np.signbit(data).any():
        return False
    # Check all are integers
    elif issubclass(data.dtype.type, Integral):
        return True
    elif np.any(~np.equal(np.mod(data, 1), 0)):
        return False
    else:
        return True

def deviance_residuals(x, theta,mu=None):
    '''Computes deviance residuals for NB model with a fixed theta'''

    if mu is None:
        counts_sum0 = np.sum(x, axis=0, keepdims=True)
        counts_sum1 = np.sum(x, axis=1, keepdims=True)
        counts_sum  = np.sum(x)
        #get residuals
        mu = counts_sum1 @ counts_sum0 / counts_sum
    
    
    
    
    def remove_negatives(sqrt_term):
        negatives_idx = sqrt_term < 0
        if np.any(negatives_idx):
            n_negatives = np.sum(negatives_idx)
            print('Setting %u negative sqrt term values to 0 (%f%%)' % (n_negatives,n_negatives/np.product(sqrt_term.shape)))
            sqrt_term[negatives_idx] = 0
    

    if np.isinf(theta): ### POISSON
        x_minus_mu = x-mu
        sqrt_term =                          2 * (xlogy(x,x/mu) - x_minus_mu   ) #xlogy(x,x/mu) computes xlog(x/mu) and returns 0 if x=0
        remove_negatives(sqrt_term)
        dev = np.sign(x_minus_mu) * np.sqrt(sqrt_term)
    else:               ### NEG BIN
        x_plus_theta = x+theta
        sqrt_term =                    2 * ( xlogy(x,x/mu)     -   (x_plus_theta)  * np.log(x_plus_theta/(mu+theta))     ) #xlogy(x,x/mu) computes xlog(x/mu) and returns 0 if x=0
        remove_negatives(sqrt_term)
        dev = np.sign(x-mu) * np.sqrt(sqrt_term)
    
    return dev


def highly_variable_residuals(
    adata: AnnData,
    layer: Optional[str] = None,
    n_top_genes: int = 1000,
    batch_key: Optional[str] = None,
    theta: float = 100,
    clip: Optional[float] = None,
    chunksize: int = 100,
    check_values: bool = True,
    subset: bool = False,
    inplace: bool = True,
    residual_type: str = 'pearson',
    debug=False,
) -> Optional[pd.DataFrame]:
    """\
    See `highly_variable_genes`.

    Returns
    -------
    Depending on `inplace` returns calculated metrics (:class:`~pd.DataFrame`)
    or updates `.var` with the following fields:

    highly_variable
        boolean indicator of highly-variable genes.
    means
        means per gene.
    variances
        variances per gene.
    residual_variances
        Pearson residual variance per gene. Averaged in the case of multiple
        batches.
    highly_variable_rank
        Rank of the gene according to residual variance, median rank in the
        case of multiple batches.
    highly_variable_nbatches : int
        If batch_key is given, this denotes in how many batches genes are
        detected as HVG.
    highly_variable_intersection : bool
        If batch_key is given, this denotes the genes that are highly variable
        in all batches.
    """

    
    # view_to_actual(adata)
    # X = _get_obs_rep(adata, layer=layer)
    if layer is None:
        layer = DKM.X_LAYER
    X = DKM.select_layer_data(adata, layer)
    computed_on = layer if layer else 'adata.X'
    

    # Check for raw counts
    if check_values and (check_is_count_data(X) == False):
        warnings.warn(
            "`flavor='pearson_residuals'` expects raw count data, but non-integers were found.",
            UserWarning,
        )

    if batch_key is None:
        batch_info = np.zeros(adata.shape[0], dtype=int)
    else:
        batch_info = adata.obs[batch_key].values
    n_batches = len(np.unique(batch_info))

    # Get pearson residuals for each batch separately
    residual_gene_vars = []
    for batch in np.unique(batch_info):
        
        adata_subset = adata[batch_info == batch]
        

        # Filter out zero genes
        # with settings.verbosity.override(Verbosity.error):
            # nonzero_genes = filter_genes(adata_subset, min_cells=1, inplace=False)[0]
        dyn.preprocessing.filter_genes_by_outliers(adata, min_count_s=1)
        nonzero_genes = adata.var["pass_basic_filter"]
        adata_subset = adata_subset[:, nonzero_genes]
        

        if layer is not None:
            X_batch = DKM.select_layer_data(adata_subset, layer)
        else:
            X_batch = adata_subset.X

        # Prepare clipping
        if clip is None:
            n = X_batch.shape[0]
            clip = np.sqrt(n)
        if clip < 0:
            raise ValueError("Pearson residuals require `clip>=0` or `clip=None`.")

        if sp_sparse.issparse(X_batch):
            sums_genes = np.sum(X_batch, axis=0)
            sums_cells = np.sum(X_batch, axis=1)
            sum_total = np.sum(sums_genes).squeeze()
        else:
            sums_genes = np.sum(X_batch, axis=0, keepdims=True)
            sums_cells = np.sum(X_batch, axis=1, keepdims=True)
            sum_total = np.sum(sums_genes)

        # Compute pearson residuals in chunks
        residual_gene_var = np.empty((X_batch.shape[1]))
        for start in np.arange(0, X_batch.shape[1], chunksize):
            stop = start + chunksize

            X_dense = X_batch[:, start:stop].toarray()
            mu = np.array(sums_cells @ sums_genes[:, start:stop] / sum_total) 
            if residual_type == 'pearson':                               
                residuals = (X_dense - mu) / np.sqrt(mu + mu ** 2 / theta)
                residuals = np.clip(residuals, a_min=-clip, a_max=clip)
            elif residual_type == 'deviance':
                residuals = deviance_residuals(X_dense,theta,mu)
            residual_gene_var[start:stop] = np.var(residuals, axis=0)


        # Add 0 values for genes that were filtered out
        zero_gene_var = np.zeros(np.sum(~nonzero_genes))
        residual_gene_var = np.concatenate((residual_gene_var, zero_gene_var))
        # Order as before filtering
        idxs = np.concatenate((np.where(nonzero_genes)[0], np.where(~nonzero_genes)[0]))
        residual_gene_var = residual_gene_var[np.argsort(idxs)]
        residual_gene_vars.append(residual_gene_var.reshape(1, -1))



    residual_gene_vars = np.concatenate(residual_gene_vars, axis=0)

    # Get cutoffs and define hvgs per batch
    residual_gene_vars_sorted = np.sort(residual_gene_vars, axis=1)
    cutoffs_per_batch = residual_gene_vars_sorted[:, -n_top_genes]
    highly_variable_per_batch = np.greater_equal(
        residual_gene_vars.T, cutoffs_per_batch
    ).T

    # Merge hvgs across batches
    highly_variable_nbatches = np.sum(highly_variable_per_batch, axis=0)
    highly_variable_intersection = highly_variable_nbatches == n_batches

    # Get rank per gene within each batch
    # argsort twice gives ranks, small rank means most variable
    ranks_residual_var = np.argsort(np.argsort(-residual_gene_vars, axis=1), axis=1)
    ranks_residual_var = ranks_residual_var.astype(np.float32)
    ranks_residual_var[ranks_residual_var >= n_top_genes] = np.nan
    ranks_masked_array = np.ma.masked_invalid(ranks_residual_var)
    # Median rank across batches,
    # ignoring batches in which gene was not selected
    medianrank_residual_var = np.ma.median(ranks_masked_array, axis=0).filled(np.nan)
    
    means, variances, _ = dyn.preprocessing.preprocessor_utils.calc_mean_var_dispersion_general_mat(adata.X)
    # means, variances = materialize_as_ndarray(_get_mean_var(X))

    
    df = pd.DataFrame.from_dict(
        dict(
            means=means,
            variances=variances,
            residual_variances=np.mean(residual_gene_vars, axis=0),
            highly_variable_rank=medianrank_residual_var,
            highly_variable_nbatches=highly_variable_nbatches,
            highly_variable_intersection=highly_variable_intersection,
        )
    )
    df = df.set_index(adata.var_names)

    # Sort genes by how often they selected as hvg within each batch and
    # break ties with median rank of residual variance across batches
    df.sort_values(
        ['highly_variable_nbatches', 'highly_variable_rank'],
        ascending=[False, True],
        na_position='last',
        inplace=True,
    )
    df['highly_variable'] = False
    df.highly_variable.iloc[:n_top_genes] = True
    # TODO: following line raises a pandas warning
    # (also for flavor = seurat and cellranger..)
    df = df.loc[adata.var_names]

    if inplace or subset:
        adata.uns['hvg'] = {'flavor': 'pearson_residuals', 'computed_on': computed_on}
        main_info(
            'added\n'
            '    \'highly_variable\', boolean vector (adata.var)\n'
            '    \'highly_variable_rank\', float vector (adata.var)\n'
            '    \'highly_variable_nbatches\', int vector (adata.var)\n'
            '    \'highly_variable_intersection\', boolean vector (adata.var)\n'
            '    \'means\', float vector (adata.var)\n'
            '    \'variances\', float vector (adata.var)\n'
            '    \'residual_variances\', float vector (adata.var)'
        )
        adata.var['highly_variable'] = df['highly_variable'].values
        adata.var['highly_variable_rank'] = df['highly_variable_rank'].values
        adata.var['means'] = df['means'].values
        adata.var['variances'] = df['variances'].values
        adata.var['residual_variances'] = df['residual_variances'].values.astype(
            'float64', copy=False
        )
        if batch_key is not None:
            adata.var['highly_variable_nbatches'] = df[
                'highly_variable_nbatches'
            ].values
            adata.var['highly_variable_intersection'] = df[
                'highly_variable_intersection'
            ].values
        if subset:
            adata._inplace_subset_var(df['highly_variable'].values)
    else:
        if batch_key is None:
            df = df.drop(
                ['highly_variable_nbatches', 'highly_variable_intersection'], axis=1
            )
        return df


import warnings
warnings.filterwarnings('ignore')

import dynamo as dyn 
from dynamo.configuration import DKM
import numpy as np
adata = dyn.sample_data.zebrafish()
highly_variable_residuals(adata)

|-----> Downloading data to ./data/zebrafish.h5ad
|-----> added
    'highly_variable', boolean vector (adata.var)
    'highly_variable_rank', float vector (adata.var)
    'highly_variable_nbatches', int vector (adata.var)
    'highly_variable_intersection', boolean vector (adata.var)
    'means', float vector (adata.var)
    'variances', float vector (adata.var)
    'residual_variances', float vector (adata.var)
