In [1]:
from __future__ import annotations

from collections.abc import Sequence
from typing import Iterable, Literal

import polars as pl
from anndata import AnnData, read_h5ad
from scipy.sparse import issparse

from cellestial.util.errors import KeyNotFoundError

from lets_plot import LetsPlot
LetsPlot.setup_html()



In [2]:
data = read_h5ad("data/pbmc3k_pped.h5ad")

In [4]:
from __future__ import annotations

import contextlib
from typing import TYPE_CHECKING

import polars as pl
from anndata import AnnData
from lets_plot import (
    aes,
    geom_point,
    ggplot,
    ggtb,
    scale_color_gradient,
    scale_fill_gradient,
)
from lets_plot.plot.core import PlotSpec

from cellestial.frames import build_frame
from cellestial.themes import _THEME_DOTPLOT

if TYPE_CHECKING:
    from collections.abc import Sequence


def dotplot(
    data: AnnData,
    keys: Sequence[str],
    group_by: str,
    *,
    threshold: float = 0,
    variables_name: str = "gene",
    value_name: str = "expression",
    color_low: str = "#e6e6e6",
    color_high: str = "#D2042D",
    fill: bool = False,
    sort_by: str | Sequence[str] | None = None,
    sort_order: Literal["ascending", "descending"] = "descending",
    percentage_key: str = "pct_exp",
    mean_key: str = "avg_exp",
    show_tooltips: bool = True,
    interactive: bool = False,
    **geom_kwargs,
) -> PlotSpec:
    """
    Dotplot.

    Parameters
    ----------
    data : AnnData
        The AnnData object of the single cell data.
    keys : Sequence[str]
        The variable keys or names to include in the dotplot.
    group_by : str
        The key to group the data by.
    threshold : float, default=0
        The expression threshold to consider a gene as expressed.
    variables_name : str, default='gene'
        The name of the variable column in the long format.
    value_name : str, default="expression"
        The name of the value column in the long format.
    color_low : str, default='#e6e6e6'
        The low color for the gradient.
    color_high : str, default='#D2042D'
        The high color for the gradient.
    fill : bool, optional
        Whether to use fill aesthetic instead of color, by default False.
    sort_by : str | None
        The column to sort the results by, by default None.
    sort_order : str, default='descending'
        The sort order, either 'ascending' or 'descending'.
    percentage_key : str, default='pct_exp'
        The name of the percentage column.
    mean_key : str, default='avg_exp'
        The name of the mean expression column
    show_tooltips : bool, default=True
        Whether to show tooltips.
    interactive : bool, default=False
        Whether to make the plot interactive.
    **geom_kwargs : Any
        Additional keyword arguments for the geom_point layer.

    Returns
    -------
    PlotSpec
        Dotplot.
    """
    # HANDLE: Data types
    if not isinstance(data, AnnData):
        msg = "data must be an `AnnData` object"
        raise TypeError(msg)
    # BUILD: dataframe
    frame = build_frame(data=data, axis=0, variable_keys=keys)
    index_columns = [x for x in frame.columns if x not in keys]

    #  CRITICAL PARTS: Dataframe Operations
    # 1. Unpivot frame
    long_frame = frame.unpivot(
        on=keys,
        index=index_columns,
        variable_name=variables_name,
        value_name=value_name,
    )
    # 2. Aggregate and compute stats
    stats_frame = long_frame.group_by([group_by, variables_name]).agg(
        [
            pl.col(value_name).mean().alias(mean_key),
            (pl.col(value_name) > threshold).mean().mul(100).alias(percentage_key),
        ]
    )

    # HANDLE: Sorting
    # In case of pseudo-categorical integer group_by temporarily cast to int for proper sorting
    with contextlib.suppress(Exception):  # supress errors if sorting fails
        stats_frame = (
            stats_frame.with_columns(pl.col(group_by).cast(pl.String).cast(pl.Int64))
            .sort(group_by, descending=True)
            #.with_columns(pl.col(group_by).cast(pl.String).cast(pl.Categorical))
        )
    # perform sorting
    if sort_by is not None:
        stats_frame = stats_frame.sort(
            by=sort_by,
            descending=(sort_order == "descending"),
        )
    # Cast back to categorical
    if stats_frame[group_by].dtype == pl.Int64:
        stats_frame = stats_frame.with_columns(
            pl.col(group_by).cast(pl.String).cast(pl.Categorical)
        )

    # BUILD: Dotplot
    if not fill:  # use color aesthetic
        dtplt = (
            ggplot(stats_frame, aes(x=variables_name, y=group_by))
            + geom_point(aes(size=percentage_key, color=mean_key), **geom_kwargs)
            + scale_color_gradient(low=color_low, high=color_high)
        )
    else:  # elif fill: use fill aesthetic
        dtplt = (
            ggplot(stats_frame, aes(x=variables_name, y=group_by))
            + geom_point(aes(size=percentage_key, fill=mean_key), **geom_kwargs)
            + scale_fill_gradient(low=color_low, high=color_high)
        )

    # ADD: layers
    dtplt += _THEME_DOTPLOT

    # HANDLE: interactive
    if interactive:
        dtplt += ggtb(size_zoomin=-1)

    return dtplt


In [5]:
hvg = data.var[data.var["highly_variable"]].index.tolist()[:50]
hvg

['AL390719.2',
 'C1QTNF12',
 'AL162741.1',
 'LINC01786',
 'AL391244.2',
 'TMEM52',
 'AL589739.1',
 'PLCH2',
 'AL513320.1',
 'CHD5',
 'AL021155.5',
 'VPS13D',
 'AL031283.1',
 'FAM131C',
 'LINC01772',
 'LINC01783',
 'UBR4',
 'AL031727.2',
 'AL031005.2',
 'NBPF3',
 'WNT4',
 'C1QA',
 'C1QC',
 'C1QB',
 'LINC01355',
 'ID3',
 'AL031432.1',
 'STMN1',
 'CNKSR1',
 'ZNF683',
 'AL513365.2',
 'AL512408.1',
 'SFN',
 'AL020997.4',
 'AL360012.1',
 'OPRD1',
 'AC114488.2',
 'AL136115.2',
 'AL136115.1',
 'DCDC2B',
 'CSMD2',
 'CSF3R',
 'DNALI1',
 'POU3F1',
 'AL356055.1',
 'AL603839.3',
 'TMEM269',
 'ERMAP',
 'SLC2A1',
 'AL139289.1']

In [15]:
from lets_plot import scale_y_reverse, scale_x_reverse, scale_y_discrete

dp = dotplot(
    data=data,
    keys=hvg,
    group_by="leiden",
    threshold=1,
    color_low="snow",
    color_high="red",
    interactive=False,
    # sort_order="descending",
    # sort_by=["leiden"],
    shape=21,
    fill=True,
    stroke=0.2,
)
dp
