Skip to content

Commit

Permalink
Merge pull request #250 from lilab-bcb/boli
Browse files Browse the repository at this point in the history
Several important updates
  • Loading branch information
yihming committed Jul 5, 2022
2 parents 3bf91da + 54db5cb commit e16428b
Show file tree
Hide file tree
Showing 11 changed files with 294 additions and 53 deletions.
3 changes: 3 additions & 0 deletions pegasus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
filter_data,
identify_robust_genes,
log_norm,
arcsinh_transform,
select_features,
pca,
pc_transform,
Expand Down Expand Up @@ -75,6 +76,7 @@
run_scvi,
train_scarches_scanvi,
predict_scarches_scanvi,
largest_variance_from_random_matrix,
)
from .annotate_cluster import infer_cell_types, annotate, infer_cluster_names
from .misc import search_genes, search_de_genes, find_outlier_clusters
Expand All @@ -94,6 +96,7 @@
ridgeplot,
wordcloud,
plot_gsea,
elbowplot,
)
from . import pseudo

Expand Down
2 changes: 1 addition & 1 deletion pegasus/annotate_cluster/human_lung_cell_markers.json
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@
}
],
"subtypes" : {
"title" : "SMC subtype markers",
"title" : "Fibro/Myofib subtype markers",
"cell_types" : [
{
"name" : "Adventitial fibroblast",
Expand Down
1 change: 1 addition & 0 deletions pegasus/plotting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@
ridgeplot,
wordcloud,
plot_gsea,
elbowplot,
)
187 changes: 146 additions & 41 deletions pegasus/plotting/plot_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import logging
logger = logging.getLogger(__name__)

from pegasus.tools import X_from_rep, slicing
from pegasus.tools import X_from_rep, slicing, largest_variance_from_random_matrix
from .plot_utils import (
_transform_basis,
_get_nrows_and_ncols,
Expand All @@ -36,14 +36,16 @@

def scatter(
data: Union[MultimodalData, UnimodalData, anndata.AnnData],
attrs: Union[str, List[str]] = None,
basis: Optional[str] = "umap",
attrs: Optional[Union[str, List[str]]] = None,
basis: Optional[Union[str, List[str]]] = "umap",
components: Optional[Union[Tuple[int, int], List[Tuple[int, int]]]] = (1, 2),
matkey: Optional[str] = None,
restrictions: Optional[Union[str, List[str]]] = None,
show_background: Optional[bool] = False,
fix_corners: Optional[bool] = True,
alpha: Optional[Union[float, List[float]]] = 1.0,
legend_loc: Optional[Union[str, List[str]]] = "right margin",
legend_fontsize: Optional[Union[int, List[int]]] = 10,
legend_ncol: Optional[str] = None,
palettes: Optional[Union[str, List[str]]] = None,
cmaps: Optional[Union[str, List[str]]] = "YlOrRd",
Expand Down Expand Up @@ -71,8 +73,10 @@ def scatter(
Use current selected modality in data.
attrs: ``str`` or ``List[str]``, default: None
Color scatter plots by attrs. Each attribute in attrs can be one key in data.obs, data.var_names (e.g. one gene) or data.obsm (attribute has the format of 'obsm_key@component', like 'X_pca@0'). If one attribute is categorical, a palette will be used to color each category separately. Otherwise, a color map will be used. If no attributes are provided, plot the basis for all data.
basis: ``str``, optional, default: ``umap``
Basis to be used to generate scatter plots. Can be either 'umap', 'tsne', 'fitsne', 'fle', 'net_tsne', 'net_fitsne', 'net_umap' or 'net_fle'.
basis: ``str`` or ``List[str]``, optional, default: ``umap``
Basis to be used to generate scatter plots. Can be either 'pca', 'diffmap', 'umap', 'tsne', 'fitsne', 'fle', 'net_tsne', 'net_fitsne', 'net_umap' or 'net_fle'. If `basis` is a list, each of element in `attrs` will be plotted for each basis in `basis`.
components: ``Tuple[int, int]`` or ``List[Tuple[int, int]]``, optional, default: ``(1, 2)``
Components in basis to be used. Default to the first two components. If `components` is a list, for each element in `attrs` and each `basis`, all components enumeration will be plotted.
matkey: ``str``, optional, default: None
If matkey is set, select matrix with matkey as keyword in the current modality. Only works for MultimodalData or UnimodalData objects.
restrictions: ``str`` or ``List[str]``, optional, default: None
Expand All @@ -85,6 +89,8 @@ def scatter(
Alpha value for blending, from 0.0 (transparent) to 1.0 (opaque). If this is a list, the length must match attrs, which means we set a separate alpha value for each attribute.
legend_loc: ``str`` or ``List[str]``, optional, default: ``right margin``
Legend location. Can be either "right margin" or "on data". If a list is provided, set 'legend_loc' for each attribute in 'attrs' separately.
legend_fontsize: ``int`` or ``List[int]``, optional, default: ``10``
Legend fontsize. If a list is provided, set 'legend_fontsize' for each attribute in 'attrs' separately.
legend_ncol: ``str``, optional, default: None
Only applicable if legend_loc == "right margin". Set number of columns used to show legends.
palettes: ``str`` or ``List[str]``, optional, default: None
Expand Down Expand Up @@ -131,56 +137,92 @@ def scatter(
>>> pg.scatter(data, attrs=['louvain_labels', 'Channel'], basis='fitsne')
>>> pg.scatter(data, attrs=['CD14', 'TRAC'], basis='umap')
"""
if attrs is None:
attrs = ['_all'] # default, plot all points
if palettes is None:
palettes = '_all:slategrey'
elif not is_list_like(attrs):
attrs = [attrs]
nattrs = len(attrs)

if not isinstance(data, anndata.AnnData):
cur_matkey = data.current_matrix()

if matkey is not None:
assert not isinstance(data, anndata.AnnData)
data.select_matrix(matkey)

x = data.obsm[f"X_{basis}"][:, 0]
y = data.obsm[f"X_{basis}"][:, 1]

# four corners of the plot
corners = np.array(np.meshgrid([x.min(), x.max()], [y.min(), y.max()])).T.reshape(-1, 2)
if attrs is None:
attrs = ['_all'] # default, plot all points
if palettes is None:
palettes = '_all:slategrey'
elif not is_list_like(attrs):
attrs = [attrs]

if isinstance(basis, str):
basis = [basis]
if isinstance(components, tuple):
components = [components]

basis = _transform_basis(basis)
global_marker_size = _get_marker_size(x.size) if marker_size is None else marker_size
nrows, ncols = _get_nrows_and_ncols(nattrs, nrows, ncols)
fig, axes = _get_subplot_layouts(nrows=nrows, ncols=ncols, panel_size=panel_size, dpi=dpi, left=left, bottom=bottom, wspace=wspace, hspace=hspace, squeeze=False)
# check validity for basis and components
max_comp = max(max([x[0] for x in components]), max([x[1] for x in components]))
for basis_key in basis:
rep = f"X_{basis_key}"
if rep not in data.obsm:
raise KeyError(f"Basis {basis_key} does not exist!")
if data.obsm[rep].shape[1] < max_comp:
raise KeyError(f"Basis {basis_key} only has {data.obsm[rep].shape[1]} components, less than max component {max_comp} specified in components!")

nattrs = len(attrs)
nbasis = len(basis)
ncomps = len(components)
nfigs = nattrs * nbasis * ncomps
share_xy = (nbasis == 1) and (ncomps == 1)


if not is_list_like(alpha):
alpha = [alpha] * nattrs

if not is_list_like(legend_loc):
legend_loc = [legend_loc] * nattrs
legend_fontsize = [5 if x == "on data" else 10 for x in legend_loc]

if not is_list_like(legend_fontsize):
legend_fontsize = [legend_fontsize] * nattrs

palettes = DictWithDefault(palettes)
cmaps = DictWithDefault(cmaps)
restr_obj = RestrictionParser(restrictions)
restr_obj.calc_default(data)

global_marker_size = None

nrows, ncols = _get_nrows_and_ncols(nfigs, nrows, ncols)
fig, axes = _get_subplot_layouts(nrows=nrows, ncols=ncols, panel_size=panel_size, dpi=dpi, left=left, bottom=bottom, wspace=wspace, hspace=hspace, squeeze=False, sharex=share_xy, sharey=share_xy)

for i in range(nrows):
for j in range(ncols):
ax = axes[i, j]
ax.grid(False)
ax.set_xticks([])
ax.set_yticks([])
if i * ncols + j >= nfigs:
ax.set_frame_on(False)

if i * ncols + j < nattrs:
pos = i * ncols + j
attr = attrs[pos]
offset_start_ = 0
offset_inc_ = nbasis * ncomps
for basis_key in basis:
basis_ = _transform_basis(basis_key)
for comp_key in components:
x = data.obsm[f"X_{basis_key}"][:, comp_key[0]-1]
y = data.obsm[f"X_{basis_key}"][:, comp_key[1]-1]

# four corners of the plot
corners = np.array(np.meshgrid([x.min(), x.max()], [y.min(), y.max()])).T.reshape(-1, 2)

if global_marker_size == None:
global_marker_size = _get_marker_size(x.size) if marker_size is None else marker_size

x_label = f"{basis_}{comp_key[0]}"
y_label = f"{basis_}{comp_key[1]}"

pos = offset_start_
for attr_id, attr in enumerate(attrs):
i = pos // ncols
j = pos % ncols
ax = axes[i, j]

if attr == '_all': # if default
values = pd.Categorical.from_codes(np.zeros(data.shape[0], dtype=int), categories=['cell'])
Expand Down Expand Up @@ -219,7 +261,7 @@ def scatter(
c=values[selected],
s=local_marker_size,
marker=".",
alpha=alpha[pos],
alpha=alpha[attr_id],
edgecolors="none",
cmap=cmap,
vmin=vmin,
Expand All @@ -232,7 +274,7 @@ def scatter(
y[selected] * scale_factor,
c=values[selected],
s=local_marker_size,
alpha=alpha[pos],
alpha=alpha[attr_id],
edgecolors="none",
cmap=cmap,
vmin=vmin,
Expand All @@ -245,7 +287,6 @@ def scatter(
rect = [left + width * (1.0 + 0.05), bottom, width * 0.1, height]
ax_colorbar = fig.add_axes(rect)
fig.colorbar(img, cax=ax_colorbar)

else:
# Categorical attribute
labels, with_background = _generate_categories(values, restr_obj.get_satisfied(data, attr))
Expand All @@ -261,10 +302,10 @@ def scatter(
for k, cat in enumerate(labels.categories):
idx = labels == cat
if idx.sum() > 0:
scatter_kwargs = {"alpha": alpha[pos], "edgecolors": "none", "rasterized": True}
scatter_kwargs = {"alpha": alpha[attr_id], "edgecolors": "none", "rasterized": True}

if cat != "":
if (legend_loc[pos] != "on data") and (scale_factor is None):
if (legend_loc[attr_id] != "on data") and (scale_factor is None):
scatter_kwargs["label"] = cat
else:
text_list.append((np.median(x[idx]), np.median(y[idx]), cat))
Expand Down Expand Up @@ -293,35 +334,38 @@ def scatter(
_plot_corners(ax, corners, local_marker_size)

if attr != '_all':
if legend_loc[pos] == "right margin":
if legend_loc[attr_id] == "right margin":
if scale_factor is not None:
for k, cat in enumerate(labels.categories):
ax.scatter([], [], c=palette[k], label=cat)
legend = ax.legend(
loc="center left",
bbox_to_anchor=(1, 0.5),
frameon=False,
fontsize=legend_fontsize[pos],
fontsize=legend_fontsize[attr_id],
ncol=_get_legend_ncol(label_size, legend_ncol),
)
for handle in legend.legendHandles:
handle.set_sizes([300.0 if scale_factor is None else 100.0])
elif legend_loc[pos] == "on data":
elif legend_loc[attr_id] == "on data":
texts = []
for px, py, txt in text_list:
texts.append(ax.text(px, py, txt, fontsize=legend_fontsize[pos], fontweight = "bold", ha = "center", va = "center"))
texts.append(ax.text(px, py, txt, fontsize=legend_fontsize[attr_id], fontweight = "bold", ha = "center", va = "center"))
# from adjustText import adjust_text
# adjust_text(texts, arrowprops=dict(arrowstyle='-', color='k', lw=0.5))

if attr != '_all':
ax.set_title(attr)
else:
ax.set_frame_on(False)

if i == nrows - 1:
ax.set_xlabel(f"{basis}1")
if j == 0:
ax.set_ylabel(f"{basis}2")
if (share_xy and (i + 1) * ncols + j >= nfigs) or (not share_xy):
ax.set_xlabel(x_label)

if (share_xy and j == 0) or (not share_xy):
ax.set_ylabel(y_label)

pos += offset_inc_

offset_start_ += 1

# Reset current matrix if needed.
if not isinstance(data, anndata.AnnData):
Expand Down Expand Up @@ -2169,11 +2213,72 @@ def plot_gsea(
df['NES Abs'] = np.abs(df['NES'])
df['pathway'] = df['pathway'].map(lambda x: ' '.join(x.split('_')))

fig, axes = _get_subplot_layouts(panel_size=panel_size, nrows=2, dpi=dpi, left=0.6, hspace=0.2)
fig, axes = _get_subplot_layouts(panel_size=panel_size, nrows=2, dpi=dpi, left=0.6, hspace=0.2, sharey=False)
df_up = df.loc[df['NES']>0]
_make_one_gsea_plot(df_up, axes[0], color='red')
df_dn = df.loc[df['NES']<0]
_make_one_gsea_plot(df_dn, axes[1], color='green')
axes[1].set_xlabel('-log10(q-value)')

return fig if return_fig else None


def elbowplot(
data: Union[MultimodalData, UnimodalData],
rep: str = "pca",
pval: str = "0.05",
panel_size: Optional[Tuple[float, float]] = (6, 4),
return_fig: Optional[bool] = False,
dpi: Optional[float] = 300.0,
**kwargs,
) -> Union[plt.Figure, None]:
"""Generate Elbowplot and suggest n_comps to select based on random matrix theory (see utils.largest_variance_from_random_matrix).
Parameters
----------
data : ``UnimodalData`` or ``MultimodalData`` object.
The main data object.
rep: ``str``, optional, default: ``pca``
Representation to consider, either "pca" or "tsvd".
pval: ``str``, optional (default: "0.05").
P value cutoff on the null distribution (random matrix), choosing from "0.01" and "0.05".
top_n: ``int``, optional, default: ``20``
Only show top_n up/down regulated pathways.
panel_size: `tuple`, optional (default: `(6, 4)`)
The plot size (width, height) in inches.
return_fig: ``bool``, optional, default: ``False``
Return a ``Figure`` object if ``True``; return ``None`` otherwise.
dpi: ``float``, optional, default: ``300.0``
The resolution in dots per inch.
Returns
-------
`Figure` object
A ``matplotlib.figure.Figure`` object containing the dot plot if ``return_fig == True``.
Update ``data.uns``:
* ``{rep}_ncomps``: Recommended components to pick.
Examples
--------
>>> fig = pg.elbowplot(data, dpi = 500)
"""
assert rep in data.uns
repf = data.uns[f"{rep}_features"]
nfeatures = data.var[repf].sum() if repf != None else data.shape[1]
thre = largest_variance_from_random_matrix(data.shape[0], nfeatures, pval)
ncomps = (data.uns[rep]["variance"] > thre).sum()
data.uns[f"{rep}_ncomps"] = ncomps
logger.info(f"Selecting {ncomps} is recommended!")

fig, ax = _get_subplot_layouts(panel_size=panel_size, dpi=dpi)
ax.scatter(range(1, data.uns[rep]["variance"].size + 1), data.uns[rep]["variance"], s=8, c='k')
ax.set_yscale('log')
ax.set_xlabel(rep.upper())
ax.set_ylabel("Variance")
ax.axvline(x = ncomps + 0.5, ls = "--", c = "r", linewidth=1)

return fig if return_fig else None
2 changes: 2 additions & 0 deletions pegasus/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
predefined_signatures,
predefined_pathways,
load_signatures_from_file,
largest_variance_from_random_matrix,
)

from .preprocessing import (
Expand All @@ -24,6 +25,7 @@
identify_robust_genes,
_run_filter_data,
log_norm,
arcsinh_transform,
select_features,
pca,
pc_transform,
Expand Down

0 comments on commit e16428b

Please sign in to comment.