Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make scplot compatible with pegasusio #6

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 40 additions & 14 deletions scplot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pandas as pd
import scipy.sparse
from anndata import AnnData
from pegasusio import MultimodalData
from holoviews import dim
from holoviews.plotting.bokeh.callbacks import LinkCallback
from holoviews.plotting.links import Link
Expand All @@ -33,7 +34,7 @@ class __BrushLinkCallbackRange(LinkCallback):
target_handles = ['cds', 'glyph']

source_code = """

target_selected.indices = source_selected.indices;
"""

Expand Down Expand Up @@ -305,7 +306,7 @@ def __bin(df, nbins, coordinate_columns, reduce_function, coordinate_column_to_r
return df.groupby(coordinate_columns, as_index=False).agg(agg_func), df[coordinate_columns]


def violin(adata: AnnData, keys: Union[str, List[str], Tuple[str]], by: str = None,
def violin(adata: Union[AnnData, MultimodalData], keys: Union[str, List[str], Tuple[str]], by: str = None,
width: int = 300, cmap: Union[str, List[str], Tuple[str]] = None, cols: int = None,
use_raw: bool = None, **kwds) -> hv.core.element.Element:
"""
Expand All @@ -322,6 +323,8 @@ def violin(adata: AnnData, keys: Union[str, List[str], Tuple[str]], by: str = No
"""
if cols is None:
cols = 3
if not isinstance(adata, AnnData):
adata = adata.to_anndata()
adata_raw = __get_raw(adata, use_raw)
keys = __to_list(keys)
df = __get_df(adata, adata_raw, keys + ([] if by is None else [by]))
Expand All @@ -342,7 +345,7 @@ def violin(adata: AnnData, keys: Union[str, List[str], Tuple[str]], by: str = No
return layout


def heatmap(adata: AnnData, keys: Union[str, List[str], Tuple[str]], by: str,
def heatmap(adata: Union[AnnData, MultimodalData], keys: Union[str, List[str], Tuple[str]], by: str,
reduce_function: Callable[[np.ndarray], float] = np.mean,
use_raw: bool = None, cmap: Union[str, List[str], Tuple[str]] = 'Reds', **kwds) -> hv.core.element.Element:
"""
Expand All @@ -357,6 +360,8 @@ def heatmap(adata: AnnData, keys: Union[str, List[str], Tuple[str]], by: str,
use_raw: Use `raw` attribute of `adata` if present.
"""

if not isinstance(adata, AnnData):
adata = adata.to_anndata()
adata_raw = __get_raw(adata, use_raw)
keys = __to_list(keys)
df = None
Expand All @@ -378,7 +383,7 @@ def heatmap(adata: AnnData, keys: Union[str, List[str], Tuple[str]], by: str,
return df.hvplot.heatmap(x='feature', y=by, C='value', reduce_function=reduce_function, **keywords)


def scatter(adata: AnnData, x: str, y: str, color: str = None, size: Union[int, str] = None,
def scatter(adata: Union[AnnData, MultimodalData], x: str, y: str, color: str = None, size: Union[int, str] = None,
dot_min=2, dot_max=14, use_raw: bool = None, sort: bool = True, width: int = 400, height: int = 400,
nbins: int = -1, reduce_function: Callable[[np.array], float] = np.max,
cmap: Union[str, List[str], Tuple[str]] = None, palette: Union[str, List[str], Tuple[str]] = None,
Expand All @@ -403,6 +408,8 @@ def scatter(adata: AnnData, x: str, y: str, color: str = None, size: Union[int,
nbins: Number of bins used to summarize plot on a grid. Useful for large datasets. Negative one means automatically bin the plot.
reduce_function: Function used to summarize overlapping cells if nbins is specified
"""
if not isinstance(adata, AnnData):
adata = adata.to_anndata()
return __scatter(adata=adata, x=x, y=y, color=color, size=size, dot_min=dot_min, dot_max=dot_max, use_raw=use_raw,
sort=sort, width=width, height=height, nbins=nbins, reduce_function=reduce_function, cmap=cmap, palette=palette,
is_scatter=True, **kwds)
Expand Down Expand Up @@ -527,7 +534,7 @@ def __scatter(adata: AnnData, x: str, y: str, color=None, size: Union[int, str]
return return_value


def dotplot(adata: AnnData, keys: Union[str, List[str], Tuple[str]], by: str,
def dotplot(adata: Union[AnnData, MultimodalData], keys: Union[str, List[str], Tuple[str]], by: str,
reduce_function: Callable[[np.ndarray], float] = np.mean,
fraction_min: float = 0, fraction_max: float = None, dot_min: int = 1, dot_max: int = 26,
use_raw: bool = None, cmap: Union[str, List[str], Tuple[str]] = 'Reds',
Expand All @@ -549,6 +556,8 @@ def dotplot(adata: AnnData, keys: Union[str, List[str], Tuple[str]], by: str,
sort_function: Optional function that accepts summarized data frame and returns a list of row indices in the order to render in the heatmap.
"""

if not isinstance(adata, AnnData):
adata = adata.to_anndata()
adata_raw = __get_raw(adata, use_raw)
keys = __to_list(keys)
keywords = dict(colorbar=True, ylabel=str(by), xlabel='', padding=0, rot=90, cmap=cmap)
Expand Down Expand Up @@ -634,7 +643,7 @@ def non_zero(g):
return result


def scatter_matrix(adata: AnnData, keys: Union[str, List[str], Tuple[str]], color=None, use_raw: bool = None,
def scatter_matrix(adata: Union[AnnData, MultimodalData], keys: Union[str, List[str], Tuple[str]], color=None, use_raw: bool = None,
**kwds) -> hv.core.element.Element:
"""
Generate a scatter plot matrix.
Expand All @@ -646,6 +655,8 @@ def scatter_matrix(adata: AnnData, keys: Union[str, List[str], Tuple[str]], colo
use_raw: Use `raw` attribute of `adata` if present.
"""

if not isinstance(adata, AnnData):
adata = adata.to_anndata()
adata_raw = __get_raw(adata, use_raw)
keys = __to_list(keys)
if color is not None:
Expand Down Expand Up @@ -680,7 +691,7 @@ def __fix_scatter_colors(adata, df_to_plot, key, is_color_by_numeric, cmap, pale
del keywords[color_keyword_delete]


def embedding(adata: AnnData, basis: Union[str, List[str], Tuple[str]],
def embedding(adata: Union[AnnData, MultimodalData], basis: Union[str, List[str], Tuple[str]],
keys: Union[None, str, List[str], Tuple[str]] = None,
cmap: Union[str, List[str], Tuple[str]] = None, palette: Union[str, List[str], Tuple[str]] = None,
alpha: float = 1, size: float = None, width: int = 400, height: int = 400, sort: bool = True,
Expand Down Expand Up @@ -719,6 +730,8 @@ def embedding(adata: AnnData, basis: Union[str, List[str], Tuple[str]],
if keys is None:
keys = []
basis = __to_list(basis)
if not isinstance(adata, AnnData):
adata = adata.to_anndata()
adata_raw = __get_raw(adata, use_raw)
keys = __to_list(keys)
if tooltips is None:
Expand Down Expand Up @@ -812,17 +825,24 @@ def embedding(adata: AnnData, basis: Union[str, List[str], Tuple[str]],
return layout


def variable_feature_plot(adata: AnnData, **kwds) -> hv.core.element.Element:
def variable_feature_plot(adata: Union[AnnData, MultimodalData], **kwds) -> hv.core.element.Element:
"""
Generate a variable feature plot.

Args:
adata: Annotated data matrix.
"""

if 'hvf_loess' in adata.var:
keywords = dict(x='mean', y='var', y_fit='hvf_loess', color='highly_variable_features',
xlabel='Mean log expression', ylabel='Variance of log expression')
if not isinstance(adata, AnnData):
adata = adata.to_anndata()

if 'hvf_flavor' in adata.uns.keys():
if adata.uns['hvf_flavor'] == 'pegasus':
keywords = dict(x='mean', y='var', y_fit='hvf_loess', color='highly_variable_features',
xlabel='Mean log expression', ylabel='Variance of log expression')
else:
keywords = dict(x='mean', y='dispersion_norm', y_fit=None, color='highly_variable_features',
xlabel='Mean log expression', ylabel='Normalized dispersion')
else:
keywords = dict(x='means', y='dispersions_norm', y_fit=None, color='highly_variable',
xlabel='Mean log expression', ylabel='Normalized dispersion')
Expand All @@ -842,10 +862,14 @@ def variable_feature_plot(adata: AnnData, **kwds) -> hv.core.element.Element:
return scatter(adata, x=x, y=y, xlabel=xlabel, color=color,
ylabel=ylabel, **keywords) * line(adata, x=x, y=y_fit, line_color=line_color)
else:
return scatter(adata, x=x, y=y, color=color, xlabel=xlabel, ylabel=ylabel)
if 'robust' in adata.var:
return scatter(adata[:, adata.var['robust']], x=x, y=y, color=color,
xlabel=xlabel, ylabel=ylabel)
else:
return scatter(adata, x=x, y=y, color=color, xlabel=xlabel, ylabel=ylabel)


def volcano(adata: AnnData, basis: str = 'de_res', x: str = 'log_fold_change', y: str = 't_qval',
def volcano(adata: Union[AnnData, MultimodalData], basis: str = 'de_res', x: str = 'log_fold_change', y: str = 't_qval',
x_cutoff: float = 1, y_cutoff: float = 0.05, cluster_ids: Union[List, Tuple, Set] = None,
**kwds) -> hv.core.element.Element:
"""
Expand Down Expand Up @@ -921,7 +945,7 @@ def volcano(adata: AnnData, basis: str = 'de_res', x: str = 'log_fold_change', y
return result


def composition_plot(adata: AnnData, by: str, condition: str, stacked: bool = True, normalize: bool = True,
def composition_plot(adata: Union[AnnData, MultimodalData], by: str, condition: str, stacked: bool = True, normalize: bool = True,
condition_sort_by: str = None, cmap: Union[str, List[str], Tuple[str]] = None,
**kwds) -> hv.core.element.Element:
"""
Expand All @@ -937,6 +961,8 @@ def composition_plot(adata: AnnData, by: str, condition: str, stacked: bool = Tr
cmap: Color map name (hv.plotting.list_cmaps()) or a list of hex colors. See http://holoviews.org/user_guide/Styling_Plots.html for more information.
"""

if not isinstance(adata, AnnData):
adata = adata.to_anndata()
adata_raw = __get_raw(adata, False)
keys = [by, condition]
adata_df = __get_df(adata, adata_raw, keys)
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

requirements = [
'anndata',
'pegasusio',
'colorcet',
'holoviews',
'hvplot',
Expand Down