# Computing on X: Highly Variable Genes

*Goal:* demonstrate larger-than-core computation on X.

This demo finds highly variable genes in a user-specified cell selection. It is similar to the [scanpy.pp.high_variable_genes](https://scanpy.readthedocs.io/en/stable/generated/scanpy.pp.highly_variable_genes.html) function, when called with `flavor='seurat_v3'`.

*NOTE*: when query results are small, it may be easier to use the SOMAExperment Query class to extract an AnnData, and then just compute over that. This notebook is showing means of incrementally processing larger-than-core (RAM) data, where incremental (online) algorithms are used.

In [1]:
import cell_census
import tiledbsoma as soma

census = cell_census.open_soma()
human = census["census_data"]["homo_sapiens"]

In [2]:
import numpy as np
import pandas as pd
import tiledbsoma as soma

from cell_census.compute import OnlineMatrixMeanVariance


def highly_variable_genes(query: soma.ExperimentAxisQuery, n_top_genes: int = 10) -> pd.DataFrame:
    """
    Acknowledgements: scanpy highly variable genes implementation, github.com/scverse/scanpy
    """
    try:
        import skmisc.loess
    except ImportError:
        raise ImportError("Please install skmisc package via `pip install --user scikit-misc")

    indexer = query.indexer
    mvn = OnlineMatrixMeanVariance(query.n_obs, query.n_vars)
    for arrow_tbl in query.X("raw").tables():
        var_dim = indexer.by_var(arrow_tbl["soma_dim_1"])
        data = arrow_tbl["soma_data"].to_numpy()
        mvn.update(var_dim, data)

    u, v = mvn.finalize()
    var_df = pd.DataFrame(
        index=pd.Index(data=query.var_joinids(), name="soma_joinid"),
        data={
            "means": u,
            "variances": v,
        },
    )

    estimated_variances = np.zeros((len(var_df),), dtype=np.float64)
    not_const = v > 0
    y = np.log10(v[not_const])
    x = np.log10(u[not_const])
    model = skmisc.loess.loess(x, y, span=0.3, degree=2)
    model.fit()
    estimated_variances[not_const] = model.outputs.fitted_values
    reg_std = np.sqrt(10**estimated_variances)

    # A second pass over the data is required because the clip value
    # is determined by the first pass
    N = query.n_obs
    vmax = np.sqrt(N)
    clip_val = reg_std * vmax + u
    counts_sum = np.zeros((query.n_vars,), dtype=np.float64)  # clipped
    squared_counts_sum = np.zeros((query.n_vars,), dtype=np.float64)  # clipped
    for arrow_tbl in query.X("raw").tables():
        var_dim = indexer.by_var(arrow_tbl["soma_dim_1"])
        data = arrow_tbl["soma_data"].to_numpy()
        # clip
        mask = data > clip_val[var_dim]
        data = data.copy()
        data[mask] = clip_val[var_dim[mask]]
        np.add.at(counts_sum, var_dim, data)
        np.add.at(squared_counts_sum, var_dim, data**2)

    norm_gene_vars = (1 / ((N - 1) * np.square(reg_std))) * (
        (N * np.square(u)) + squared_counts_sum - 2 * counts_sum * u
    )
    norm_gene_vars = norm_gene_vars.reshape(1, -1)

    # argsort twice gives ranks, small rank means most variable
    ranked_norm_gene_vars = np.argsort(np.argsort(-norm_gene_vars, axis=1), axis=1)

    # this is done in SelectIntegrationFeatures() in Seurat v3
    ranked_norm_gene_vars = ranked_norm_gene_vars.astype(np.float32)
    num_batches_high_var = np.sum((ranked_norm_gene_vars < n_top_genes).astype(int), axis=0)
    ranked_norm_gene_vars[ranked_norm_gene_vars >= n_top_genes] = np.nan
    ma_ranked = np.ma.masked_invalid(ranked_norm_gene_vars)  # type: ignore
    median_ranked = np.ma.median(ma_ranked, axis=0).filled(np.nan)  # type: ignore

    var_df = var_df.assign(
        highly_variable_nbatches=pd.Series(num_batches_high_var, index=var_df.index),
        highly_variable_rank=pd.Series(median_ranked, index=var_df.index),
        variances_norm=pd.Series(np.mean(norm_gene_vars, axis=0), index=var_df.index),
    )

    sorted_index = (
        var_df[["highly_variable_rank", "highly_variable_nbatches"]]
        .sort_values(
            ["highly_variable_rank", "highly_variable_nbatches"],
            ascending=[True, False],
            na_position="last",
        )
        .index
    )
    var_df["highly_variable"] = False
    var_df = var_df.drop(columns=["highly_variable_nbatches"])
    var_df.loc[sorted_index[: int(n_top_genes)], "highly_variable"] = True
    return var_df

To use this funtion, which is also available in `cell_census.compute`, open a ExperimentQuery, and pass to the function as a parameter.

In [3]:
with human.axis_query(
    measurement_name="RNA",
    obs_query=soma.AxisQuery(value_filter="tissue == 'brain'"),
) as query:
    hvg = highly_variable_genes(query)

display(hvg[hvg.highly_variable])

Unnamed: 0_level_0,means,variances,highly_variable_rank,variances_norm,highly_variable
soma_joinid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
298,1.469808,27990.14,9.0,19.943684,True
3088,0.883362,34831.03,6.0,20.313552,True
3850,3.334319,266168.5,7.0,20.188258,True
4041,0.545531,5397.146,8.0,20.102508,True
10858,9.272079,1508529.0,1.0,25.324218,True
15568,5.726879,418069.3,0.0,26.027483,True
22410,2.593248,64117.81,5.0,20.669148,True
40604,5.558539,321908.0,2.0,24.388457,True
50271,6.475617,700085.6,4.0,21.696996,True
55861,8.855432,933523.1,3.0,23.784307,True
