---
title: Single-cell Elbowplot
author: Zafer Kosar
format:
    html:
        code-fold: true
        code-summary: "Show code"
---

In [514]:
# Core scverse libraries
import anndata as ad

# data manipulation
import polars as pl
import scanpy as sc

# ggplot but interactive and python
from lets_plot import (
    LetsPlot,
    aes,
    element_blank,
    element_line,
    element_text,
    geom_blank,
    geom_jitter,
    geom_line,
    geom_point,
    geom_violin,
    geom_text,
    geom_label,
    gggrid,
    ggplot,
    ggsize,
    ggtb,
    guide_colorbar,
    guides,
    labs,
    layer_tooltips,
    scale_color_continuous,
    scale_color_gradient,
    scale_color_hue,
    scale_color_viridis,
    theme,
    theme_classic,
    scale_x_continuous,
    geom_smooth,
    geom_hline,
)

LetsPlot.setup_html()

from typing import TYPE_CHECKING, Literal

from lets_plot.plot.core import PlotSpec
from scipy.optimize import curve_fit
import numpy as np

In [515]:
# read the sampel data
adata = sc.read("pbmc3k_pped.h5ad")

In [516]:
adata

AnnData object with n_obs × n_vars = 16680 × 2000
    obs: 'sample', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'log1p_total_counts_ribo', 'pct_counts_ribo', 'total_counts_hb', 'log1p_total_counts_hb', 'pct_counts_hb', 'n_genes', 'leiden'
    var: 'mt', 'ribo', 'hb', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'n_cells', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'mean', 'std'
    uns: 'hvg', 'leiden', 'leiden_colors', 'log1p', 'neighbors', 'pca', 'umap'
    obsm: 'X_pca', 'X_umap'
    varm: 'PCs'
    obsp: 'connectivities', 'distances'

In [517]:
def exp_decay(x, a, b):
    return a * np.exp(-x * b)


def pc_fit(df: pl.DataFrame) -> pl.DataFrame:
    values = df.select("variance").to_numpy().flatten()
    x = df.select("PC").to_numpy().flatten()
    popt, _ = curve_fit(exp_decay, x, values)
    fit = exp_decay(x, *popt)
    mean_lifetime_point_sqr = np.max(fit) / (2.71828**2)
    # find for what x , y is mean_lifetime_point_sqr
    a, b = popt
    x_intercept = np.log(mean_lifetime_point_sqr / a) / -b
    return df.with_columns(pl.Series(exp_decay(x, *popt)).alias("exp_fit")), x_intercept

In [518]:
def theme_elbow(func, *args, **kwargs):
    def modifier(*args, **kwargs):
        plot = func(*args, **kwargs)
        plot += (
            theme_classic()
            + theme(
                text=element_text(color="#1f1f1f", family="Arial", size=12),
                axis_text_x=element_text(color="#1f1f1f", family="Arial", size=14),
                axis_text_y=element_text(color="#1f1f1f", family="Arial", size=14),
                axis_title=element_text(color="#1f1f1f", family="Arial", size=18),
            )
            + labs(y="Variance")
            + ggsize(600, 400)
        )
        return plot

    return modifier

In [None]:
@theme_elbow
def elbow(
    data: sc.AnnData,
    n_pcs: int = 50,
    *,
    scale: Literal["log", "linear"] = "log",
    fit: bool = True,
    interactive: bool = False,
    color_hline: str = "#7FFFD4",
    color_point: str = "#6f6f6f",
    color_line: str = "#d26868",
    shadow: bool = True,
    hline: bool = True,
):
    if not isinstance(data, sc.AnnData):
        raise ValueError("data must be an AnnData object")

    # Sub sample the data
    col_names = [f"{i+1}" for i in range(n_pcs)]
    # get the PCs from the anndata object
    frame = pl.from_numpy(data.obsm["X_pca"][:, :n_pcs], schema=col_names)
    # Calculate the variance explained by each PC, transpose, and rename the columns
    frame = (
        frame.select(pl.all().var())
        .transpose(include_header=True, header_name="PC", column_names=["variance"])
        .with_columns(pl.col("PC").cast(pl.Int16))
    )

    # Handle the scale
    if scale == "log":
        frame = frame.with_columns(pl.col("variance").log())
    elif scale == "linear":
        pass
    else:
        raise ValueError("scale must be either 'log' or 'linear'")

    # Create the plot
    elbw = ggplot(data=frame) + geom_point(aes(x="PC", y="variance"), size=5, color="#6f6f6f")

    if fit:
        frame, x_intercept = pc_fit(frame)
        if shadow:
            elbw += geom_line(
                data=frame, mapping=aes(x="PC", y="exp_fit"), size=4, color="#d26868", alpha=0.2
            )
        elbw += geom_line(data=frame, mapping=aes(x="PC", y="exp_fit"), size=2, color="#d26868")

        mean_lifetime_point_sqr = frame.select("exp_fit").max().item() / (2.71828**2)
        if hline:
            elbw += geom_hline(
                yintercept=mean_lifetime_point_sqr, color="#7FFFD4", size=1, linetype="dashed"
            )

        elbw += geom_label(
            hjust=0.5,
            yjust=0.5,
            label=f"X intercept = {x_intercept:.2f}",
            color="#3f3f3f",
            size=8,
            x=x_intercept,
            fontface="bold",
        )

    return elbw

In [564]:
elbow(adata, scale="log", n_pcs=40)

  return a * np.exp(-x * b)


In [521]:
n_pcs = 40
col_names = [f"{i+1}" for i in range(n_pcs)]
frame = pl.from_numpy(adata.obsm["X_pca"][:, :n_pcs], schema=col_names)

In [522]:
frame = (
    frame.select(pl.all().var())
    .transpose(include_header=True, header_name="PC", column_names=["variance"])
    .with_columns(pl.col("PC").cast(pl.Int16))
)
frame

PC,variance
i16,f32
1,38.158001
2,31.63534
3,25.29818
4,16.957245
5,9.686678
…,…
36,1.321278
37,1.308913
38,1.298041
39,1.295106


In [523]:
frame = frame.with_columns(pl.col("variance").log())
frame

PC,variance
i16,f32
1,3.641736
2,3.454275
3,3.230733
4,2.830695
5,2.270752
…,…
36,0.2786
37,0.269197
38,0.260857
39,0.258592


In [524]:
pc_fit(frame)

  return a * np.exp(-x * b)


(shape: (40, 3)
 ┌─────┬──────────┬──────────┐
 │ PC  ┆ variance ┆ exp_fit  │
 │ --- ┆ ---      ┆ ---      │
 │ i16 ┆ f32      ┆ f64      │
 ╞═════╪══════════╪══════════╡
 │ 1   ┆ 3.641736 ┆ 3.525122 │
 │ 2   ┆ 3.454275 ┆ 3.240479 │
 │ 3   ┆ 3.230733 ┆ 2.97882  │
 │ 4   ┆ 2.830695 ┆ 2.738289 │
 │ 5   ┆ 2.270752 ┆ 2.51718  │
 │ …   ┆ …        ┆ …        │
 │ 36  ┆ 0.2786   ┆ 0.185098 │
 │ 37  ┆ 0.269197 ┆ 0.170152 │
 │ 38  ┆ 0.260857 ┆ 0.156412 │
 │ 39  ┆ 0.258592 ┆ 0.143783 │
 │ 40  ┆ 0.246411 ┆ 0.132173 │
 └─────┴──────────┴──────────┘,
 np.float64(24.75467472593728))