Skip to content

Commit

Permalink
Moved data accesors to new module sc.get
Browse files Browse the repository at this point in the history
Based on discussion from in: scverse#562
  • Loading branch information
ivirshup committed Jun 24, 2019
1 parent 3246fa8 commit 9fa7672
Show file tree
Hide file tree
Showing 5 changed files with 285 additions and 276 deletions.
2 changes: 1 addition & 1 deletion scanpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from . import tools as tl
from . import preprocessing as pp
from . import plotting as pl
from . import datasets, logging, queries, external
from . import datasets, logging, queries, external, get

from anndata import AnnData
from anndata import read_h5ad, read_csv, read_excel, read_hdf, read_loom, read_mtx, read_text, read_umi_tools
Expand Down
208 changes: 208 additions & 0 deletions scanpy/get.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
"""This module contains helper functions for accessing data."""
from typing import Optional, Iterable, Tuple

import numpy as np
import pandas as pd
from scipy import sparse

from anndata import AnnData
# --------------------------------------------------------------------------------
# Plotting data helpers
# --------------------------------------------------------------------------------


# TODO: implement diffxpy method, make singledispatch
def rank_genes_groups_df(
adata: AnnData,
group: str, # Can this be something other than a str?
*,
key: str = "rank_genes_groups",
pval_cutoff: Optional[float] = None,
log2fc_min: Optional[float] = None,
log2fc_max: Optional[float] = None,
gene_symbols: Optional[str] = None
) -> pd.DataFrame:
"""
Get `rank_genes_groups` results in the form of a :class:`pd.DataFrame`.
Params
------
adata
Object to get results from.
group
Which group (key from :func:`scanpy.tl.rank_genes_groups` `groupby`) to
return results from.
key
Key differential expression groups were stored under.
pval_cutoff
Minimum adjusted pval to return.
log2fc_min
Minumum logfc to return.
log2fc_max
Maximum logfc to return.
gene_symbols
Column name in `.var` DataFrame that stores gene symbols. Specifying
this will add that column to the returned dataframe.
Example
-------
>>> pbmc = sc.datasets.pbmc68k_reduced()
>>> sc.tl.rank_genes_groups(pbmc, groupby="louvain", use_raw=True, n_genes=pbmc.shape[1])
>>> dedf = sc.get.rank_genes_groups_df(pbmc, group="0")
"""
d = pd.DataFrame()
for k in ['scores', 'names', 'logfoldchanges', 'pvals', 'pvals_adj']:
d[k] = adata.uns["rank_genes_groups"][k][group]
if pval_cutoff is not None:
d = d[d["pvals_adj"] < pval_cutoff]
if log2fc_min is not None:
d = d[d["logfoldchanges"] > log2fc_min]
if log2fc_max is not None:
d = d[d["logfoldchanges"] < log2fc_max]
if gene_symbols is not None:
d = d.join(adata.var[gene_symbols], on="names")
return d


def obs_df(
adata: AnnData,
keys: Iterable[str] = (),
obsm_keys: Iterable[Tuple[str, int]] = (),
*,
layer: str = None,
gene_symbols: str = None,
) -> pd.DataFrame:
"""\
Return values for observations in adata.
Params
------
adata
AnnData object to get values from.
keys
Keys from either `.var_names`, `.var[gene_symbols]`, or `.obs.columns`.
obsm_keys
Tuple of ``(key from obsm, column index of obsm[key])`.
layer
Layer of `adata` to use as expression values.
gene_symbols
Column of `adata.var` to search for `keys` in.
Returns
-------
A dataframe with `adata.obs_names` as index, and values specified by `keys`
and `obsm_keys`.
Examples
--------
Getting value for plotting:
>>> pbmc = sc.datasets.pbmc68k_reduced()
>>> plotdf = sc.get.obs_df(
pbmc,
keys=["CD8B", "n_genes"],
obsm_keys=[("X_umap", 0), ("X_umap", 1)]
)
>>> plotdf.plot.scatter("X_umap0", "X_umap1", c="CD8B")
Calculating mean expression for marker genes by cluster:
>>> pbmc = sc.datasets.pbmc68k_reduced()
>>> marker_genes = ['CD79A', 'MS4A1', 'CD8A', 'CD8B', 'LYZ']
>>> genedf = sc.get.obs_df(
pbmc,
keys=["louvain", *marker_genes]
)
>>> grouped = genedf.groupby("louvain")
>>> mean, var = grouped.mean(), grouped.var()
"""
# Argument handling
if gene_symbols is not None:
gene_names = pd.Series(adata.var_names, index=adata.var[gene_symbols])
else:
gene_names = pd.Series(adata.var_names, index=adata.var_names)
lookup_keys = []
not_found = []
for key in keys:
if key in adata.obs.columns:
lookup_keys.append(key)
elif key in gene_names.index:
lookup_keys.append(gene_names[key])
else:
not_found.append(key)
if len(not_found) > 0:
if gene_symbols is None:
gene_error = "`adata.var_names`"
else:
gene_error = "gene_symbols column `adata.var[{}].values`".format(gene_symbols)
raise KeyError(
f"Could not find keys '{not_found}' in columns of `adata.obs` or in"
f" {gene_error}."
)

# Make df
df = pd.DataFrame(index=adata.obs_names)
for k, l in zip(keys, lookup_keys):
df[k] = adata.obs_vector(l, layer=layer)
for k, idx in obsm_keys:
added_k = f"{k}-{idx}"
if isinstance(adata.obsm[k], (np.ndarray, sparse.csr_matrix)):
df[added_k] = np.ravel(adata.obsm[k][:, idx])
elif isinstance(adata.obsm[k], pd.DataFrame):
df[added_k] = adata.obsm[k].loc[:, idx]
return df


def var_df(
adata: AnnData,
keys: Iterable[str] = (),
varm_keys: Iterable[Tuple[str, int]] = (),
*,
layer: str = None,
) -> pd.DataFrame:
"""\
Return values for observations in adata.
Params
------
adata
AnnData object to get values from.
keys
Keys from either `.obs_names`, or `.var.columns`.
varm_keys
Tuple of ``(key from varm, column index of varm[key])`.
layer
Layer of `adata` to use as expression values.
Returns
-------
A dataframe with `adata.var_names` as index, and values specified by `keys`
and `varm_keys`.
"""
# Argument handling
lookup_keys = []
not_found = []
for key in keys:
if key in adata.var.columns:
lookup_keys.append(key)
elif key in adata.obs_names:
lookup_keys.append(key)
else:
not_found.append(key)
if len(not_found) > 0:
raise KeyError(
f"Could not find keys '{not_found}' in columns of `adata.var` or"
" in `adata.obs_names`."
)

# Make df
df = pd.DataFrame(index=adata.var_names)
for k, l in zip(keys, lookup_keys):
df[k] = adata.var_vector(l, layer=layer)
for k, idx in varm_keys:
added_k = f"{k}-{idx}"
if isinstance(adata.varm[k], (np.ndarray, sparse.csr_matrix)):
df[added_k] = np.ravel(adata.varm[k][:, idx])
elif isinstance(adata.varm[k], pd.DataFrame):
df[added_k] = adata.varm[k].loc[:, idx]
return df
74 changes: 74 additions & 0 deletions scanpy/tests/test_get.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from itertools import repeat, chain

import numpy as np
import pandas as pd
import pytest

import scanpy as sc


def test_obs_df():
adata = sc.AnnData(
X=np.ones((2, 2)),
obs=pd.DataFrame({"obs1": [0, 1], "obs2": ["a", "b"]}, index=["cell1", "cell2"]),
var=pd.DataFrame({"gene_symbols": ["genesymbol1", "genesymbol2"]}, index=["gene1", "gene2"]),
obsm={"eye": np.eye(2)},
layers={"double": np.ones((2, 2)) * 2}
)
assert np.all(np.equal(
sc.get.obs_df(adata, keys=["gene2", "obs1"], obsm_keys=[("eye", 0)]),
pd.DataFrame({"gene2": [1, 1], "obs1": [0, 1], "eye-0": [1, 0]}, index=adata.obs_names)
))
assert np.all(np.equal(
sc.get.obs_df(adata, keys=["genesymbol2", "obs1"], obsm_keys=[("eye", 0)], gene_symbols="gene_symbols"),
pd.DataFrame({"genesymbol2": [1, 1], "obs1": [0, 1], "eye-0": [1, 0]}, index=adata.obs_names)
))
assert np.all(np.equal(
sc.get.obs_df(adata, keys=["gene2", "obs1"], layer="double"),
pd.DataFrame({"gene2": [2, 2], "obs1": [0, 1]}, index=adata.obs_names)
))
badkeys = ["badkey1", "badkey2"]
with pytest.raises(KeyError) as badkey_err:
sc.get.obs_df(adata, keys=badkeys)
assert all(badkey_err.match(k) for k in badkeys)


def test_var_df():
adata = sc.AnnData(
X=np.ones((2, 2)),
obs=pd.DataFrame({"obs1": [0, 1], "obs2": ["a", "b"]}, index=["cell1", "cell2"]),
var=pd.DataFrame({"gene_symbols": ["genesymbol1", "genesymbol2"]}, index=["gene1", "gene2"]),
varm={"eye": np.eye(2)},
layers={"double": np.ones((2, 2)) * 2}
)
assert np.all(np.equal(
sc.get.var_df(adata, keys=["cell2", "gene_symbols"], varm_keys=[("eye", 0)]),
pd.DataFrame({"cell2": [1, 1], "gene_symbols": ["genesymbol1", "genesymbol2"], "eye-0": [1, 0]}, index=adata.obs_names)
))
assert np.all(np.equal(
sc.get.var_df(adata, keys=["cell1", "gene_symbols"], layer="double"),
pd.DataFrame({"cell1": [2, 2], "gene_symbols": ["genesymbol1", "genesymbol2"]}, index=adata.obs_names)
))
badkeys = ["badkey1", "badkey2"]
with pytest.raises(KeyError) as badkey_err:
sc.get.var_df(adata, keys=badkeys)
assert all(badkey_err.match(k) for k in badkeys)


def test_rank_genes_groups_df():
a = np.zeros((20, 3))
a[:10, 0] = 5
adata = sc.AnnData(
a,
obs=pd.DataFrame(
{"celltype": list(chain(repeat("a", 10), repeat("b", 10)))},
index=[f"cell{i}" for i in range(a.shape[0])]
),
var=pd.DataFrame(index=[f"gene{i}" for i in range(a.shape[1])]),
)
sc.tl.rank_genes_groups(adata, groupby="celltype", method="wilcoxon")
dedf = sc.get.rank_genes_groups_df(adata, "a")
assert dedf["pvals"].value_counts()[1.] == 2
assert sc.get.rank_genes_groups_df(adata, "a", log2fc_max=.1).shape[0] == 2
assert sc.get.rank_genes_groups_df(adata, "a", log2fc_min=.1).shape[0] == 1
assert sc.get.rank_genes_groups_df(adata, "a", pval_cutoff=.9).shape[0] == 1
74 changes: 1 addition & 73 deletions scanpy/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
from types import ModuleType
from itertools import repeat, chain

import numpy as np
import pandas as pd
import pytest

import scanpy as sc
from scanpy.utils import descend_classes_and_funcs, obs_df, var_df, rank_genes_groups_df
from scanpy.utils import descend_classes_and_funcs


def test_descend_classes_and_funcs():
Expand All @@ -24,70 +19,3 @@ def test_descend_classes_and_funcs():
a.b.a = a

assert {a.A, a.b.B} == set(descend_classes_and_funcs(a, 'a'))


def test_obs_df():
adata = sc.AnnData(
X=np.ones((2, 2)),
obs=pd.DataFrame({"obs1": [0, 1], "obs2": ["a", "b"]}, index=["cell1", "cell2"]),
var=pd.DataFrame({"gene_symbols": ["genesymbol1", "genesymbol2"]}, index=["gene1", "gene2"]),
obsm={"eye": np.eye(2)},
layers={"double": np.ones((2, 2)) * 2}
)
assert np.all(np.equal(
obs_df(adata, keys=["gene2", "obs1"], obsm_keys=[("eye", 0)]),
pd.DataFrame({"gene2": [1, 1], "obs1": [0, 1], "eye-0": [1, 0]}, index=adata.obs_names)
))
assert np.all(np.equal(
obs_df(adata, keys=["genesymbol2", "obs1"], obsm_keys=[("eye", 0)], gene_symbols="gene_symbols"),
pd.DataFrame({"genesymbol2": [1, 1], "obs1": [0, 1], "eye-0": [1, 0]}, index=adata.obs_names)
))
assert np.all(np.equal(
obs_df(adata, keys=["gene2", "obs1"], layer="double"),
pd.DataFrame({"gene2": [2, 2], "obs1": [0, 1]}, index=adata.obs_names)
))
badkeys = ["badkey1", "badkey2"]
with pytest.raises(KeyError) as badkey_err:
obs_df(adata, keys=badkeys)
assert all(badkey_err.match(k) for k in badkeys)


def test_var_df():
adata = sc.AnnData(
X=np.ones((2, 2)),
obs=pd.DataFrame({"obs1": [0, 1], "obs2": ["a", "b"]}, index=["cell1", "cell2"]),
var=pd.DataFrame({"gene_symbols": ["genesymbol1", "genesymbol2"]}, index=["gene1", "gene2"]),
varm={"eye": np.eye(2)},
layers={"double": np.ones((2, 2)) * 2}
)
assert np.all(np.equal(
var_df(adata, keys=["cell2", "gene_symbols"], varm_keys=[("eye", 0)]),
pd.DataFrame({"cell2": [1, 1], "gene_symbols": ["genesymbol1", "genesymbol2"], "eye-0": [1, 0]}, index=adata.obs_names)
))
assert np.all(np.equal(
var_df(adata, keys=["cell1", "gene_symbols"], layer="double"),
pd.DataFrame({"cell1": [2, 2], "gene_symbols": ["genesymbol1", "genesymbol2"]}, index=adata.obs_names)
))
badkeys = ["badkey1", "badkey2"]
with pytest.raises(KeyError) as badkey_err:
var_df(adata, keys=badkeys)
assert all(badkey_err.match(k) for k in badkeys)


def test_rank_genes_groups_df():
a = np.zeros((20, 3))
a[:10, 0] = 5
adata = sc.AnnData(
a,
obs=pd.DataFrame(
{"celltype": list(chain(repeat("a", 10), repeat("b", 10)))},
index=[f"cell{i}" for i in range(a.shape[0])]
),
var=pd.DataFrame(index=[f"gene{i}" for i in range(a.shape[1])]),
)
sc.tl.rank_genes_groups(adata, groupby="celltype", method="wilcoxon")
dedf = rank_genes_groups_df(adata, "a")
assert dedf["pvals"].value_counts()[1.] == 2
assert rank_genes_groups_df(adata, "a", log2fc_max=.1).shape[0] == 2
assert rank_genes_groups_df(adata, "a", log2fc_min=.1).shape[0] == 1
assert rank_genes_groups_df(adata, "a", pval_cutoff=.9).shape[0] == 1
Loading

0 comments on commit 9fa7672

Please sign in to comment.