forked from scverse/scanpy
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Moved data accesors to new module
sc.get
Based on discussion from in: scverse#562
- Loading branch information
Showing
5 changed files
with
285 additions
and
278 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,208 @@ | ||
"""This module contains helper functions for accessing data.""" | ||
from typing import Optional, Iterable, Tuple | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from scipy import sparse | ||
|
||
from anndata import AnnData | ||
# -------------------------------------------------------------------------------- | ||
# Plotting data helpers | ||
# -------------------------------------------------------------------------------- | ||
|
||
|
||
# TODO: implement diffxpy method, make singledispatch | ||
def rank_genes_groups_df( | ||
adata: AnnData, | ||
group: str, # Can this be something other than a str? | ||
*, | ||
key: str = "rank_genes_groups", | ||
pval_cutoff: Optional[float] = None, | ||
log2fc_min: Optional[float] = None, | ||
log2fc_max: Optional[float] = None, | ||
gene_symbols: Optional[str] = None | ||
) -> pd.DataFrame: | ||
""" | ||
Get `rank_genes_groups` results in the form of a :class:`pd.DataFrame`. | ||
Params | ||
------ | ||
adata | ||
Object to get results from. | ||
group | ||
Which group (key from :func:`scanpy.tl.rank_genes_groups` `groupby`) to | ||
return results from. | ||
key | ||
Key differential expression groups were stored under. | ||
pval_cutoff | ||
Minimum adjusted pval to return. | ||
log2fc_min | ||
Minumum logfc to return. | ||
log2fc_max | ||
Maximum logfc to return. | ||
gene_symbols | ||
Column name in `.var` DataFrame that stores gene symbols. Specifying | ||
this will add that column to the returned dataframe. | ||
Example | ||
------- | ||
>>> pbmc = sc.datasets.pbmc68k_reduced() | ||
>>> sc.tl.rank_genes_groups(pbmc, groupby="louvain", use_raw=True, n_genes=pbmc.shape[1]) | ||
>>> dedf = sc.get.rank_genes_groups_df(pbmc, group="0") | ||
""" | ||
d = pd.DataFrame() | ||
for k in ['scores', 'names', 'logfoldchanges', 'pvals', 'pvals_adj']: | ||
d[k] = adata.uns["rank_genes_groups"][k][group] | ||
if pval_cutoff is not None: | ||
d = d[d["pvals_adj"] < pval_cutoff] | ||
if log2fc_min is not None: | ||
d = d[d["logfoldchanges"] > log2fc_min] | ||
if log2fc_max is not None: | ||
d = d[d["logfoldchanges"] < log2fc_max] | ||
if gene_symbols is not None: | ||
d = d.join(adata.var[gene_symbols], on="names") | ||
return d | ||
|
||
|
||
def obs_df( | ||
adata: AnnData, | ||
keys: Iterable[str] = (), | ||
obsm_keys: Iterable[Tuple[str, int]] = (), | ||
*, | ||
layer: str = None, | ||
gene_symbols: str = None, | ||
) -> pd.DataFrame: | ||
"""\ | ||
Return values for observations in adata. | ||
Params | ||
------ | ||
adata | ||
AnnData object to get values from. | ||
keys | ||
Keys from either `.var_names`, `.var[gene_symbols]`, or `.obs.columns`. | ||
obsm_keys | ||
Tuple of ``(key from obsm, column index of obsm[key])`. | ||
layer | ||
Layer of `adata` to use as expression values. | ||
gene_symbols | ||
Column of `adata.var` to search for `keys` in. | ||
Returns | ||
------- | ||
A dataframe with `adata.obs_names` as index, and values specified by `keys` | ||
and `obsm_keys`. | ||
Examples | ||
-------- | ||
Getting value for plotting: | ||
>>> pbmc = sc.datasets.pbmc68k_reduced() | ||
>>> plotdf = sc.get.obs_df( | ||
pbmc, | ||
keys=["CD8B", "n_genes"], | ||
obsm_keys=[("X_umap", 0), ("X_umap", 1)] | ||
) | ||
>>> plotdf.plot.scatter("X_umap0", "X_umap1", c="CD8B") | ||
Calculating mean expression for marker genes by cluster: | ||
>>> pbmc = sc.datasets.pbmc68k_reduced() | ||
>>> marker_genes = ['CD79A', 'MS4A1', 'CD8A', 'CD8B', 'LYZ'] | ||
>>> genedf = sc.get.obs_df( | ||
pbmc, | ||
keys=["louvain", *marker_genes] | ||
) | ||
>>> grouped = genedf.groupby("louvain") | ||
>>> mean, var = grouped.mean(), grouped.var() | ||
""" | ||
# Argument handling | ||
if gene_symbols is not None: | ||
gene_names = pd.Series(adata.var_names, index=adata.var[gene_symbols]) | ||
else: | ||
gene_names = pd.Series(adata.var_names, index=adata.var_names) | ||
lookup_keys = [] | ||
not_found = [] | ||
for key in keys: | ||
if key in adata.obs.columns: | ||
lookup_keys.append(key) | ||
elif key in gene_names.index: | ||
lookup_keys.append(gene_names[key]) | ||
else: | ||
not_found.append(key) | ||
if len(not_found) > 0: | ||
if gene_symbols is None: | ||
gene_error = "`adata.var_names`" | ||
else: | ||
gene_error = "gene_symbols column `adata.var[{}].values`".format(gene_symbols) | ||
raise KeyError( | ||
f"Could not find keys '{not_found}' in columns of `adata.obs` or in" | ||
f" {gene_error}." | ||
) | ||
|
||
# Make df | ||
df = pd.DataFrame(index=adata.obs_names) | ||
for k, l in zip(keys, lookup_keys): | ||
df[k] = adata.obs_vector(l, layer=layer) | ||
for k, idx in obsm_keys: | ||
added_k = f"{k}-{idx}" | ||
if isinstance(adata.obsm[k], (np.ndarray, sparse.csr_matrix)): | ||
df[added_k] = np.ravel(adata.obsm[k][:, idx]) | ||
elif isinstance(adata.obsm[k], pd.DataFrame): | ||
df[added_k] = adata.obsm[k].loc[:, idx] | ||
return df | ||
|
||
|
||
def var_df( | ||
adata: AnnData, | ||
keys: Iterable[str] = (), | ||
varm_keys: Iterable[Tuple[str, int]] = (), | ||
*, | ||
layer: str = None, | ||
) -> pd.DataFrame: | ||
"""\ | ||
Return values for observations in adata. | ||
Params | ||
------ | ||
adata | ||
AnnData object to get values from. | ||
keys | ||
Keys from either `.obs_names`, or `.var.columns`. | ||
varm_keys | ||
Tuple of ``(key from varm, column index of varm[key])`. | ||
layer | ||
Layer of `adata` to use as expression values. | ||
Returns | ||
------- | ||
A dataframe with `adata.var_names` as index, and values specified by `keys` | ||
and `varm_keys`. | ||
""" | ||
# Argument handling | ||
lookup_keys = [] | ||
not_found = [] | ||
for key in keys: | ||
if key in adata.var.columns: | ||
lookup_keys.append(key) | ||
elif key in adata.obs_names: | ||
lookup_keys.append(key) | ||
else: | ||
not_found.append(key) | ||
if len(not_found) > 0: | ||
raise KeyError( | ||
f"Could not find keys '{not_found}' in columns of `adata.var` or" | ||
" in `adata.obs_names`." | ||
) | ||
|
||
# Make df | ||
df = pd.DataFrame(index=adata.var_names) | ||
for k, l in zip(keys, lookup_keys): | ||
df[k] = adata.var_vector(l, layer=layer) | ||
for k, idx in varm_keys: | ||
added_k = f"{k}-{idx}" | ||
if isinstance(adata.varm[k], (np.ndarray, sparse.csr_matrix)): | ||
df[added_k] = np.ravel(adata.varm[k][:, idx]) | ||
elif isinstance(adata.varm[k], pd.DataFrame): | ||
df[added_k] = adata.varm[k].loc[:, idx] | ||
return df |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from itertools import repeat, chain | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
|
||
import scanpy as sc | ||
|
||
|
||
def test_obs_df(): | ||
adata = sc.AnnData( | ||
X=np.ones((2, 2)), | ||
obs=pd.DataFrame({"obs1": [0, 1], "obs2": ["a", "b"]}, index=["cell1", "cell2"]), | ||
var=pd.DataFrame({"gene_symbols": ["genesymbol1", "genesymbol2"]}, index=["gene1", "gene2"]), | ||
obsm={"eye": np.eye(2)}, | ||
layers={"double": np.ones((2, 2)) * 2} | ||
) | ||
assert np.all(np.equal( | ||
sc.get.obs_df(adata, keys=["gene2", "obs1"], obsm_keys=[("eye", 0)]), | ||
pd.DataFrame({"gene2": [1, 1], "obs1": [0, 1], "eye-0": [1, 0]}, index=adata.obs_names) | ||
)) | ||
assert np.all(np.equal( | ||
sc.get.obs_df(adata, keys=["genesymbol2", "obs1"], obsm_keys=[("eye", 0)], gene_symbols="gene_symbols"), | ||
pd.DataFrame({"genesymbol2": [1, 1], "obs1": [0, 1], "eye-0": [1, 0]}, index=adata.obs_names) | ||
)) | ||
assert np.all(np.equal( | ||
sc.get.obs_df(adata, keys=["gene2", "obs1"], layer="double"), | ||
pd.DataFrame({"gene2": [2, 2], "obs1": [0, 1]}, index=adata.obs_names) | ||
)) | ||
badkeys = ["badkey1", "badkey2"] | ||
with pytest.raises(KeyError) as badkey_err: | ||
sc.get.obs_df(adata, keys=badkeys) | ||
assert all(badkey_err.match(k) for k in badkeys) | ||
|
||
|
||
def test_var_df(): | ||
adata = sc.AnnData( | ||
X=np.ones((2, 2)), | ||
obs=pd.DataFrame({"obs1": [0, 1], "obs2": ["a", "b"]}, index=["cell1", "cell2"]), | ||
var=pd.DataFrame({"gene_symbols": ["genesymbol1", "genesymbol2"]}, index=["gene1", "gene2"]), | ||
varm={"eye": np.eye(2)}, | ||
layers={"double": np.ones((2, 2)) * 2} | ||
) | ||
assert np.all(np.equal( | ||
sc.get.var_df(adata, keys=["cell2", "gene_symbols"], varm_keys=[("eye", 0)]), | ||
pd.DataFrame({"cell2": [1, 1], "gene_symbols": ["genesymbol1", "genesymbol2"], "eye-0": [1, 0]}, index=adata.obs_names) | ||
)) | ||
assert np.all(np.equal( | ||
sc.get.var_df(adata, keys=["cell1", "gene_symbols"], layer="double"), | ||
pd.DataFrame({"cell1": [2, 2], "gene_symbols": ["genesymbol1", "genesymbol2"]}, index=adata.obs_names) | ||
)) | ||
badkeys = ["badkey1", "badkey2"] | ||
with pytest.raises(KeyError) as badkey_err: | ||
sc.get.var_df(adata, keys=badkeys) | ||
assert all(badkey_err.match(k) for k in badkeys) | ||
|
||
|
||
def test_rank_genes_groups_df(): | ||
a = np.zeros((20, 3)) | ||
a[:10, 0] = 5 | ||
adata = sc.AnnData( | ||
a, | ||
obs=pd.DataFrame( | ||
{"celltype": list(chain(repeat("a", 10), repeat("b", 10)))}, | ||
index=[f"cell{i}" for i in range(a.shape[0])] | ||
), | ||
var=pd.DataFrame(index=[f"gene{i}" for i in range(a.shape[1])]), | ||
) | ||
sc.tl.rank_genes_groups(adata, groupby="celltype", method="wilcoxon") | ||
dedf = sc.get.rank_genes_groups_df(adata, "a") | ||
assert dedf["pvals"].value_counts()[1.] == 2 | ||
assert sc.get.rank_genes_groups_df(adata, "a", log2fc_max=.1).shape[0] == 2 | ||
assert sc.get.rank_genes_groups_df(adata, "a", log2fc_min=.1).shape[0] == 1 | ||
assert sc.get.rank_genes_groups_df(adata, "a", pval_cutoff=.9).shape[0] == 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.